diff options
Diffstat (limited to 'src/ejabberd_s2s_out.erl')
-rw-r--r-- | src/ejabberd_s2s_out.erl | 929 |
1 files changed, 356 insertions, 573 deletions
diff --git a/src/ejabberd_s2s_out.erl b/src/ejabberd_s2s_out.erl index ae3433a6a..076ba2d3b 100644 --- a/src/ejabberd_s2s_out.erl +++ b/src/ejabberd_s2s_out.erl @@ -50,8 +50,7 @@ -include("ejabberd.hrl"). -include("logger.hrl"). - --include("jlib.hrl"). +-include("xmpp.hrl"). -record(state, {socket :: ejabberd_socket:socket_state(), @@ -75,6 +74,17 @@ bridge :: {atom(), atom()}, timer = make_ref() :: reference()}). +-type state_name() :: open_socket | wait_for_stream | + wait_for_validation | wait_for_features | + wait_for_auth_result | wait_for_starttls_proceed | + relay_to_bridge | reopen_socket | wait_before_retry | + stream_established. +-type state() :: #state{}. +-type fsm_stop() :: {stop, normal, state()}. +-type fsm_next() :: {next_state, state_name(), state(), non_neg_integer()} | + {next_state, state_name(), state()}. +-type fsm_transition() :: fsm_stop() | fsm_next(). + %%-define(DBGFSM, true). -ifdef(DBGFSM). @@ -96,23 +106,6 @@ %% Specified in miliseconds. Default value is 5 minutes. -define(MAX_RETRY_DELAY, 300000). --define(STREAM_HEADER, - <<"<?xml version='1.0'?><stream:stream " - "xmlns:stream='http://etherx.jabber.org/stream" - "s' xmlns='jabber:server' xmlns:db='jabber:ser" - "ver:dialback' from='~s' to='~s'~s>">>). - --define(STREAM_TRAILER, <<"</stream:stream>">>). - --define(INVALID_NAMESPACE_ERR, - fxml:element_to_binary(?SERR_INVALID_NAMESPACE)). - --define(HOST_UNKNOWN_ERR, - fxml:element_to_binary(?SERR_HOST_UNKNOWN)). - --define(INVALID_XML_ERR, - fxml:element_to_binary(?SERR_XML_NOT_WELL_FORMED)). - -define(SOCKET_DEFAULT_RESULT, {error, badarg}). %%%---------------------------------------------------------------------- @@ -229,17 +222,13 @@ open_socket(init, StateData) -> ?SOCKET_DEFAULT_RESULT, AddrList) of {ok, Socket} -> - Version = if StateData#state.use_v10 -> - <<" version='1.0'">>; - true -> <<"">> + Version = if StateData#state.use_v10 -> {1,0}; + true -> undefined end, NewStateData = StateData#state{socket = Socket, tls_enabled = false, streamid = new_id()}, - send_text(NewStateData, - io_lib:format(?STREAM_HEADER, - [StateData#state.myname, - StateData#state.server, Version])), + send_header(NewStateData, Version), {next_state, wait_for_stream, NewStateData, ?FSMTIMEOUT}; {error, _Reason} -> @@ -259,18 +248,8 @@ open_socket(init, StateData) -> _ -> wait_before_reconnect(StateData) end end; -open_socket(closed, StateData) -> - ?INFO_MSG("s2s connection: ~s -> ~s (stopped in " - "open socket)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -open_socket(timeout, StateData) -> - ?INFO_MSG("s2s connection: ~s -> ~s (timeout in " - "open socket)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -open_socket(_, StateData) -> - {next_state, open_socket, StateData}. +open_socket(Event, StateData) -> + handle_unexpected_event(Event, open_socket, StateData). open_socket1({_, _, _, _} = Addr, Port) -> open_socket2(inet, Addr, Port); @@ -309,466 +288,215 @@ open_socket2(Type, Addr, Port) -> %%---------------------------------------------------------------------- -wait_for_stream({xmlstreamstart, _Name, Attrs}, - StateData) -> - {CertCheckRes, CertCheckMsg, StateData0} = - if StateData#state.tls_certverify, StateData#state.tls_enabled -> - {Res, Msg} = - ejabberd_s2s:check_peer_certificate(ejabberd_socket, - StateData#state.socket, - StateData#state.server), - ?DEBUG("Certificate verification result for ~s: ~s", - [StateData#state.server, Msg]), - {Res, Msg, StateData#state{tls_certverify = false}}; +wait_for_stream({xmlstreamstart, Name, Attrs}, StateData0) -> + {CertCheckRes, CertCheckMsg, StateData} = + if StateData0#state.tls_certverify, StateData0#state.tls_enabled -> + {Res, Msg} = + ejabberd_s2s:check_peer_certificate(ejabberd_socket, + StateData0#state.socket, + StateData0#state.server), + ?DEBUG("Certificate verification result for ~s: ~s", + [StateData0#state.server, Msg]), + {Res, Msg, StateData0#state{tls_certverify = false}}; true -> - {no_verify, <<"Not verified">>, StateData} + {no_verify, <<"Not verified">>, StateData0} end, - RemoteStreamID = fxml:get_attr_s(<<"id">>, Attrs), - NewStateData = StateData0#state{remote_streamid = RemoteStreamID}, - case {fxml:get_attr_s(<<"xmlns">>, Attrs), - fxml:get_attr_s(<<"xmlns:db">>, Attrs), - fxml:get_attr_s(<<"version">>, Attrs) == <<"1.0">>} - of - _ when CertCheckRes == error -> - send_text(NewStateData, - <<(fxml:element_to_binary(?SERRT_POLICY_VIOLATION(<<"en">>, - CertCheckMsg)))/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)", - [NewStateData#state.myname, - NewStateData#state.server, - CertCheckMsg]), - {stop, normal, NewStateData}; - {<<"jabber:server">>, <<"jabber:server:dialback">>, - false} -> - send_db_request(NewStateData); - {<<"jabber:server">>, <<"jabber:server:dialback">>, - true} - when NewStateData#state.use_v10 -> - {next_state, wait_for_features, NewStateData, ?FSMTIMEOUT}; - %% Clause added to handle Tigase's workaround for an old ejabberd bug: - {<<"jabber:server">>, <<"jabber:server:dialback">>, - true} - when not NewStateData#state.use_v10 -> - send_db_request(NewStateData); - {<<"jabber:server">>, <<"">>, true} - when NewStateData#state.use_v10 -> - {next_state, wait_for_features, - NewStateData#state{db_enabled = false}, ?FSMTIMEOUT}; - {NSProvided, DB, _} -> - send_text(NewStateData, ?INVALID_NAMESPACE_ERR), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid " - "namespace).~nNamespace provided: ~p~nNamespac" - "e expected: \"jabber:server\"~nxmlns:db " - "provided: ~p~nAll attributes: ~p", - [NewStateData#state.myname, NewStateData#state.server, - NSProvided, DB, Attrs]), - {stop, normal, NewStateData} + try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of + _ when CertCheckRes == error -> + send_element(StateData, + xmpp:serr_policy_violation(CertCheckMsg, ?MYLANG)), + ?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)", + [StateData#state.myname, StateData#state.server, + CertCheckMsg]), + {stop, normal, StateData}; + #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM} + when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM -> + send_element(StateData, xmpp:serr_invalid_namespace()), + {stop, normal, StateData}; + #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID, + version = V} when V /= {1,0} -> + send_db_request(StateData#state{remote_streamid = ID}); + #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID} + when StateData#state.use_v10 -> + {next_state, wait_for_features, + StateData#state{remote_streamid = ID}, ?FSMTIMEOUT}; + #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID} + when not StateData#state.use_v10 -> + %% Handle Tigase's workaround for an old ejabberd bug: + send_db_request(StateData#state{remote_streamid = ID}); + #stream_start{id = ID} when StateData#state.use_v10 -> + {next_state, wait_for_features, + StateData#state{db_enabled = false, remote_streamid = ID}, + ?FSMTIMEOUT}; + #stream_start{} -> + send_element(StateData, xmpp:serr_invalid_namespace()), + {stop, normal, StateData}; + _ -> + send_element(StateData, xmpp:serr_invalid_xml()), + {stop, normal, StateData} + catch _:{xmpp_codec, Why} -> + Txt = xmpp:format_error(Why), + send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), + {stop, normal, StateData} end; -wait_for_stream({xmlstreamerror, _}, StateData) -> - send_text(StateData, - <<(?INVALID_XML_ERR)/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid " - "xml)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -wait_for_stream({xmlstreamend, _Name}, StateData) -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (xmlstreamend)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -wait_for_stream(timeout, StateData) -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (timeout " - "in wait_for_stream)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -wait_for_stream(closed, StateData) -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (close " - "in wait_for_stream)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}. - -wait_for_validation({xmlstreamelement, El}, +wait_for_stream(Event, StateData) -> + handle_unexpected_event(Event, wait_for_stream, StateData). + +wait_for_validation({xmlstreamelement, El}, StateData) -> + decode_element(El, wait_for_validation, StateData); +wait_for_validation(#db_result{to = To, from = From, type = Type}, StateData) -> + ?DEBUG("recv result: ~p", [{From, To, Type}]), + case {Type, StateData#state.tls_enabled, StateData#state.tls_required} of + {valid, Enabled, Required} when (Enabled == true) or (Required == false) -> + send_queue(StateData, StateData#state.queue), + ?INFO_MSG("Connection established: ~s -> ~s with " + "TLS=~p", + [StateData#state.myname, StateData#state.server, + StateData#state.tls_enabled]), + ejabberd_hooks:run(s2s_connect_hook, + [StateData#state.myname, + StateData#state.server]), + {next_state, stream_established, StateData#state{queue = queue:new()}}; + {valid, Enabled, Required} when (Enabled == false) and (Required == true) -> + ?INFO_MSG("Closing s2s connection: ~s -> ~s (TLS " + "is required but unavailable)", + [StateData#state.myname, StateData#state.server]), + {stop, normal, StateData}; + _ -> + ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid " + "dialback key result)", + [StateData#state.myname, StateData#state.server]), + {stop, normal, StateData} + end; +wait_for_validation(#db_verify{to = To, from = From, id = Id, type = Type}, StateData) -> - case is_verify_res(El) of - {result, To, From, Id, Type} -> - ?DEBUG("recv result: ~p", [{From, To, Id, Type}]), - case {Type, StateData#state.tls_enabled, - StateData#state.tls_required} - of - {<<"valid">>, Enabled, Required} - when (Enabled == true) or (Required == false) -> - send_queue(StateData, StateData#state.queue), - ?INFO_MSG("Connection established: ~s -> ~s with " - "TLS=~p", - [StateData#state.myname, StateData#state.server, - StateData#state.tls_enabled]), - ejabberd_hooks:run(s2s_connect_hook, - [StateData#state.myname, - StateData#state.server]), - {next_state, stream_established, - StateData#state{queue = queue:new()}}; - {<<"valid">>, Enabled, Required} - when (Enabled == false) and (Required == true) -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (TLS " - "is required but unavailable)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; - _ -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid " - "dialback key)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData} - end; - {verify, To, From, Id, Type} -> - ?DEBUG("recv verify: ~p", [{From, To, Id, Type}]), - case StateData#state.verify of - false -> - NextState = wait_for_validation, - {next_state, NextState, StateData, - get_timeout_interval(NextState)}; - {Pid, _Key, _SID} -> - case Type of - <<"valid">> -> - p1_fsm:send_event(Pid, - {valid, StateData#state.server, - StateData#state.myname}); - _ -> - p1_fsm:send_event(Pid, - {invalid, StateData#state.server, - StateData#state.myname}) - end, - if StateData#state.verify == false -> - {stop, normal, StateData}; - true -> - NextState = wait_for_validation, - {next_state, NextState, StateData, - get_timeout_interval(NextState)} - end - end; - _ -> - {next_state, wait_for_validation, StateData, - (?FSMTIMEOUT) * 3} + ?DEBUG("recv verify: ~p", [{From, To, Id, Type}]), + case StateData#state.verify of + false -> + NextState = wait_for_validation, + {next_state, NextState, StateData, get_timeout_interval(NextState)}; + {Pid, _Key, _SID} -> + case Type of + valid -> + p1_fsm:send_event(Pid, + {valid, StateData#state.server, + StateData#state.myname}); + _ -> + p1_fsm:send_event(Pid, + {invalid, StateData#state.server, + StateData#state.myname}) + end, + if StateData#state.verify == false -> + {stop, normal, StateData}; + true -> + NextState = wait_for_validation, + {next_state, NextState, StateData, get_timeout_interval(NextState)} + end end; -wait_for_validation({xmlstreamend, _Name}, StateData) -> - ?INFO_MSG("wait for validation: ~s -> ~s (xmlstreamend)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -wait_for_validation({xmlstreamerror, _}, StateData) -> - ?INFO_MSG("wait for validation: ~s -> ~s (xmlstreamerror)", - [StateData#state.myname, StateData#state.server]), - send_text(StateData, - <<(?INVALID_XML_ERR)/binary, - (?STREAM_TRAILER)/binary>>), - {stop, normal, StateData}; wait_for_validation(timeout, #state{verify = {VPid, VKey, SID}} = StateData) - when is_pid(VPid) and is_binary(VKey) and - is_binary(SID) -> - ?DEBUG("wait_for_validation: ~s -> ~s (timeout " - "in verify connection)", + when is_pid(VPid) and is_binary(VKey) and is_binary(SID) -> + ?DEBUG("wait_for_validation: ~s -> ~s (timeout in verify connection)", [StateData#state.myname, StateData#state.server]), {stop, normal, StateData}; -wait_for_validation(timeout, StateData) -> - ?INFO_MSG("wait_for_validation: ~s -> ~s (connect " - "timeout)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -wait_for_validation(closed, StateData) -> - ?INFO_MSG("wait for validation: ~s -> ~s (closed)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}. +wait_for_validation(Event, StateData) -> + handle_unexpected_event(Event, wait_for_validation, StateData). wait_for_features({xmlstreamelement, El}, StateData) -> - case El of - #xmlel{name = <<"stream:features">>, children = Els} -> - {SASLEXT, StartTLS, StartTLSRequired} = lists:foldl(fun - (#xmlel{name = - <<"mechanisms">>, - attrs = - Attrs1, - children - = - Els1} = - _El1, - {_SEXT, STLS, - STLSReq} = - Acc) -> - case - fxml:get_attr_s(<<"xmlns">>, - Attrs1) - of - ?NS_SASL -> - NewSEXT = - lists:any(fun - (#xmlel{name - = - <<"mechanism">>, - children - = - Els2}) -> - case - fxml:get_cdata(Els2) - of - <<"EXTERNAL">> -> - true; - _ -> - false - end; - (_) -> - false - end, - Els1), - {NewSEXT, - STLS, - STLSReq}; - _ -> Acc - end; - (#xmlel{name = - <<"starttls">>, - attrs = - Attrs1} = - El1, - {SEXT, _STLS, - _STLSReq} = - Acc) -> - case - fxml:get_attr_s(<<"xmlns">>, - Attrs1) - of - ?NS_TLS -> - Req = - case - fxml:get_subtag(El1, - <<"required">>) - of - #xmlel{} -> - true; - false -> - false - end, - {SEXT, - true, - Req}; - _ -> Acc - end; - (_, Acc) -> Acc - end, - {false, false, - false}, - Els), - if not SASLEXT and not StartTLS and - StateData#state.authenticated -> - send_queue(StateData, StateData#state.queue), - ?INFO_MSG("Connection established: ~s -> ~s with " - "SASL EXTERNAL and TLS=~p", - [StateData#state.myname, StateData#state.server, - StateData#state.tls_enabled]), - ejabberd_hooks:run(s2s_connect_hook, - [StateData#state.myname, - StateData#state.server]), - {next_state, stream_established, - StateData#state{queue = queue:new()}}; - SASLEXT and StateData#state.try_auth and - (StateData#state.new /= false) and - (StateData#state.tls_enabled or - not StateData#state.tls_required) -> - send_element(StateData, - #xmlel{name = <<"auth">>, - attrs = - [{<<"xmlns">>, ?NS_SASL}, - {<<"mechanism">>, <<"EXTERNAL">>}], - children = - [{xmlcdata, - jlib:encode_base64(StateData#state.myname)}]}), - {next_state, wait_for_auth_result, - StateData#state{try_auth = false}, ?FSMTIMEOUT}; - StartTLS and StateData#state.tls and - not StateData#state.tls_enabled -> - send_element(StateData, - #xmlel{name = <<"starttls">>, - attrs = [{<<"xmlns">>, ?NS_TLS}], - children = []}), - {next_state, wait_for_starttls_proceed, StateData, - ?FSMTIMEOUT}; - StartTLSRequired and not StateData#state.tls -> - ?DEBUG("restarted: ~p", - [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:close(StateData#state.socket), - {next_state, reopen_socket, - StateData#state{socket = undefined, use_v10 = false}, - ?FSMTIMEOUT}; - StateData#state.db_enabled -> - send_db_request(StateData); - true -> - ?DEBUG("restarted: ~p", - [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:close(StateData#state.socket), - {next_state, reopen_socket, - StateData#state{socket = undefined, use_v10 = false}, - ?FSMTIMEOUT} - end; - _ -> - send_text(StateData, - <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad " - "format)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData} + decode_element(El, wait_for_features, StateData); +wait_for_features(#stream_features{sub_els = Els}, StateData) -> + {SASLEXT, StartTLS, StartTLSRequired} = + lists:foldl( + fun(#sasl_mechanisms{list = Mechs}, {_, STLS, STLSReq}) -> + {lists:member(<<"EXTERNAL">>, Mechs), STLS, STLSReq}; + (#starttls{required = Required}, {SEXT, _, _}) -> + {SEXT, true, Required}; + (_, Acc) -> + Acc + end, {false, false, false}, Els), + if not SASLEXT and not StartTLS and StateData#state.authenticated -> + send_queue(StateData, StateData#state.queue), + ?INFO_MSG("Connection established: ~s -> ~s with " + "SASL EXTERNAL and TLS=~p", + [StateData#state.myname, StateData#state.server, + StateData#state.tls_enabled]), + ejabberd_hooks:run(s2s_connect_hook, + [StateData#state.myname, + StateData#state.server]), + {next_state, stream_established, + StateData#state{queue = queue:new()}}; + SASLEXT and StateData#state.try_auth and + (StateData#state.new /= false) and + (StateData#state.tls_enabled or + not StateData#state.tls_required) -> + send_element(StateData, + #sasl_auth{mechanism = <<"EXTERNAL">>, + text = StateData#state.myname}), + {next_state, wait_for_auth_result, + StateData#state{try_auth = false}, ?FSMTIMEOUT}; + StartTLS and StateData#state.tls and + not StateData#state.tls_enabled -> + send_element(StateData, #starttls{}), + {next_state, wait_for_starttls_proceed, StateData, ?FSMTIMEOUT}; + StartTLSRequired and not StateData#state.tls -> + ?DEBUG("restarted: ~p", + [{StateData#state.myname, StateData#state.server}]), + ejabberd_socket:close(StateData#state.socket), + {next_state, reopen_socket, + StateData#state{socket = undefined, use_v10 = false}, + ?FSMTIMEOUT}; + StateData#state.db_enabled -> + send_db_request(StateData); + true -> + ?DEBUG("restarted: ~p", + [{StateData#state.myname, StateData#state.server}]), + ejabberd_socket:close(StateData#state.socket), + {next_state, reopen_socket, + StateData#state{socket = undefined, use_v10 = false}, ?FSMTIMEOUT} end; -wait_for_features({xmlstreamend, _Name}, StateData) -> - ?INFO_MSG("wait_for_features: xmlstreamend", []), - {stop, normal, StateData}; -wait_for_features({xmlstreamerror, _}, StateData) -> - send_text(StateData, - <<(?INVALID_XML_ERR)/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("wait for features: xmlstreamerror", []), - {stop, normal, StateData}; -wait_for_features(timeout, StateData) -> - ?INFO_MSG("wait for features: timeout", []), - {stop, normal, StateData}; -wait_for_features(closed, StateData) -> - ?INFO_MSG("wait for features: closed", []), - {stop, normal, StateData}. - -wait_for_auth_result({xmlstreamelement, El}, - StateData) -> - case El of - #xmlel{name = <<"success">>, attrs = Attrs} -> - case fxml:get_attr_s(<<"xmlns">>, Attrs) of - ?NS_SASL -> - ?DEBUG("auth: ~p", - [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:reset_stream(StateData#state.socket), - send_text(StateData, - io_lib:format(?STREAM_HEADER, - [StateData#state.myname, - StateData#state.server, - <<" version='1.0'">>])), - {next_state, wait_for_stream, - StateData#state{streamid = new_id(), - authenticated = true}, - ?FSMTIMEOUT}; - _ -> - send_text(StateData, - <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad " - "format)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData} - end; - #xmlel{name = <<"failure">>, attrs = Attrs} -> - case fxml:get_attr_s(<<"xmlns">>, Attrs) of - ?NS_SASL -> - ?DEBUG("restarted: ~p", - [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:close(StateData#state.socket), - {next_state, reopen_socket, - StateData#state{socket = undefined}, ?FSMTIMEOUT}; - _ -> - send_text(StateData, - <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad " - "format)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData} - end; - _ -> - send_text(StateData, - <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad " - "format)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData} - end; -wait_for_auth_result({xmlstreamend, _Name}, - StateData) -> - ?INFO_MSG("wait for auth result: xmlstreamend", []), - {stop, normal, StateData}; -wait_for_auth_result({xmlstreamerror, _}, StateData) -> - send_text(StateData, - <<(?INVALID_XML_ERR)/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("wait for auth result: xmlstreamerror", []), - {stop, normal, StateData}; -wait_for_auth_result(timeout, StateData) -> - ?INFO_MSG("wait for auth result: timeout", []), - {stop, normal, StateData}; -wait_for_auth_result(closed, StateData) -> - ?INFO_MSG("wait for auth result: closed", []), - {stop, normal, StateData}. - -wait_for_starttls_proceed({xmlstreamelement, El}, - StateData) -> - case El of - #xmlel{name = <<"proceed">>, attrs = Attrs} -> - case fxml:get_attr_s(<<"xmlns">>, Attrs) of - ?NS_TLS -> - ?DEBUG("starttls: ~p", - [{StateData#state.myname, StateData#state.server}]), - Socket = StateData#state.socket, - TLSOpts = case - ejabberd_config:get_option( - {domain_certfile, StateData#state.myname}, - fun iolist_to_binary/1) - of - undefined -> StateData#state.tls_options; - CertFile -> - [{certfile, CertFile} - | lists:keydelete(certfile, 1, - StateData#state.tls_options)] - end, - TLSSocket = ejabberd_socket:starttls(Socket, TLSOpts), - NewStateData = StateData#state{socket = TLSSocket, - streamid = new_id(), - tls_enabled = true, - tls_options = TLSOpts}, - send_text(NewStateData, - io_lib:format(?STREAM_HEADER, - [NewStateData#state.myname, - NewStateData#state.server, - <<" version='1.0'">>])), - {next_state, wait_for_stream, NewStateData, - ?FSMTIMEOUT}; - _ -> - send_text(StateData, - <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad " - "format)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData} - end; - _ -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad " - "format)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData} - end; -wait_for_starttls_proceed({xmlstreamend, _Name}, - StateData) -> - ?INFO_MSG("wait for starttls proceed: xmlstreamend", - []), - {stop, normal, StateData}; -wait_for_starttls_proceed({xmlstreamerror, _}, - StateData) -> - send_text(StateData, - <<(?INVALID_XML_ERR)/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("wait for starttls proceed: xmlstreamerror", - []), - {stop, normal, StateData}; -wait_for_starttls_proceed(timeout, StateData) -> - ?INFO_MSG("wait for starttls proceed: timeout", []), - {stop, normal, StateData}; -wait_for_starttls_proceed(closed, StateData) -> - ?INFO_MSG("wait for starttls proceed: closed", []), - {stop, normal, StateData}. +wait_for_features(Event, StateData) -> + handle_unexpected_event(Event, wait_for_features, StateData). + +wait_for_auth_result({xmlstreamelement, El}, StateData) -> + decode_element(El, wait_for_auth_result, StateData); +wait_for_auth_result(#sasl_success{}, StateData) -> + ?DEBUG("auth: ~p", [{StateData#state.myname, StateData#state.server}]), + ejabberd_socket:reset_stream(StateData#state.socket), + send_header(StateData, {1,0}), + {next_state, wait_for_stream, + StateData#state{streamid = new_id(), authenticated = true}, + ?FSMTIMEOUT}; +wait_for_auth_result(#sasl_failure{}, StateData) -> + ?DEBUG("restarted: ~p", [{StateData#state.myname, StateData#state.server}]), + ejabberd_socket:close(StateData#state.socket), + {next_state, reopen_socket, + StateData#state{socket = undefined}, ?FSMTIMEOUT}; +wait_for_auth_result(Event, StateData) -> + handle_unexpected_event(Event, wait_for_auth_result, StateData). + +wait_for_starttls_proceed({xmlstreamelement, El}, StateData) -> + decode_element(El, wait_for_starttls_proceed, StateData); +wait_for_starttls_proceed(#starttls_proceed{}, StateData) -> + ?DEBUG("starttls: ~p", [{StateData#state.myname, StateData#state.server}]), + Socket = StateData#state.socket, + TLSOpts = case ejabberd_config:get_option( + {domain_certfile, StateData#state.myname}, + fun iolist_to_binary/1) of + undefined -> StateData#state.tls_options; + CertFile -> + [{certfile, CertFile} + | lists:keydelete(certfile, 1, + StateData#state.tls_options)] + end, + TLSSocket = ejabberd_socket:starttls(Socket, TLSOpts), + NewStateData = StateData#state{socket = TLSSocket, + streamid = new_id(), + tls_enabled = true, + tls_options = TLSOpts}, + send_header(NewStateData, {1,0}), + {next_state, wait_for_stream, NewStateData, ?FSMTIMEOUT}; +wait_for_starttls_proceed(Event, StateData) -> + handle_unexpected_event(Event, wait_for_starttls_proceed, StateData). reopen_socket({xmlstreamelement, _El}, StateData) -> {next_state, reopen_socket, StateData, ?FSMTIMEOUT}; @@ -797,47 +525,70 @@ relay_to_bridge(_Event, StateData) -> {next_state, relay_to_bridge, StateData}. stream_established({xmlstreamelement, El}, StateData) -> - ?DEBUG("s2S stream established", []), - case is_verify_res(El) of - {verify, VTo, VFrom, VId, VType} -> - ?DEBUG("recv verify: ~p", [{VFrom, VTo, VId, VType}]), - case StateData#state.verify of - {VPid, _VKey, _SID} -> - case VType of - <<"valid">> -> - p1_fsm:send_event(VPid, - {valid, StateData#state.server, - StateData#state.myname}); - _ -> - p1_fsm:send_event(VPid, - {invalid, StateData#state.server, - StateData#state.myname}) - end; - _ -> ok - end; - _ -> ok + decode_element(El, stream_established, StateData); +stream_established(#db_verify{to = VTo, from = VFrom, id = VId, type = VType}, + StateData) -> + ?DEBUG("recv verify: ~p", [{VFrom, VTo, VId, VType}]), + case StateData#state.verify of + {VPid, _VKey, _SID} -> + case VType of + valid -> + p1_fsm:send_event(VPid, + {valid, StateData#state.server, + StateData#state.myname}); + _ -> + p1_fsm:send_event(VPid, + {invalid, StateData#state.server, + StateData#state.myname}) + end; + _ -> ok end, {next_state, stream_established, StateData}; -stream_established({xmlstreamend, _Name}, StateData) -> - ?INFO_MSG("Connection closed in stream established: " - "~s -> ~s (xmlstreamend)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -stream_established({xmlstreamerror, _}, StateData) -> - send_text(StateData, - <<(?INVALID_XML_ERR)/binary, - (?STREAM_TRAILER)/binary>>), - ?INFO_MSG("stream established: ~s -> ~s (xmlstreamerror)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -stream_established(timeout, StateData) -> - ?INFO_MSG("stream established: ~s -> ~s (timeout)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -stream_established(closed, StateData) -> - ?INFO_MSG("stream established: ~s -> ~s (closed)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}. +stream_established(Event, StateData) -> + handle_unexpected_event(Event, stream_established, StateData). + +-spec handle_unexpected_event(term(), state_name(), state()) -> fsm_transition(). +handle_unexpected_event(Event, StateName, StateData) -> + case Event of + {xmlstreamerror, _} -> + send_element(StateData, xmpp:serr_not_well_formed()), + ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " + "got invalid XML from peer", + [StateData#state.myname, StateData#state.server, + StateName]), + {stop, normal, StateData}; + {xmlstreamend, _} -> + ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " + "XML stream closed by peer", + [StateData#state.myname, StateData#state.server, + StateName]), + {stop, normal, StateData}; + timeout -> + send_element(StateData, xmpp:serr_connection_timeout()), + ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " + "timed out during establishing an XML stream", + [StateData#state.myname, StateData#state.server, + StateName]), + {stop, normal, StateData}; + closed -> + ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " + "connection socket closed", + [StateData#state.myname, StateData#state.server, + StateName]), + {stop, normal, StateData}; + Pkt when StateName == wait_for_stream; + StateName == wait_for_features; + StateName == wait_for_auth_result; + StateName == wait_for_starttls_proceed -> + send_element(StateData, xmpp:serr_bad_format()), + ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " + "got unexpected event ~p", + [StateData#state.myname, StateData#state.server, + StateName, Pkt]), + {stop, normal, StateData}; + _ -> + {next_state, StateName, StateData, get_timeout_interval(StateName)} + end. %%---------------------------------------------------------------------- %% Func: StateName/3 @@ -917,7 +668,7 @@ handle_info({send_element, El}, StateName, StateData) -> %% In this state we bounce all message: We are waiting before %% trying to reconnect wait_before_retry -> - bounce_element(El, ?ERR_REMOTE_SERVER_NOT_FOUND), + bounce_element(El, xmpp:err_remote_server_not_found()), {next_state, StateName, StateData}; relay_to_bridge -> {Mod, Fun} = StateData#state.bridge, @@ -926,7 +677,7 @@ handle_info({send_element, El}, StateName, StateData) -> {'EXIT', Reason} -> ?ERROR_MSG("Error while relaying to bridge: ~p", [Reason]), - bounce_element(El, ?ERR_INTERNAL_SERVER_ERROR), + bounce_element(El, xmpp:err_internal_server_error()), wait_before_reconnect(StateData); _ -> {next_state, StateName, StateData} end; @@ -966,12 +717,13 @@ terminate(Reason, StateName, StateData) -> StateData#state.server}, self()) end, - bounce_queue(StateData#state.queue, - ?ERR_REMOTE_SERVER_NOT_FOUND), - bounce_messages(?ERR_REMOTE_SERVER_NOT_FOUND), + bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()), + bounce_messages(xmpp:err_remote_server_not_found()), case StateData#state.socket of undefined -> ok; - _Socket -> ejabberd_socket:close(StateData#state.socket) + _Socket -> + catch send_trailer(StateData), + ejabberd_socket:close(StateData#state.socket) end, ok. @@ -981,12 +733,32 @@ print_state(State) -> State. %%% Internal functions %%%---------------------------------------------------------------------- +-spec send_text(state(), iodata()) -> ok. send_text(StateData, Text) -> + ?DEBUG("Send Text on stream = ~s", [Text]), ejabberd_socket:send(StateData#state.socket, Text). +-spec send_element(state(), xmpp_element()) -> ok. send_element(StateData, El) -> - send_text(StateData, fxml:element_to_binary(El)). - + El1 = xmpp:encode(El, ?NS_SERVER), + send_text(StateData, fxml:element_to_binary(El1)). + +-spec send_header(state(), undefined | {integer(), integer()}) -> ok. +send_header(StateData, Version) -> + Header = xmpp:encode( + #stream_start{xmlns = ?NS_SERVER, + stream_xmlns = ?NS_STREAM, + db_xmlns = ?NS_SERVER_DIALBACK, + from = jid:make(StateData#state.myname), + to = jid:make(StateData#state.server), + version = Version}), + send_text(StateData, fxml:element_to_header(Header)). + +-spec send_trailer(state()) -> ok. +send_trailer(StateData) -> + send_text(StateData, <<"</stream:stream>">>). + +-spec send_queue(state(), queue:queue()) -> ok. send_queue(StateData, Q) -> case queue:out(Q) of {{value, El}, Q1} -> @@ -995,20 +767,13 @@ send_queue(StateData, Q) -> end. %% Bounce a single message (xmlelement) +-spec bounce_element(stanza(), stanza_error()) -> ok. bounce_element(El, Error) -> - #xmlel{attrs = Attrs} = El, - case fxml:get_attr_s(<<"type">>, Attrs) of - <<"error">> -> ok; - <<"result">> -> ok; - _ -> - Err = jlib:make_error_reply(El, Error), - From = jid:from_string(fxml:get_tag_attr_s(<<"from">>, - El)), - To = jid:from_string(fxml:get_tag_attr_s(<<"to">>, - El)), - ejabberd_router:route(To, From, Err) - end. + From = xmpp:get_from(El), + To = xmpp:get_to(El), + ejabberd_router:route_error(To, From, El, Error). +-spec bounce_queue(queue:queue(), stanza_error()) -> ok. bounce_queue(Q, Error) -> case queue:out(Q) of {{value, El}, Q1} -> @@ -1016,12 +781,15 @@ bounce_queue(Q, Error) -> {empty, _} -> ok end. +-spec new_id() -> binary(). new_id() -> randoms:get_string(). +-spec cancel_timer(reference()) -> ok. cancel_timer(Timer) -> erlang:cancel_timer(Timer), receive {timeout, Timer, _} -> ok after 0 -> ok end. +-spec bounce_messages(stanza_error()) -> ok. bounce_messages(Error) -> receive {send_element, El} -> @@ -1029,6 +797,7 @@ bounce_messages(Error) -> after 0 -> ok end. +-spec send_db_request(state()) -> fsm_transition(). send_db_request(StateData) -> Server = StateData#state.server, New = case StateData#state.new of @@ -1045,22 +814,18 @@ send_db_request(StateData) -> {StateData#state.myname, Server}, StateData#state.remote_streamid), send_element(StateData, - #xmlel{name = <<"db:result">>, - attrs = - [{<<"from">>, StateData#state.myname}, - {<<"to">>, Server}], - children = [{xmlcdata, Key1}]}) + #db_result{from = StateData#state.myname, + to = Server, + key = Key1}) end, case StateData#state.verify of false -> ok; {_Pid, Key2, SID} -> send_element(StateData, - #xmlel{name = <<"db:verify">>, - attrs = - [{<<"from">>, StateData#state.myname}, - {<<"to">>, StateData#state.server}, - {<<"id">>, SID}], - children = [{xmlcdata, Key2}]}) + #db_verify{from = StateData#state.myname, + to = StateData#state.server, + id = SID, + key = Key2}) end, {next_state, wait_for_validation, NewStateData, (?FSMTIMEOUT) * 6} @@ -1068,20 +833,6 @@ send_db_request(StateData) -> _:_ -> {stop, normal, NewStateData} end. -is_verify_res(#xmlel{name = Name, attrs = Attrs}) - when Name == <<"db:result">> -> - {result, fxml:get_attr_s(<<"to">>, Attrs), - fxml:get_attr_s(<<"from">>, Attrs), - fxml:get_attr_s(<<"id">>, Attrs), - fxml:get_attr_s(<<"type">>, Attrs)}; -is_verify_res(#xmlel{name = Name, attrs = Attrs}) - when Name == <<"db:verify">> -> - {verify, fxml:get_attr_s(<<"to">>, Attrs), - fxml:get_attr_s(<<"from">>, Attrs), - fxml:get_attr_s(<<"id">>, Attrs), - fxml:get_attr_s(<<"type">>, Attrs)}; -is_verify_res(_) -> false. - %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% SRV support @@ -1189,12 +940,14 @@ get_addrs(Host, Family) -> [] end. +-spec outgoing_s2s_port() -> pos_integer(). outgoing_s2s_port() -> ejabberd_config:get_option( outgoing_s2s_port, fun(I) when is_integer(I), I > 0, I =< 65536 -> I end, 5269). +-spec outgoing_s2s_families() -> [ipv4 | ipv6]. outgoing_s2s_families() -> ejabberd_config:get_option( outgoing_s2s_families, @@ -1206,6 +959,7 @@ outgoing_s2s_families() -> Families end, [ipv4, ipv6]). +-spec outgoing_s2s_timeout() -> pos_integer(). outgoing_s2s_timeout() -> ejabberd_config:get_option( outgoing_s2s_timeout, @@ -1255,21 +1009,24 @@ log_s2s_out(_, Myname, Server, Tls) -> %% Calculate timeout depending on which state we are in: %% Can return integer > 0 | infinity +-spec get_timeout_interval(state_name()) -> pos_integer() | infinity. get_timeout_interval(StateName) -> case StateName of %% Validation implies dialback: Networking can take longer: wait_for_validation -> (?FSMTIMEOUT) * 6; %% When stream is established, we only rely on S2S Timeout timer: stream_established -> infinity; + relay_to_bridge -> infinity; + open_socket -> infinity; _ -> ?FSMTIMEOUT end. %% This function is intended to be called at the end of a state %% function that want to wait for a reconnect delay before stopping. +-spec wait_before_reconnect(state()) -> fsm_next(). wait_before_reconnect(StateData) -> - bounce_queue(StateData#state.queue, - ?ERR_REMOTE_SERVER_NOT_FOUND), - bounce_messages(?ERR_REMOTE_SERVER_NOT_FOUND), + bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()), + bounce_messages(xmpp:err_remote_server_not_found()), cancel_timer(StateData#state.timer), Delay = case StateData#state.delay_to_retry of undefined_delay -> @@ -1281,6 +1038,7 @@ wait_before_reconnect(StateData) -> StateData#state{timer = Timer, delay_to_retry = Delay, queue = queue:new()}}. +-spec get_max_retry_delay() -> pos_integer(). get_max_retry_delay() -> case ejabberd_config:get_option( s2s_max_retry_delay, @@ -1290,6 +1048,7 @@ get_max_retry_delay() -> end. %% Terminate s2s_out connections that are in state wait_before_retry +-spec terminate_if_waiting_delay(binary(), binary()) -> ok. terminate_if_waiting_delay(From, To) -> FromTo = {From, To}, Pids = ejabberd_s2s:get_connections_pids(FromTo), @@ -1298,6 +1057,7 @@ terminate_if_waiting_delay(From, To) -> end, Pids). +-spec fsm_limit_opts() -> [{max_queue, pos_integer()}]. fsm_limit_opts() -> case ejabberd_config:get_option( max_fsm_queue, @@ -1306,6 +1066,29 @@ fsm_limit_opts() -> N -> [{max_queue, N}] end. +-spec decode_element(xmlel(), state_name(), state()) -> fsm_next(). +decode_element(#xmlel{} = El, StateName, StateData) -> + Opts = if StateName == stream_established -> + [ignore_els]; + true -> + [] + end, + try xmpp:decode(El, ?NS_SERVER, Opts) of + Pkt -> ?MODULE:StateName(Pkt, StateData) + catch error:{xmpp_codec, Why} -> + Type = xmpp:get_type(El), + case xmpp:is_stanza(El) of + true when Type /= <<"result">>, Type /= <<"error">> -> + Lang = xmpp:get_lang(El), + Txt = xmpp:format_error(Why), + Err = xmpp:make_error(El, xmpp:err_bad_request(Txt, Lang)), + send_element(StateData, Err); + false -> + ok + end, + {next_state, StateName, StateData, get_timeout_interval(StateName)} + end. + opt_type(domain_certfile) -> fun iolist_to_binary/1; opt_type(max_fsm_queue) -> fun (I) when is_integer(I), I > 0 -> I end; |