diff options
author | Evgeniy Khramtsov <ekhramtsov@process-one.net> | 2018-05-26 09:06:24 +0300 |
---|---|---|
committer | Evgeniy Khramtsov <ekhramtsov@process-one.net> | 2018-05-26 09:06:24 +0300 |
commit | fc77051b68a923d958f876193fd5745af34208db (patch) | |
tree | bf7a3c2b1aee6c10ced2e1086bfc24a7731a631d /src/xmpp_stream_out.erl | |
parent | mod_muc_sql: Fix export to SQL (diff) |
Don't call Mod:function() in xmpp_stream callbacks
If a callback function is not defined by the `Mod` then
a call to code_server process is performed. Under heavy load
this may cause code_server to get overloaded. We now avoid this.
Diffstat (limited to 'src/xmpp_stream_out.erl')
-rw-r--r-- | src/xmpp_stream_out.erl | 197 |
1 files changed, 111 insertions, 86 deletions
diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl index b2367a09b..da0a14e22 100644 --- a/src/xmpp_stream_out.erl +++ b/src/xmpp_stream_out.erl @@ -266,9 +266,9 @@ init([Mod, _SockMod, From, To, Opts]) -> end. -spec handle_call(term(), term(), state()) -> noreply(). -handle_call(Call, From, #{mod := Mod} = State) -> - noreply(try Mod:handle_call(Call, From, State) - catch _:undef -> State +handle_call(Call, From, State) -> + noreply(try callback(handle_call, Call, From, State) + catch _:{?MODULE, undef} -> State end). -spec handle_cast(term(), state()) -> noreply(). @@ -311,9 +311,9 @@ handle_cast({close, Reason}, State) -> true -> State1; false -> process_stream_end({socket, Reason}, State) end); -handle_cast(Cast, #{mod := Mod} = State) -> - noreply(try Mod:handle_cast(Cast, State) - catch _:undef -> State +handle_cast(Cast, State) -> + noreply(try callback(handle_cast, Cast, State) + catch _:{?MODULE, undef} -> State end). -spec handle_info(term(), state()) -> noreply(). @@ -348,20 +348,20 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> send_pkt(State1, Err) end); handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS, mod := Mod, codec_options := Opts} = State) -> + #{xmlns := NS, codec_options := Opts} = State) -> noreply( try xmpp:decode(El, NS, Opts) of Pkt -> - State1 = try Mod:handle_recv(El, Pkt, State) - catch _:undef -> State + State1 = try callback(handle_recv, El, Pkt, State) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; false -> process_element(Pkt, State1) end catch _:{xmpp_codec, Why} -> - State1 = try Mod:handle_recv(El, {error, Why}, State) - catch _:undef -> State + State1 = try callback(handle_recv, El, {error, Why}, State) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; @@ -369,21 +369,21 @@ handle_info({'$gen_event', {xmlstreamelement, El}}, end end); handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, - #{mod := Mod} = State) -> - noreply(try Mod:handle_cdata(Data, State) - catch _:undef -> State + State) -> + noreply(try callback(handle_cdata, Data, State) + catch _:{?MODULE, undef} -> State end); handle_info({'$gen_event', {xmlstreamend, _}}, State) -> noreply(process_stream_end({stream, reset}, State)); handle_info({'$gen_event', closed}, State) -> noreply(process_stream_end({socket, closed}, State)); -handle_info(timeout, #{mod := Mod, lang := Lang} = State) -> +handle_info(timeout, #{lang := Lang} = State) -> Disconnected = is_disconnected(State), - noreply(try Mod:handle_timeout(State) - catch _:undef when not Disconnected -> + noreply(try callback(handle_timeout, State) + catch _:{?MODULE, undef} when not Disconnected -> Txt = <<"Idle connection">>, send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang)); - _:undef -> + _:{?MODULE, undef} -> stop(State) end); handle_info({'DOWN', MRef, _Type, _Object, _Info}, @@ -404,26 +404,26 @@ handle_info({tcp_closed, _}, State) -> handle_info({'$gen_event', closed}, State); handle_info({tcp_error, _, Reason}, State) -> noreply(process_stream_end({socket, Reason}, State)); -handle_info(Info, #{mod := Mod} = State) -> - noreply(try Mod:handle_info(Info, State) - catch _:undef -> State +handle_info(Info, State) -> + noreply(try callback(handle_info, Info, State) + catch _:{?MODULE, undef} -> State end). -spec terminate(term(), state()) -> any(). -terminate(Reason, #{mod := Mod} = State) -> +terminate(Reason, State) -> case get(already_terminated) of true -> State; _ -> put(already_terminated, true), - try Mod:terminate(Reason, State) - catch _:undef -> ok + try callback(terminate, Reason, State) + catch _:{?MODULE, undef} -> ok end, send_trailer(State) end. -code_change(OldVsn, #{mod := Mod} = State, Extra) -> - Mod:code_change(OldVsn, State, Extra). +code_change(OldVsn, State, Extra) -> + callback(code_change, OldVsn, State, Extra). %%%=================================================================== %%% Internal functions @@ -458,11 +458,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> -spec process_stream_end(stop_reason(), state()) -> state(). process_stream_end(_, #{stream_state := disconnected} = State) -> State; -process_stream_end(Reason, #{mod := Mod} = State) -> +process_stream_end(Reason, State) -> State1 = State#{stream_timeout => infinity, stream_state => disconnected}, - try Mod:handle_stream_end(Reason, State1) - catch _:undef -> stop(State1) + try callback(handle_stream_end, Reason, State1) + catch _:{?MODULE, undef} -> stop(State1) end. -spec process_stream(stream_start(), state()) -> state(). @@ -475,10 +475,10 @@ process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> send_pkt(State, xmpp:serr_unsupported_version()); process_stream(#stream_start{lang = Lang, id = ID, version = Version} = StreamStart, - #{mod := Mod} = State) -> + State) -> State1 = State#{stream_remote_id => ID, lang => Lang}, - State2 = try Mod:handle_stream_start(StreamStart, State1) - catch _:undef -> State1 + State2 = try callback(handle_stream_start, StreamStart, State1) + catch _:{?MODULE, undef} -> State1 end, case is_disconnected(State2) of true -> State2; @@ -522,16 +522,16 @@ process_element(Pkt, #{stream_state := StateName} = State) -> -spec process_features(stream_features(), state()) -> state(). process_features(StreamFeatures, - #{stream_authenticated := true, mod := Mod} = State) -> - State1 = try Mod:handle_authenticated_features(StreamFeatures, State) - catch _:undef -> State + #{stream_authenticated := true} = State) -> + State1 = try callback(handle_authenticated_features, StreamFeatures, State) + catch _:{?MODULE, undef} -> State end, process_stream_established(State1); process_features(StreamFeatures, #{stream_encrypted := Encrypted, - mod := Mod, lang := Lang} = State) -> - State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State) - catch _:undef -> State + lang := Lang} = State) -> + State1 = try callback(handle_unauthenticated_features, StreamFeatures, State) + catch _:{?MODULE, undef} -> State end, case is_disconnected(State1) of true -> State1; @@ -582,12 +582,12 @@ process_features(StreamFeatures, process_stream_established(#{stream_state := StateName} = State) when StateName == disconnected; StateName == established -> State; -process_stream_established(#{mod := Mod} = State) -> +process_stream_established(State) -> State1 = State#{stream_authenticated := true, stream_state => established, stream_timeout => infinity}, - try Mod:handle_stream_established(State1) - catch _:undef -> State1 + try callback(handle_stream_established, State1) + catch _:{?MODULE, undef} -> State1 end. -spec process_sasl_mechanisms([binary()], state()) -> state(). @@ -620,7 +620,7 @@ process_starttls(#{socket := Socket} = State) -> -spec process_stream_downgrade(stream_start(), state()) -> state(). process_stream_downgrade(StreamStart, - #{mod := Mod, lang := Lang, + #{lang := Lang, stream_encrypted := Encrypted} = State) -> TLSRequired = is_starttls_required(State), if not Encrypted and TLSRequired -> @@ -628,18 +628,17 @@ process_stream_downgrade(StreamStart, send_pkt(State, xmpp:serr_policy_violation(Txt, Lang)); true -> State1 = State#{stream_state => downgraded}, - try Mod:handle_stream_downgraded(StreamStart, State1) - catch _:undef -> + try callback(handle_stream_downgraded, StreamStart, State1) + catch _:{?MODULE, undef} -> send_pkt(State1, xmpp:serr_unsupported_version()) end end. -spec process_cert_verification(state()) -> state(). process_cert_verification(#{stream_encrypted := true, - stream_verified := false, - mod := Mod} = State) -> - case try Mod:tls_verify(State) - catch _:undef -> true + stream_verified := false} = State) -> + case try callback(tls_verify, State) + catch _:{?MODULE, undef} -> true end of true -> case xmpp_stream_pkix:authenticate(State) of @@ -655,8 +654,7 @@ process_cert_verification(State) -> State. -spec process_sasl_success(state()) -> state(). -process_sasl_success(#{mod := Mod, - socket := Socket} = State) -> +process_sasl_success(#{socket := Socket} = State) -> Socket1 = xmpp_socket:reset_stream(Socket), State0 = State#{socket => Socket1}, State1 = State0#{stream_id => new_id(), @@ -667,8 +665,8 @@ process_sasl_success(#{mod := Mod, case is_disconnected(State2) of true -> State2; false -> - try Mod:handle_auth_success(<<"EXTERNAL">>, State2) - catch _:undef -> State2 + try callback(handle_auth_success, <<"EXTERNAL">>, State2) + catch _:{?MODULE, undef} -> State2 end end. @@ -677,27 +675,27 @@ process_sasl_failure(#sasl_failure{} = Failure, State) -> Reason = format("Peer responded with error: ~s", [format_sasl_failure(Failure)]), process_sasl_failure(Reason, State); -process_sasl_failure(Reason, #{mod := Mod} = State) -> - try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State) - catch _:undef -> process_stream_end({auth, Reason}, State) +process_sasl_failure(Reason, State) -> + try callback(handle_auth_failure, <<"EXTERNAL">>, {auth, Reason}, State) + catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State) end. -spec process_packet(xmpp_element(), state()) -> state(). -process_packet(Pkt, #{mod := Mod} = State) -> - try Mod:handle_packet(Pkt, State) - catch _:undef -> State +process_packet(Pkt, State) -> + try callback(handle_packet, Pkt, State) + catch _:{?MODULE, undef} -> State end. -spec is_starttls_required(state()) -> boolean(). -is_starttls_required(#{mod := Mod} = State) -> - try Mod:tls_required(State) - catch _:undef -> false +is_starttls_required(State) -> + try callback(tls_required, State) + catch _:{?MODULE, undef} -> false end. -spec is_starttls_available(state()) -> boolean(). -is_starttls_available(#{mod := Mod} = State) -> - try Mod:tls_enabled(State) - catch _:undef -> true +is_starttls_available(State) -> + try callback(tls_enabled, State) + catch _:{?MODULE, undef} -> true end. -spec send_header(state()) -> state(). @@ -731,10 +729,10 @@ send_header(#{remote_server := RemoteServer, end. -spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). -send_pkt(#{mod := Mod} = State, Pkt) -> +send_pkt(State, Pkt) -> Result = socket_send(State, Pkt), - State1 = try Mod:handle_send(Pkt, Result, State) - catch _:undef -> State + State1 = try callback(handle_send, Pkt, Result, State) + catch _:{?MODULE, undef} -> State end, case Result of _ when is_record(Pkt, stream_error) -> @@ -795,10 +793,10 @@ close_socket(State) -> stream_state => disconnected}. -spec starttls(term(), state()) -> {ok, term()} | {error, tls_error_reason()}. -starttls(Socket, #{mod := Mod, xmlns := NS, +starttls(Socket, #{xmlns := NS, remote_server := RemoteServer} = State) -> - TLSOpts = try Mod:tls_options(State) - catch _:undef -> [] + TLSOpts = try callback(tls_options, State) + catch _:{?MODULE, undef} -> [] end, SNI = idna_to_ascii(RemoteServer), ALPN = case NS of @@ -1077,32 +1075,59 @@ get_addr_type({_, _, _, _}) -> inet; get_addr_type({_, _, _, _, _, _, _, _}) -> inet6. -spec get_dns_timeout(state()) -> timeout(). -get_dns_timeout(#{mod := Mod} = State) -> - try Mod:dns_timeout(State) - catch _:undef -> timer:seconds(10) +get_dns_timeout(State) -> + try callback(dns_timeout, State) + catch _:{?MODULE, undef} -> timer:seconds(10) end. -spec get_dns_retries(state()) -> non_neg_integer(). -get_dns_retries(#{mod := Mod} = State) -> - try Mod:dns_retries(State) - catch _:undef -> 2 +get_dns_retries(State) -> + try callback(dns_retries, State) + catch _:{?MODULE, undef} -> 2 end. -spec get_default_port(state()) -> inet:port_number(). -get_default_port(#{mod := Mod, xmlns := NS} = State) -> - try Mod:default_port(State) - catch _:undef when NS == ?NS_SERVER -> 5269; - _:undef when NS == ?NS_CLIENT -> 5222 +get_default_port(#{xmlns := NS} = State) -> + try callback(default_port, State) + catch _:{?MODULE, undef} when NS == ?NS_SERVER -> 5269; + _:{?MODULE, undef} when NS == ?NS_CLIENT -> 5222 end. -spec get_address_families(state()) -> [inet:address_family()]. -get_address_families(#{mod := Mod} = State) -> - try Mod:address_families(State) - catch _:undef -> [inet, inet6] +get_address_families(State) -> + try callback(address_families, State) + catch _:{?MODULE, undef} -> [inet, inet6] end. -spec get_connect_timeout(state()) -> timeout(). -get_connect_timeout(#{mod := Mod} = State) -> - try Mod:connect_timeout(State) - catch _:undef -> timer:seconds(10) +get_connect_timeout(State) -> + try callback(connect_timeout, State) + catch _:{?MODULE, undef} -> timer:seconds(10) + end. + +%%%=================================================================== +%%% Callbacks +%%%=================================================================== +callback(F, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 1) of + true -> Mod:F(State); + false -> erlang:error({?MODULE, undef}) + end. + +callback(F, Arg1, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 2) of + true -> Mod:F(Arg1, State); + false -> erlang:error({?MODULE, undef}) + end. + +callback(code_change, OldVsn, #{mod := Mod} = State, Extra) -> + %% code_change/3 callback is a special snowflake + case erlang:function_exported(Mod, code_change, 3) of + true -> Mod:code_change(OldVsn, State, Extra); + false -> {ok, State} + end; +callback(F, Arg1, Arg2, #{mod := Mod} = State) -> + case erlang:function_exported(Mod, F, 3) of + true -> Mod:F(Arg1, Arg2, State); + false -> erlang:error({?MODULE, undef}) end. |