aboutsummaryrefslogtreecommitdiff
path: root/src/xmpp_stream_out.erl
diff options
context:
space:
mode:
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>2018-05-26 09:06:24 +0300
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>2018-05-26 09:06:24 +0300
commitfc77051b68a923d958f876193fd5745af34208db (patch)
treebf7a3c2b1aee6c10ced2e1086bfc24a7731a631d /src/xmpp_stream_out.erl
parentmod_muc_sql: Fix export to SQL (diff)
Don't call Mod:function() in xmpp_stream callbacks
If a callback function is not defined by the `Mod` then a call to code_server process is performed. Under heavy load this may cause code_server to get overloaded. We now avoid this.
Diffstat (limited to 'src/xmpp_stream_out.erl')
-rw-r--r--src/xmpp_stream_out.erl197
1 files changed, 111 insertions, 86 deletions
diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl
index b2367a09b..da0a14e22 100644
--- a/src/xmpp_stream_out.erl
+++ b/src/xmpp_stream_out.erl
@@ -266,9 +266,9 @@ init([Mod, _SockMod, From, To, Opts]) ->
end.
-spec handle_call(term(), term(), state()) -> noreply().
-handle_call(Call, From, #{mod := Mod} = State) ->
- noreply(try Mod:handle_call(Call, From, State)
- catch _:undef -> State
+handle_call(Call, From, State) ->
+ noreply(try callback(handle_call, Call, From, State)
+ catch _:{?MODULE, undef} -> State
end).
-spec handle_cast(term(), state()) -> noreply().
@@ -311,9 +311,9 @@ handle_cast({close, Reason}, State) ->
true -> State1;
false -> process_stream_end({socket, Reason}, State)
end);
-handle_cast(Cast, #{mod := Mod} = State) ->
- noreply(try Mod:handle_cast(Cast, State)
- catch _:undef -> State
+handle_cast(Cast, State) ->
+ noreply(try callback(handle_cast, Cast, State)
+ catch _:{?MODULE, undef} -> State
end).
-spec handle_info(term(), state()) -> noreply().
@@ -348,20 +348,20 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
send_pkt(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS, mod := Mod, codec_options := Opts} = State) ->
+ #{xmlns := NS, codec_options := Opts} = State) ->
noreply(
try xmpp:decode(El, NS, Opts) of
Pkt ->
- State1 = try Mod:handle_recv(El, Pkt, State)
- catch _:undef -> State
+ State1 = try callback(handle_recv, El, Pkt, State)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
- State1 = try Mod:handle_recv(El, {error, Why}, State)
- catch _:undef -> State
+ State1 = try callback(handle_recv, El, {error, Why}, State)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@@ -369,21 +369,21 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
- #{mod := Mod} = State) ->
- noreply(try Mod:handle_cdata(Data, State)
- catch _:undef -> State
+ State) ->
+ noreply(try callback(handle_cdata, Data, State)
+ catch _:{?MODULE, undef} -> State
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_end({socket, closed}, State));
-handle_info(timeout, #{mod := Mod, lang := Lang} = State) ->
+handle_info(timeout, #{lang := Lang} = State) ->
Disconnected = is_disconnected(State),
- noreply(try Mod:handle_timeout(State)
- catch _:undef when not Disconnected ->
+ noreply(try callback(handle_timeout, State)
+ catch _:{?MODULE, undef} when not Disconnected ->
Txt = <<"Idle connection">>,
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
- _:undef ->
+ _:{?MODULE, undef} ->
stop(State)
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
@@ -404,26 +404,26 @@ handle_info({tcp_closed, _}, State) ->
handle_info({'$gen_event', closed}, State);
handle_info({tcp_error, _, Reason}, State) ->
noreply(process_stream_end({socket, Reason}, State));
-handle_info(Info, #{mod := Mod} = State) ->
- noreply(try Mod:handle_info(Info, State)
- catch _:undef -> State
+handle_info(Info, State) ->
+ noreply(try callback(handle_info, Info, State)
+ catch _:{?MODULE, undef} -> State
end).
-spec terminate(term(), state()) -> any().
-terminate(Reason, #{mod := Mod} = State) ->
+terminate(Reason, State) ->
case get(already_terminated) of
true ->
State;
_ ->
put(already_terminated, true),
- try Mod:terminate(Reason, State)
- catch _:undef -> ok
+ try callback(terminate, Reason, State)
+ catch _:{?MODULE, undef} -> ok
end,
send_trailer(State)
end.
-code_change(OldVsn, #{mod := Mod} = State, Extra) ->
- Mod:code_change(OldVsn, State, Extra).
+code_change(OldVsn, State, Extra) ->
+ callback(code_change, OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
@@ -458,11 +458,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
-process_stream_end(Reason, #{mod := Mod} = State) ->
+process_stream_end(Reason, State) ->
State1 = State#{stream_timeout => infinity,
stream_state => disconnected},
- try Mod:handle_stream_end(Reason, State1)
- catch _:undef -> stop(State1)
+ try callback(handle_stream_end, Reason, State1)
+ catch _:{?MODULE, undef} -> stop(State1)
end.
-spec process_stream(stream_start(), state()) -> state().
@@ -475,10 +475,10 @@ process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_pkt(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang, id = ID,
version = Version} = StreamStart,
- #{mod := Mod} = State) ->
+ State) ->
State1 = State#{stream_remote_id => ID, lang => Lang},
- State2 = try Mod:handle_stream_start(StreamStart, State1)
- catch _:undef -> State1
+ State2 = try callback(handle_stream_start, StreamStart, State1)
+ catch _:{?MODULE, undef} -> State1
end,
case is_disconnected(State2) of
true -> State2;
@@ -522,16 +522,16 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
-spec process_features(stream_features(), state()) -> state().
process_features(StreamFeatures,
- #{stream_authenticated := true, mod := Mod} = State) ->
- State1 = try Mod:handle_authenticated_features(StreamFeatures, State)
- catch _:undef -> State
+ #{stream_authenticated := true} = State) ->
+ State1 = try callback(handle_authenticated_features, StreamFeatures, State)
+ catch _:{?MODULE, undef} -> State
end,
process_stream_established(State1);
process_features(StreamFeatures,
#{stream_encrypted := Encrypted,
- mod := Mod, lang := Lang} = State) ->
- State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State)
- catch _:undef -> State
+ lang := Lang} = State) ->
+ State1 = try callback(handle_unauthenticated_features, StreamFeatures, State)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@@ -582,12 +582,12 @@ process_features(StreamFeatures,
process_stream_established(#{stream_state := StateName} = State)
when StateName == disconnected; StateName == established ->
State;
-process_stream_established(#{mod := Mod} = State) ->
+process_stream_established(State) ->
State1 = State#{stream_authenticated := true,
stream_state => established,
stream_timeout => infinity},
- try Mod:handle_stream_established(State1)
- catch _:undef -> State1
+ try callback(handle_stream_established, State1)
+ catch _:{?MODULE, undef} -> State1
end.
-spec process_sasl_mechanisms([binary()], state()) -> state().
@@ -620,7 +620,7 @@ process_starttls(#{socket := Socket} = State) ->
-spec process_stream_downgrade(stream_start(), state()) -> state().
process_stream_downgrade(StreamStart,
- #{mod := Mod, lang := Lang,
+ #{lang := Lang,
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
if not Encrypted and TLSRequired ->
@@ -628,18 +628,17 @@ process_stream_downgrade(StreamStart,
send_pkt(State, xmpp:serr_policy_violation(Txt, Lang));
true ->
State1 = State#{stream_state => downgraded},
- try Mod:handle_stream_downgraded(StreamStart, State1)
- catch _:undef ->
+ try callback(handle_stream_downgraded, StreamStart, State1)
+ catch _:{?MODULE, undef} ->
send_pkt(State1, xmpp:serr_unsupported_version())
end
end.
-spec process_cert_verification(state()) -> state().
process_cert_verification(#{stream_encrypted := true,
- stream_verified := false,
- mod := Mod} = State) ->
- case try Mod:tls_verify(State)
- catch _:undef -> true
+ stream_verified := false} = State) ->
+ case try callback(tls_verify, State)
+ catch _:{?MODULE, undef} -> true
end of
true ->
case xmpp_stream_pkix:authenticate(State) of
@@ -655,8 +654,7 @@ process_cert_verification(State) ->
State.
-spec process_sasl_success(state()) -> state().
-process_sasl_success(#{mod := Mod,
- socket := Socket} = State) ->
+process_sasl_success(#{socket := Socket} = State) ->
Socket1 = xmpp_socket:reset_stream(Socket),
State0 = State#{socket => Socket1},
State1 = State0#{stream_id => new_id(),
@@ -667,8 +665,8 @@ process_sasl_success(#{mod := Mod,
case is_disconnected(State2) of
true -> State2;
false ->
- try Mod:handle_auth_success(<<"EXTERNAL">>, State2)
- catch _:undef -> State2
+ try callback(handle_auth_success, <<"EXTERNAL">>, State2)
+ catch _:{?MODULE, undef} -> State2
end
end.
@@ -677,27 +675,27 @@ process_sasl_failure(#sasl_failure{} = Failure, State) ->
Reason = format("Peer responded with error: ~s",
[format_sasl_failure(Failure)]),
process_sasl_failure(Reason, State);
-process_sasl_failure(Reason, #{mod := Mod} = State) ->
- try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State)
- catch _:undef -> process_stream_end({auth, Reason}, State)
+process_sasl_failure(Reason, State) ->
+ try callback(handle_auth_failure, <<"EXTERNAL">>, {auth, Reason}, State)
+ catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State)
end.
-spec process_packet(xmpp_element(), state()) -> state().
-process_packet(Pkt, #{mod := Mod} = State) ->
- try Mod:handle_packet(Pkt, State)
- catch _:undef -> State
+process_packet(Pkt, State) ->
+ try callback(handle_packet, Pkt, State)
+ catch _:{?MODULE, undef} -> State
end.
-spec is_starttls_required(state()) -> boolean().
-is_starttls_required(#{mod := Mod} = State) ->
- try Mod:tls_required(State)
- catch _:undef -> false
+is_starttls_required(State) ->
+ try callback(tls_required, State)
+ catch _:{?MODULE, undef} -> false
end.
-spec is_starttls_available(state()) -> boolean().
-is_starttls_available(#{mod := Mod} = State) ->
- try Mod:tls_enabled(State)
- catch _:undef -> true
+is_starttls_available(State) ->
+ try callback(tls_enabled, State)
+ catch _:{?MODULE, undef} -> true
end.
-spec send_header(state()) -> state().
@@ -731,10 +729,10 @@ send_header(#{remote_server := RemoteServer,
end.
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
-send_pkt(#{mod := Mod} = State, Pkt) ->
+send_pkt(State, Pkt) ->
Result = socket_send(State, Pkt),
- State1 = try Mod:handle_send(Pkt, Result, State)
- catch _:undef -> State
+ State1 = try callback(handle_send, Pkt, Result, State)
+ catch _:{?MODULE, undef} -> State
end,
case Result of
_ when is_record(Pkt, stream_error) ->
@@ -795,10 +793,10 @@ close_socket(State) ->
stream_state => disconnected}.
-spec starttls(term(), state()) -> {ok, term()} | {error, tls_error_reason()}.
-starttls(Socket, #{mod := Mod, xmlns := NS,
+starttls(Socket, #{xmlns := NS,
remote_server := RemoteServer} = State) ->
- TLSOpts = try Mod:tls_options(State)
- catch _:undef -> []
+ TLSOpts = try callback(tls_options, State)
+ catch _:{?MODULE, undef} -> []
end,
SNI = idna_to_ascii(RemoteServer),
ALPN = case NS of
@@ -1077,32 +1075,59 @@ get_addr_type({_, _, _, _}) -> inet;
get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
-spec get_dns_timeout(state()) -> timeout().
-get_dns_timeout(#{mod := Mod} = State) ->
- try Mod:dns_timeout(State)
- catch _:undef -> timer:seconds(10)
+get_dns_timeout(State) ->
+ try callback(dns_timeout, State)
+ catch _:{?MODULE, undef} -> timer:seconds(10)
end.
-spec get_dns_retries(state()) -> non_neg_integer().
-get_dns_retries(#{mod := Mod} = State) ->
- try Mod:dns_retries(State)
- catch _:undef -> 2
+get_dns_retries(State) ->
+ try callback(dns_retries, State)
+ catch _:{?MODULE, undef} -> 2
end.
-spec get_default_port(state()) -> inet:port_number().
-get_default_port(#{mod := Mod, xmlns := NS} = State) ->
- try Mod:default_port(State)
- catch _:undef when NS == ?NS_SERVER -> 5269;
- _:undef when NS == ?NS_CLIENT -> 5222
+get_default_port(#{xmlns := NS} = State) ->
+ try callback(default_port, State)
+ catch _:{?MODULE, undef} when NS == ?NS_SERVER -> 5269;
+ _:{?MODULE, undef} when NS == ?NS_CLIENT -> 5222
end.
-spec get_address_families(state()) -> [inet:address_family()].
-get_address_families(#{mod := Mod} = State) ->
- try Mod:address_families(State)
- catch _:undef -> [inet, inet6]
+get_address_families(State) ->
+ try callback(address_families, State)
+ catch _:{?MODULE, undef} -> [inet, inet6]
end.
-spec get_connect_timeout(state()) -> timeout().
-get_connect_timeout(#{mod := Mod} = State) ->
- try Mod:connect_timeout(State)
- catch _:undef -> timer:seconds(10)
+get_connect_timeout(State) ->
+ try callback(connect_timeout, State)
+ catch _:{?MODULE, undef} -> timer:seconds(10)
+ end.
+
+%%%===================================================================
+%%% Callbacks
+%%%===================================================================
+callback(F, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 1) of
+ true -> Mod:F(State);
+ false -> erlang:error({?MODULE, undef})
+ end.
+
+callback(F, Arg1, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 2) of
+ true -> Mod:F(Arg1, State);
+ false -> erlang:error({?MODULE, undef})
+ end.
+
+callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
+ %% code_change/3 callback is a special snowflake
+ case erlang:function_exported(Mod, code_change, 3) of
+ true -> Mod:code_change(OldVsn, State, Extra);
+ false -> {ok, State}
+ end;
+callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 3) of
+ true -> Mod:F(Arg1, Arg2, State);
+ false -> erlang:error({?MODULE, undef})
end.