aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/xmpp_stream_in.erl240
-rw-r--r--src/xmpp_stream_out.erl197
2 files changed, 247 insertions, 190 deletions
diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl
index 55fa3a4bf..675425bd0 100644
--- a/src/xmpp_stream_in.erl
+++ b/src/xmpp_stream_in.erl
@@ -210,14 +210,14 @@ format_error(Err) ->
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
-init([Module, {_SockMod, Socket}, Opts]) ->
+init([Mod, {_SockMod, Socket}, Opts]) ->
Encrypted = proplists:get_bool(tls, Opts),
SocketMonitor = xmpp_socket:monitor(Socket),
case xmpp_socket:peername(Socket) of
{ok, IP} ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(),
- mod => Module,
+ mod => Mod,
socket => Socket,
socket_monitor => SocketMonitor,
stream_timeout => {timer:seconds(30), Time},
@@ -238,15 +238,15 @@ init([Module, {_SockMod, Socket}, Opts]) ->
resource => <<"">>,
lserver => <<"">>,
ip => IP},
- case try Module:init([State, Opts])
+ case try Mod:init([State, Opts])
catch _:undef -> {ok, State}
end of
{ok, State1} when not Encrypted ->
{_, State2, Timeout} = noreply(State1),
{ok, State2, Timeout};
{ok, State1} when Encrypted ->
- TLSOpts = try Module:tls_options(State1)
- catch _:undef -> []
+ TLSOpts = try callback(tls_options, State1)
+ catch _:{?MODULE, undef} -> []
end,
case xmpp_socket:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
@@ -276,14 +276,14 @@ handle_cast({close, Reason}, State) ->
true -> State1;
false -> process_stream_end({socket, Reason}, State)
end);
-handle_cast(Cast, #{mod := Mod} = State) ->
- noreply(try Mod:handle_cast(Cast, State)
- catch _:undef -> State
+handle_cast(Cast, State) ->
+ noreply(try callback(handle_cast, Cast, State)
+ catch _:{?MODULE, undef} -> State
end).
-handle_call(Call, From, #{mod := Mod} = State) ->
- noreply(try Mod:handle_call(Call, From, State)
- catch _:undef -> State
+handle_call(Call, From, State) ->
+ noreply(try callback(handle_call, Call, From, State)
+ catch _:{?MODULE, undef} -> State
end).
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
@@ -343,20 +343,20 @@ handle_info({'$gen_event', El}, #{stream_state := wait_for_stream} = State) ->
false -> send_pkt(State1, xmpp:serr_invalid_xml())
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS, mod := Mod, codec_options := Opts} = State) ->
+ #{xmlns := NS, codec_options := Opts} = State) ->
noreply(
try xmpp:decode(El, NS, Opts) of
Pkt ->
- State1 = try Mod:handle_recv(El, Pkt, State)
- catch _:undef -> State
+ State1 = try callback(handle_recv, El, Pkt, State)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
- State1 = try Mod:handle_recv(El, {error, Why}, State)
- catch _:undef -> State
+ State1 = try callback(handle_recv, El, {error, Why}, State)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@@ -364,17 +364,17 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
- #{mod := Mod} = State) ->
- noreply(try Mod:handle_cdata(Data, State)
- catch _:undef -> State
+ State) ->
+ noreply(try callback(handle_cdata, Data, State)
+ catch _:{?MODULE, undef} -> State
end);
-handle_info(timeout, #{mod := Mod, lang := Lang} = State) ->
+handle_info(timeout, #{lang := Lang} = State) ->
Disconnected = is_disconnected(State),
- noreply(try Mod:handle_timeout(State)
- catch _:undef when not Disconnected ->
+ noreply(try callback(handle_timeout, State)
+ catch _:{?MODULE, undef} when not Disconnected ->
Txt = <<"Idle connection">>,
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
- _:undef ->
+ _:{?MODULE, undef} ->
stop(State)
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
@@ -395,25 +395,25 @@ handle_info({tcp_closed, _}, State) ->
handle_info({'$gen_event', closed}, State);
handle_info({tcp_error, _, Reason}, State) ->
noreply(process_stream_end({socket, Reason}, State));
-handle_info(Info, #{mod := Mod} = State) ->
- noreply(try Mod:handle_info(Info, State)
- catch _:undef -> State
+handle_info(Info, State) ->
+ noreply(try callback(handle_info, Info, State)
+ catch _:{?MODULE, undef} -> State
end).
-terminate(Reason, #{mod := Mod} = State) ->
+terminate(Reason, State) ->
case get(already_terminated) of
true ->
State;
_ ->
put(already_terminated, true),
- try Mod:terminate(Reason, State)
- catch _:undef -> ok
+ try callback(terminate, Reason, State)
+ catch _:{?MODULE, undef} -> ok
end,
send_trailer(State)
end.
-code_change(OldVsn, #{mod := Mod} = State, Extra) ->
- Mod:code_change(OldVsn, State, Extra).
+code_change(OldVsn, State, Extra) ->
+ callback(code_change, OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
@@ -464,11 +464,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
-process_stream_end(Reason, #{mod := Mod} = State) ->
+process_stream_end(Reason, State) ->
State1 = State#{stream_timeout => infinity,
stream_state => disconnected},
- try Mod:handle_stream_end(Reason, State1)
- catch _:undef -> stop(State1)
+ try callback(handle_stream_end, Reason, State1)
+ catch _:{?MODULE, undef} -> stop(State1)
end.
-spec process_stream(stream_start(), state()) -> state().
@@ -503,17 +503,17 @@ process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
Txt = <<"Improper 'to' attribute">>,
send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
- #{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
+ #{xmlns := ?NS_COMPONENT} = State) ->
State1 = State#{remote_server => RemoteServer,
stream_state => wait_for_handshake},
- try Mod:handle_stream_start(StreamStart, State1)
- catch _:undef -> State1
+ try callback(handle_stream_start, StreamStart, State1)
+ catch _:{?MODULE, undef} -> State1
end;
process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
from = From} = StreamStart,
#{stream_authenticated := Authenticated,
stream_restarted := StreamWasRestarted,
- mod := Mod, xmlns := NS, resource := Resource,
+ xmlns := NS, resource := Resource,
stream_encrypted := Encrypted} = State) ->
State1 = if not StreamWasRestarted ->
State#{server => Server, lserver => LServer};
@@ -526,8 +526,8 @@ process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
_ ->
State1
end,
- State3 = try Mod:handle_stream_start(StreamStart, State2)
- catch _:undef -> State2
+ State3 = try callback(handle_stream_start, StreamStart, State2)
+ catch _:{?MODULE, undef} -> State2
end,
case is_disconnected(State3) of
true -> State3;
@@ -604,21 +604,21 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
end.
-spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
-process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
+process_unauthenticated_packet(Pkt, State) ->
NewPkt = set_lang(Pkt, State),
- try Mod:handle_unauthenticated_packet(NewPkt, State)
- catch _:undef ->
+ try callback(handle_unauthenticated_packet, NewPkt, State)
+ catch _:{?MODULE, undef} ->
Err = xmpp:serr_not_authorized(),
send(State, Err)
end.
-spec process_authenticated_packet(xmpp_element(), state()) -> state().
-process_authenticated_packet(Pkt, #{mod := Mod} = State) ->
+process_authenticated_packet(Pkt, State) ->
Pkt1 = set_lang(Pkt, State),
case set_from_to(Pkt1, State) of
{ok, Pkt2} ->
- try Mod:handle_authenticated_packet(Pkt2, State)
- catch _:undef ->
+ try callback(handle_authenticated_packet, Pkt2, State)
+ catch _:{?MODULE, undef} ->
Err = xmpp:err_service_unavailable(),
send_error(State, Pkt, Err)
end;
@@ -628,10 +628,10 @@ process_authenticated_packet(Pkt, #{mod := Mod} = State) ->
-spec process_bind(xmpp_element(), state()) -> state().
process_bind(#iq{type = set, sub_els = [_]} = Pkt,
- #{xmlns := ?NS_CLIENT, mod := Mod, lang := MyLang} = State) ->
+ #{xmlns := ?NS_CLIENT, lang := MyLang} = State) ->
try xmpp:try_subtag(Pkt, #bind{}) of
#bind{resource = R} ->
- case Mod:bind(R, State) of
+ case callback(bind, R, State) of
{ok, #{user := U, server := S, resource := NewR} = State1}
when NewR /= <<"">> ->
Reply = #bind{jid = jid:make(U, S, NewR)},
@@ -641,8 +641,8 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
send_error(State1, Pkt, Err)
end;
_ ->
- try Mod:handle_unbinded_packet(Pkt, State)
- catch _:undef ->
+ try callback(handle_unbinded_packet, Pkt, State)
+ catch _:{?MODULE, undef} ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end
@@ -652,19 +652,19 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
Err = xmpp:err_bad_request(Txt, Lang),
send_error(State, Pkt, Err)
end;
-process_bind(Pkt, #{mod := Mod} = State) ->
- try Mod:handle_unbinded_packet(Pkt, State)
- catch _:undef ->
+process_bind(Pkt, State) ->
+ try callback(handle_unbinded_packet, Pkt, State)
+ catch _:{?MODULE, undef} ->
Err = xmpp:err_not_authorized(),
send_error(State, Pkt, Err)
end.
-spec process_handshake(handshake(), state()) -> state().
process_handshake(#handshake{data = Digest},
- #{mod := Mod, stream_id := StreamID,
+ #{stream_id := StreamID,
remote_server := RemoteServer} = State) ->
- GetPW = try Mod:get_password_fun(State)
- catch _:undef -> fun(_) -> {false, undefined} end
+ GetPW = try callback(get_password_fun, State)
+ catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end
end,
AuthRes = case GetPW(<<"">>) of
{false, _} ->
@@ -674,9 +674,9 @@ process_handshake(#handshake{data = Digest},
end,
case AuthRes of
true ->
- State1 = try Mod:handle_auth_success(
+ State1 = try callback(handle_auth_success,
RemoteServer, <<"handshake">>, undefined, State)
- catch _:undef -> State
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@@ -685,9 +685,9 @@ process_handshake(#handshake{data = Digest},
process_stream_established(State2)
end;
false ->
- State1 = try Mod:handle_auth_failure(
+ State1 = try callback(handle_auth_failure,
RemoteServer, <<"handshake">>, <<"not authorized">>, State)
- catch _:undef -> State
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@@ -699,12 +699,12 @@ process_handshake(#handshake{data = Digest},
process_stream_established(#{stream_state := StateName} = State)
when StateName == disconnected; StateName == established ->
State;
-process_stream_established(#{mod := Mod} = State) ->
+process_stream_established(State) ->
State1 = State#{stream_authenticated => true,
stream_state => established,
stream_timeout => infinity},
- try Mod:handle_stream_established(State1)
- catch _:undef -> State1
+ try callback(handle_stream_established, State1)
+ catch _:{?MODULE, undef} -> State1
end.
-spec process_compress(compress(), state()) -> state().
@@ -714,9 +714,9 @@ process_compress(#compress{},
when Compressed or not Authenticated ->
send_pkt(State, #compress_failure{reason = 'setup-failed'});
process_compress(#compress{methods = HisMethods},
- #{socket := Socket, mod := Mod} = State) ->
- MyMethods = try Mod:compress_methods(State)
- catch _:undef -> []
+ #{socket := Socket} = State) ->
+ MyMethods = try callback(compress_methods, State)
+ catch _:{?MODULE, undef} -> []
end,
CommonMethods = lists_intersection(MyMethods, HisMethods),
case lists:member(<<"zlib">>, CommonMethods) of
@@ -745,12 +745,11 @@ process_compress(#compress{methods = HisMethods},
-spec process_starttls(state()) -> state().
process_starttls(#{stream_encrypted := true} = State) ->
process_starttls_failure(already_encrypted, State);
-process_starttls(#{socket := Socket,
- mod := Mod} = State) ->
+process_starttls(#{socket := Socket} = State) ->
case is_starttls_available(State) of
true ->
- TLSOpts = try Mod:tls_options(State)
- catch _:undef -> []
+ TLSOpts = try callback(tls_options, State)
+ catch _:{?MODULE, undef} -> []
end,
case xmpp_socket:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
@@ -782,7 +781,7 @@ process_starttls_failure(Why, State) ->
-spec process_sasl_request(sasl_auth(), state()) -> state().
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
- #{mod := Mod, lserver := LServer} = State) ->
+ #{lserver := LServer} = State) ->
State1 = State#{sasl_mech => Mech},
Mechs = get_sasl_mechanisms(State1),
case lists:member(Mech, Mechs) of
@@ -795,14 +794,14 @@ process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
end,
process_sasl_result(Res, State1);
true ->
- GetPW = try Mod:get_password_fun(State1)
- catch _:undef -> fun(_) -> false end
+ GetPW = try callback(get_password_fun, State1)
+ catch _:{?MODULE, undef} -> fun(_) -> false end
end,
- CheckPW = try Mod:check_password_fun(State1)
- catch _:undef -> fun(_, _, _) -> false end
+ CheckPW = try callback(check_password_fun, State1)
+ catch _:{?MODULE, undef} -> fun(_, _, _) -> false end
end,
- CheckPWDigest = try Mod:check_password_digest_fun(State1)
- catch _:undef -> fun(_, _, _, _, _) -> false end
+ CheckPWDigest = try callback(check_password_digest_fun, State1)
+ catch _:{?MODULE, undef} -> fun(_, _, _, _, _) -> false end
end,
SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
GetPW, CheckPW, CheckPWDigest),
@@ -831,13 +830,13 @@ process_sasl_result({error, Reason, User}, State) ->
-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
process_sasl_success(Props, ServerOut,
#{socket := Socket,
- mod := Mod, sasl_mech := Mech} = State) ->
+ sasl_mech := Mech} = State) ->
User = identity(Props),
AuthModule = proplists:get_value(auth_module, Props),
Socket1 = xmpp_socket:reset_stream(Socket),
State0 = State#{socket => Socket1},
- State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State0)
- catch _:undef -> State
+ State1 = try callback(handle_auth_success, User, Mech, AuthModule, State0)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@@ -865,10 +864,10 @@ process_sasl_continue(ServerOut, NewSASLState, State) ->
-spec process_sasl_failure(atom(), binary(), state()) -> state().
process_sasl_failure(Err, User,
- #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) ->
+ #{sasl_mech := Mech, lang := Lang} = State) ->
{Reason, Text} = format_sasl_error(Mech, Err),
- State1 = try Mod:handle_auth_failure(User, Mech, Text, State)
- catch _:undef -> State
+ State1 = try callback(handle_auth_failure, User, Mech, Text, State)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@@ -906,21 +905,21 @@ send_features(State) ->
State.
-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
-get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
+get_sasl_mechanisms(#{stream_encrypted := Encrypted,
xmlns := NS, lserver := LServer} = State) ->
Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
true -> []
end,
- TLSVerify = try Mod:tls_verify(State)
- catch _:undef -> false
+ TLSVerify = try callback(tls_verify, State)
+ catch _:{?MODULE, undef} -> false
end,
Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
[<<"EXTERNAL">>|Mechs];
true ->
Mechs
end,
- try Mod:sasl_mechanisms(Mechs1, State)
- catch _:undef -> Mechs1
+ try callback(sasl_mechanisms, Mechs1, State)
+ catch _:{?MODULE, undef} -> Mechs1
end.
-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
@@ -937,12 +936,12 @@ get_sasl_feature(_) ->
[].
-spec get_compress_feature(state()) -> [compression()].
-get_compress_feature(#{stream_compressed := false, mod := Mod,
+get_compress_feature(#{stream_compressed := false,
stream_authenticated := true} = State) ->
- try Mod:compress_methods(State) of
+ try callback(compress_methods, State) of
[] -> [];
Ms -> [#compression{methods = Ms}]
- catch _:undef ->
+ catch _:{?MODULE, undef} ->
[]
end;
get_compress_feature(_) ->
@@ -978,25 +977,25 @@ get_session_feature(_) ->
[].
-spec get_other_features(state()) -> [xmpp_element()].
-get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
+get_other_features(#{stream_authenticated := Auth} = State) ->
try
- if Auth -> Mod:authenticated_stream_features(State);
- true -> Mod:unauthenticated_stream_features(State)
+ if Auth -> callback(authenticated_stream_features, State);
+ true -> callback(unauthenticated_stream_features, State)
end
- catch _:undef ->
+ catch _:{?MODULE, undef} ->
[]
end.
-spec is_starttls_available(state()) -> boolean().
-is_starttls_available(#{mod := Mod} = State) ->
- try Mod:tls_enabled(State)
- catch _:undef -> true
+is_starttls_available(State) ->
+ try callback(tls_enabled, State)
+ catch _:{?MODULE, undef} -> true
end.
-spec is_starttls_required(state()) -> boolean().
-is_starttls_required(#{mod := Mod} = State) ->
- try Mod:tls_required(State)
- catch _:undef -> false
+is_starttls_required(State) ->
+ try callback(tls_required, State)
+ catch _:{?MODULE, undef} -> false
end.
-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
@@ -1076,10 +1075,10 @@ send_header(State, _) ->
State.
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
-send_pkt(#{mod := Mod} = State, Pkt) ->
+send_pkt(State, Pkt) ->
Result = socket_send(State, Pkt),
- State1 = try Mod:handle_send(Pkt, Result, State)
- catch _:undef -> State
+ State1 = try callback(handle_send, Pkt, Result, State)
+ catch _:{?MODULE, undef} -> State
end,
case Result of
_ when is_record(Pkt, stream_error) ->
@@ -1200,3 +1199,36 @@ identity(Props) ->
<<>> -> proplists:get_value(username, Props, <<>>);
AuthzId -> AuthzId
end.
+
+%%%===================================================================
+%%% Callbacks
+%%%===================================================================
+callback(F, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 1) of
+ true -> Mod:F(State);
+ false -> erlang:error({?MODULE, undef})
+ end.
+
+callback(F, Arg1, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 2) of
+ true -> Mod:F(Arg1, State);
+ false -> erlang:error({?MODULE, undef})
+ end.
+
+callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
+ %% code_change/3 callback is a special snowflake
+ case erlang:function_exported(Mod, code_change, 3) of
+ true -> Mod:code_change(OldVsn, State, Extra);
+ false -> {ok, State}
+ end;
+callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 3) of
+ true -> Mod:F(Arg1, Arg2, State);
+ false -> erlang:error({?MODULE, undef})
+ end.
+
+callback(F, Arg1, Arg2, Arg3, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 4) of
+ true -> Mod:F(Arg1, Arg2, Arg3, State);
+ false -> erlang:error({?MODULE, undef})
+ end.
diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl
index b2367a09b..da0a14e22 100644
--- a/src/xmpp_stream_out.erl
+++ b/src/xmpp_stream_out.erl
@@ -266,9 +266,9 @@ init([Mod, _SockMod, From, To, Opts]) ->
end.
-spec handle_call(term(), term(), state()) -> noreply().
-handle_call(Call, From, #{mod := Mod} = State) ->
- noreply(try Mod:handle_call(Call, From, State)
- catch _:undef -> State
+handle_call(Call, From, State) ->
+ noreply(try callback(handle_call, Call, From, State)
+ catch _:{?MODULE, undef} -> State
end).
-spec handle_cast(term(), state()) -> noreply().
@@ -311,9 +311,9 @@ handle_cast({close, Reason}, State) ->
true -> State1;
false -> process_stream_end({socket, Reason}, State)
end);
-handle_cast(Cast, #{mod := Mod} = State) ->
- noreply(try Mod:handle_cast(Cast, State)
- catch _:undef -> State
+handle_cast(Cast, State) ->
+ noreply(try callback(handle_cast, Cast, State)
+ catch _:{?MODULE, undef} -> State
end).
-spec handle_info(term(), state()) -> noreply().
@@ -348,20 +348,20 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
send_pkt(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS, mod := Mod, codec_options := Opts} = State) ->
+ #{xmlns := NS, codec_options := Opts} = State) ->
noreply(
try xmpp:decode(El, NS, Opts) of
Pkt ->
- State1 = try Mod:handle_recv(El, Pkt, State)
- catch _:undef -> State
+ State1 = try callback(handle_recv, El, Pkt, State)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
false -> process_element(Pkt, State1)
end
catch _:{xmpp_codec, Why} ->
- State1 = try Mod:handle_recv(El, {error, Why}, State)
- catch _:undef -> State
+ State1 = try callback(handle_recv, El, {error, Why}, State)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@@ -369,21 +369,21 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
- #{mod := Mod} = State) ->
- noreply(try Mod:handle_cdata(Data, State)
- catch _:undef -> State
+ State) ->
+ noreply(try callback(handle_cdata, Data, State)
+ catch _:{?MODULE, undef} -> State
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) ->
noreply(process_stream_end({socket, closed}, State));
-handle_info(timeout, #{mod := Mod, lang := Lang} = State) ->
+handle_info(timeout, #{lang := Lang} = State) ->
Disconnected = is_disconnected(State),
- noreply(try Mod:handle_timeout(State)
- catch _:undef when not Disconnected ->
+ noreply(try callback(handle_timeout, State)
+ catch _:{?MODULE, undef} when not Disconnected ->
Txt = <<"Idle connection">>,
send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
- _:undef ->
+ _:{?MODULE, undef} ->
stop(State)
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
@@ -404,26 +404,26 @@ handle_info({tcp_closed, _}, State) ->
handle_info({'$gen_event', closed}, State);
handle_info({tcp_error, _, Reason}, State) ->
noreply(process_stream_end({socket, Reason}, State));
-handle_info(Info, #{mod := Mod} = State) ->
- noreply(try Mod:handle_info(Info, State)
- catch _:undef -> State
+handle_info(Info, State) ->
+ noreply(try callback(handle_info, Info, State)
+ catch _:{?MODULE, undef} -> State
end).
-spec terminate(term(), state()) -> any().
-terminate(Reason, #{mod := Mod} = State) ->
+terminate(Reason, State) ->
case get(already_terminated) of
true ->
State;
_ ->
put(already_terminated, true),
- try Mod:terminate(Reason, State)
- catch _:undef -> ok
+ try callback(terminate, Reason, State)
+ catch _:{?MODULE, undef} -> ok
end,
send_trailer(State)
end.
-code_change(OldVsn, #{mod := Mod} = State, Extra) ->
- Mod:code_change(OldVsn, State, Extra).
+code_change(OldVsn, State, Extra) ->
+ callback(code_change, OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
@@ -458,11 +458,11 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
-process_stream_end(Reason, #{mod := Mod} = State) ->
+process_stream_end(Reason, State) ->
State1 = State#{stream_timeout => infinity,
stream_state => disconnected},
- try Mod:handle_stream_end(Reason, State1)
- catch _:undef -> stop(State1)
+ try callback(handle_stream_end, Reason, State1)
+ catch _:{?MODULE, undef} -> stop(State1)
end.
-spec process_stream(stream_start(), state()) -> state().
@@ -475,10 +475,10 @@ process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
send_pkt(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang, id = ID,
version = Version} = StreamStart,
- #{mod := Mod} = State) ->
+ State) ->
State1 = State#{stream_remote_id => ID, lang => Lang},
- State2 = try Mod:handle_stream_start(StreamStart, State1)
- catch _:undef -> State1
+ State2 = try callback(handle_stream_start, StreamStart, State1)
+ catch _:{?MODULE, undef} -> State1
end,
case is_disconnected(State2) of
true -> State2;
@@ -522,16 +522,16 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
-spec process_features(stream_features(), state()) -> state().
process_features(StreamFeatures,
- #{stream_authenticated := true, mod := Mod} = State) ->
- State1 = try Mod:handle_authenticated_features(StreamFeatures, State)
- catch _:undef -> State
+ #{stream_authenticated := true} = State) ->
+ State1 = try callback(handle_authenticated_features, StreamFeatures, State)
+ catch _:{?MODULE, undef} -> State
end,
process_stream_established(State1);
process_features(StreamFeatures,
#{stream_encrypted := Encrypted,
- mod := Mod, lang := Lang} = State) ->
- State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State)
- catch _:undef -> State
+ lang := Lang} = State) ->
+ State1 = try callback(handle_unauthenticated_features, StreamFeatures, State)
+ catch _:{?MODULE, undef} -> State
end,
case is_disconnected(State1) of
true -> State1;
@@ -582,12 +582,12 @@ process_features(StreamFeatures,
process_stream_established(#{stream_state := StateName} = State)
when StateName == disconnected; StateName == established ->
State;
-process_stream_established(#{mod := Mod} = State) ->
+process_stream_established(State) ->
State1 = State#{stream_authenticated := true,
stream_state => established,
stream_timeout => infinity},
- try Mod:handle_stream_established(State1)
- catch _:undef -> State1
+ try callback(handle_stream_established, State1)
+ catch _:{?MODULE, undef} -> State1
end.
-spec process_sasl_mechanisms([binary()], state()) -> state().
@@ -620,7 +620,7 @@ process_starttls(#{socket := Socket} = State) ->
-spec process_stream_downgrade(stream_start(), state()) -> state().
process_stream_downgrade(StreamStart,
- #{mod := Mod, lang := Lang,
+ #{lang := Lang,
stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
if not Encrypted and TLSRequired ->
@@ -628,18 +628,17 @@ process_stream_downgrade(StreamStart,
send_pkt(State, xmpp:serr_policy_violation(Txt, Lang));
true ->
State1 = State#{stream_state => downgraded},
- try Mod:handle_stream_downgraded(StreamStart, State1)
- catch _:undef ->
+ try callback(handle_stream_downgraded, StreamStart, State1)
+ catch _:{?MODULE, undef} ->
send_pkt(State1, xmpp:serr_unsupported_version())
end
end.
-spec process_cert_verification(state()) -> state().
process_cert_verification(#{stream_encrypted := true,
- stream_verified := false,
- mod := Mod} = State) ->
- case try Mod:tls_verify(State)
- catch _:undef -> true
+ stream_verified := false} = State) ->
+ case try callback(tls_verify, State)
+ catch _:{?MODULE, undef} -> true
end of
true ->
case xmpp_stream_pkix:authenticate(State) of
@@ -655,8 +654,7 @@ process_cert_verification(State) ->
State.
-spec process_sasl_success(state()) -> state().
-process_sasl_success(#{mod := Mod,
- socket := Socket} = State) ->
+process_sasl_success(#{socket := Socket} = State) ->
Socket1 = xmpp_socket:reset_stream(Socket),
State0 = State#{socket => Socket1},
State1 = State0#{stream_id => new_id(),
@@ -667,8 +665,8 @@ process_sasl_success(#{mod := Mod,
case is_disconnected(State2) of
true -> State2;
false ->
- try Mod:handle_auth_success(<<"EXTERNAL">>, State2)
- catch _:undef -> State2
+ try callback(handle_auth_success, <<"EXTERNAL">>, State2)
+ catch _:{?MODULE, undef} -> State2
end
end.
@@ -677,27 +675,27 @@ process_sasl_failure(#sasl_failure{} = Failure, State) ->
Reason = format("Peer responded with error: ~s",
[format_sasl_failure(Failure)]),
process_sasl_failure(Reason, State);
-process_sasl_failure(Reason, #{mod := Mod} = State) ->
- try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State)
- catch _:undef -> process_stream_end({auth, Reason}, State)
+process_sasl_failure(Reason, State) ->
+ try callback(handle_auth_failure, <<"EXTERNAL">>, {auth, Reason}, State)
+ catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State)
end.
-spec process_packet(xmpp_element(), state()) -> state().
-process_packet(Pkt, #{mod := Mod} = State) ->
- try Mod:handle_packet(Pkt, State)
- catch _:undef -> State
+process_packet(Pkt, State) ->
+ try callback(handle_packet, Pkt, State)
+ catch _:{?MODULE, undef} -> State
end.
-spec is_starttls_required(state()) -> boolean().
-is_starttls_required(#{mod := Mod} = State) ->
- try Mod:tls_required(State)
- catch _:undef -> false
+is_starttls_required(State) ->
+ try callback(tls_required, State)
+ catch _:{?MODULE, undef} -> false
end.
-spec is_starttls_available(state()) -> boolean().
-is_starttls_available(#{mod := Mod} = State) ->
- try Mod:tls_enabled(State)
- catch _:undef -> true
+is_starttls_available(State) ->
+ try callback(tls_enabled, State)
+ catch _:{?MODULE, undef} -> true
end.
-spec send_header(state()) -> state().
@@ -731,10 +729,10 @@ send_header(#{remote_server := RemoteServer,
end.
-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
-send_pkt(#{mod := Mod} = State, Pkt) ->
+send_pkt(State, Pkt) ->
Result = socket_send(State, Pkt),
- State1 = try Mod:handle_send(Pkt, Result, State)
- catch _:undef -> State
+ State1 = try callback(handle_send, Pkt, Result, State)
+ catch _:{?MODULE, undef} -> State
end,
case Result of
_ when is_record(Pkt, stream_error) ->
@@ -795,10 +793,10 @@ close_socket(State) ->
stream_state => disconnected}.
-spec starttls(term(), state()) -> {ok, term()} | {error, tls_error_reason()}.
-starttls(Socket, #{mod := Mod, xmlns := NS,
+starttls(Socket, #{xmlns := NS,
remote_server := RemoteServer} = State) ->
- TLSOpts = try Mod:tls_options(State)
- catch _:undef -> []
+ TLSOpts = try callback(tls_options, State)
+ catch _:{?MODULE, undef} -> []
end,
SNI = idna_to_ascii(RemoteServer),
ALPN = case NS of
@@ -1077,32 +1075,59 @@ get_addr_type({_, _, _, _}) -> inet;
get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
-spec get_dns_timeout(state()) -> timeout().
-get_dns_timeout(#{mod := Mod} = State) ->
- try Mod:dns_timeout(State)
- catch _:undef -> timer:seconds(10)
+get_dns_timeout(State) ->
+ try callback(dns_timeout, State)
+ catch _:{?MODULE, undef} -> timer:seconds(10)
end.
-spec get_dns_retries(state()) -> non_neg_integer().
-get_dns_retries(#{mod := Mod} = State) ->
- try Mod:dns_retries(State)
- catch _:undef -> 2
+get_dns_retries(State) ->
+ try callback(dns_retries, State)
+ catch _:{?MODULE, undef} -> 2
end.
-spec get_default_port(state()) -> inet:port_number().
-get_default_port(#{mod := Mod, xmlns := NS} = State) ->
- try Mod:default_port(State)
- catch _:undef when NS == ?NS_SERVER -> 5269;
- _:undef when NS == ?NS_CLIENT -> 5222
+get_default_port(#{xmlns := NS} = State) ->
+ try callback(default_port, State)
+ catch _:{?MODULE, undef} when NS == ?NS_SERVER -> 5269;
+ _:{?MODULE, undef} when NS == ?NS_CLIENT -> 5222
end.
-spec get_address_families(state()) -> [inet:address_family()].
-get_address_families(#{mod := Mod} = State) ->
- try Mod:address_families(State)
- catch _:undef -> [inet, inet6]
+get_address_families(State) ->
+ try callback(address_families, State)
+ catch _:{?MODULE, undef} -> [inet, inet6]
end.
-spec get_connect_timeout(state()) -> timeout().
-get_connect_timeout(#{mod := Mod} = State) ->
- try Mod:connect_timeout(State)
- catch _:undef -> timer:seconds(10)
+get_connect_timeout(State) ->
+ try callback(connect_timeout, State)
+ catch _:{?MODULE, undef} -> timer:seconds(10)
+ end.
+
+%%%===================================================================
+%%% Callbacks
+%%%===================================================================
+callback(F, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 1) of
+ true -> Mod:F(State);
+ false -> erlang:error({?MODULE, undef})
+ end.
+
+callback(F, Arg1, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 2) of
+ true -> Mod:F(Arg1, State);
+ false -> erlang:error({?MODULE, undef})
+ end.
+
+callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
+ %% code_change/3 callback is a special snowflake
+ case erlang:function_exported(Mod, code_change, 3) of
+ true -> Mod:code_change(OldVsn, State, Extra);
+ false -> {ok, State}
+ end;
+callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
+ case erlang:function_exported(Mod, F, 3) of
+ true -> Mod:F(Arg1, Arg2, State);
+ false -> erlang:error({?MODULE, undef})
end.