diff options
Diffstat (limited to 'src/ejabberd_s2s_in.erl')
-rw-r--r-- | src/ejabberd_s2s_in.erl | 673 |
1 files changed, 321 insertions, 352 deletions
diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index d8d0a400a..395a0fce7 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -42,7 +42,7 @@ -include("ejabberd.hrl"). -include("logger.hrl"). --include("jlib.hrl"). +-include("xmpp.hrl"). -define(DICT, dict). @@ -62,40 +62,19 @@ connections = (?DICT):new() :: ?TDICT, timer = make_ref() :: reference()}). -%-define(DBGFSM, true). +-type state_name() :: wait_for_stream | wait_for_feature_request | stream_established. +-type state() :: #state{}. +-type fsm_next() :: {next_state, state_name(), state()}. +-type fsm_stop() :: {stop, normal, state()}. +-type fsm_transition() :: fsm_stop() | fsm_next(). +%%-define(DBGFSM, true). -ifdef(DBGFSM). - -define(FSMOPTS, [{debug, [trace]}]). - -else. - -define(FSMOPTS, []). - -endif. --define(STREAM_HEADER(Version), - <<"<?xml version='1.0'?><stream:stream " - "xmlns:stream='http://etherx.jabber.org/stream" - "s' xmlns='jabber:server' xmlns:db='jabber:ser" - "ver:dialback' id='", - (StateData#state.streamid)/binary, "'", Version/binary, - ">">>). - --define(STREAM_TRAILER, <<"</stream:stream>">>). - --define(INVALID_NAMESPACE_ERR, - fxml:element_to_binary(?SERR_INVALID_NAMESPACE)). - --define(HOST_UNKNOWN_ERR, - fxml:element_to_binary(?SERR_HOST_UNKNOWN)). - --define(INVALID_FROM_ERR, - fxml:element_to_binary(?SERR_INVALID_FROM)). - --define(INVALID_XML_ERR, - fxml:element_to_binary(?SERR_XML_NOT_WELL_FORMED)). - start(SockData, Opts) -> supervisor:start_child(ejabberd_s2s_in_sup, [SockData, Opts]). @@ -185,351 +164,294 @@ init([{SockMod, Socket}, Opts]) -> %% {next_state, NextStateName, NextStateData, Timeout} | %% {stop, Reason, NewStateData} %%---------------------------------------------------------------------- - -wait_for_stream({xmlstreamstart, _Name, Attrs}, - StateData) -> - case {fxml:get_attr_s(<<"xmlns">>, Attrs), - fxml:get_attr_s(<<"xmlns:db">>, Attrs), - fxml:get_attr_s(<<"to">>, Attrs), - fxml:get_attr_s(<<"version">>, Attrs) == <<"1.0">>} - of - {<<"jabber:server">>, _, Server, true} - when StateData#state.tls and - not StateData#state.authenticated -> - send_text(StateData, - ?STREAM_HEADER(<<" version='1.0'">>)), - Auth = if StateData#state.tls_enabled -> - case jid:nameprep(fxml:get_attr_s(<<"from">>, Attrs)) of - From when From /= <<"">>, From /= error -> - {Result, Message} = - ejabberd_s2s:check_peer_certificate(StateData#state.sockmod, - StateData#state.socket, - From), - {Result, From, Message}; - _ -> - {error, <<"(unknown)">>, - <<"Got no valid 'from' attribute">>} - end; - true -> - {no_verify, <<"(unknown)">>, - <<"TLS not (yet) enabled">>} - end, - StartTLS = if StateData#state.tls_enabled -> []; - not StateData#state.tls_enabled and +wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> + try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of + #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM} + when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM -> + send_header(StateData, {1,0}), + send_element(StateData, xmpp:serr_invalid_namespace()), + {stop, normal, StateData}; + #stream_start{to = #jid{lserver = Server}, + from = From, version = {1,0}} + when StateData#state.tls and not StateData#state.authenticated -> + send_header(StateData, {1,0}), + Auth = if StateData#state.tls_enabled -> + case From of + #jid{} -> + {Result, Message} = + ejabberd_s2s:check_peer_certificate( + StateData#state.sockmod, + StateData#state.socket, + From#jid.lserver), + {Result, From#jid.lserver, Message}; + undefined -> + {error, <<"(unknown)">>, + <<"Got no valid 'from' attribute">>} + end; + true -> + {no_verify, <<"(unknown)">>, <<"TLS not (yet) enabled">>} + end, + StartTLS = if StateData#state.tls_enabled -> []; + not StateData#state.tls_enabled and not StateData#state.tls_required -> - [#xmlel{name = <<"starttls">>, - attrs = [{<<"xmlns">>, ?NS_TLS}], - children = []}]; - not StateData#state.tls_enabled and + [#starttls{required = false}]; + not StateData#state.tls_enabled and StateData#state.tls_required -> - [#xmlel{name = <<"starttls">>, - attrs = [{<<"xmlns">>, ?NS_TLS}], - children = - [#xmlel{name = <<"required">>, - attrs = [], children = []}]}] - end, - case Auth of - {error, RemoteServer, CertError} - when StateData#state.tls_certverify -> - ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)", - [StateData#state.server, RemoteServer, CertError]), - send_text(StateData, - <<(fxml:element_to_binary(?SERRT_POLICY_VIOLATION(<<"en">>, - CertError)))/binary, - (?STREAM_TRAILER)/binary>>), - {stop, normal, StateData}; - {VerifyResult, RemoteServer, Msg} -> - {SASL, NewStateData} = case VerifyResult of - ok -> - {[#xmlel{name = <<"mechanisms">>, - attrs = [{<<"xmlns">>, ?NS_SASL}], - children = - [#xmlel{name = <<"mechanism">>, - attrs = [], - children = - [{xmlcdata, - <<"EXTERNAL">>}]}]}], - StateData#state{auth_domain = RemoteServer}}; - error -> - ?DEBUG("Won't accept certificate of ~s: ~s", - [RemoteServer, Msg]), - {[], StateData}; - no_verify -> - {[], StateData} - end, - send_element(NewStateData, - #xmlel{name = <<"stream:features">>, attrs = [], - children = - SASL ++ - StartTLS ++ - ejabberd_hooks:run_fold(s2s_stream_features, - Server, [], - [Server])}), - {next_state, wait_for_feature_request, - NewStateData#state{server = Server}} - end; - {<<"jabber:server">>, _, Server, true} - when StateData#state.authenticated -> - send_text(StateData, - ?STREAM_HEADER(<<" version='1.0'">>)), - send_element(StateData, - #xmlel{name = <<"stream:features">>, attrs = [], - children = - ejabberd_hooks:run_fold(s2s_stream_features, - Server, [], - [Server])}), - {next_state, stream_established, StateData}; - {<<"jabber:server">>, <<"jabber:server:dialback">>, - _Server, _} when - (StateData#state.tls_required and StateData#state.tls_enabled) - or (not StateData#state.tls_required) -> - send_text(StateData, ?STREAM_HEADER(<<"">>)), - {next_state, stream_established, StateData}; - _ -> - send_text(StateData, ?INVALID_NAMESPACE_ERR), - {stop, normal, StateData} + [#starttls{required = true}] + end, + case Auth of + {error, RemoteServer, CertError} + when StateData#state.tls_certverify -> + ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)", + [StateData#state.server, RemoteServer, CertError]), + send_element(StateData, + xmpp:serr_policy_violation(CertError, ?MYLANG)), + {stop, normal, StateData}; + {VerifyResult, RemoteServer, Msg} -> + {SASL, NewStateData} = + case VerifyResult of + ok -> + {[#sasl_mechanisms{list = [<<"EXTERNAL">>]}], + StateData#state{auth_domain = RemoteServer}}; + error -> + ?DEBUG("Won't accept certificate of ~s: ~s", + [RemoteServer, Msg]), + {[], StateData}; + no_verify -> + {[], StateData} + end, + send_element(NewStateData, + #stream_features{ + sub_els = SASL ++ StartTLS ++ + ejabberd_hooks:run_fold( + s2s_stream_features, Server, [], + [Server])}), + {next_state, wait_for_feature_request, + NewStateData#state{server = Server}} + end; + #stream_start{to = #jid{lserver = Server}, + version = {1,0}} when StateData#state.authenticated -> + send_header(StateData, {1,0}), + send_element(StateData, + #stream_features{ + sub_els = ejabberd_hooks:run_fold( + s2s_stream_features, Server, [], + [Server])}), + {next_state, stream_established, StateData}; + #stream_start{db_xmlns = ?NS_SERVER_DIALBACK} + when (StateData#state.tls_required and StateData#state.tls_enabled) + or (not StateData#state.tls_required) -> + send_header(StateData, undefined), + {next_state, stream_established, StateData}; + #stream_start{} -> + send_header(StateData, {1,0}), + send_element(StateData, xmpp:serr_undefined_condition()), + {stop, normal, StateData}; + _ -> + send_header(StateData, {1,0}), + send_element(StateData, xmpp:serr_invalid_xml()), + {stop, normal, StateData} + catch _:{xmpp_codec, Why} -> + Txt = xmpp:format_error(Why), + send_header(StateData, {1,0}), + send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), + {stop, normal, StateData} end; wait_for_stream({xmlstreamerror, _}, StateData) -> - send_text(StateData, - <<(?STREAM_HEADER(<<"">>))/binary, - (?INVALID_XML_ERR)/binary, (?STREAM_TRAILER)/binary>>), + send_header(StateData, {1,0}), + send_element(StateData, xmpp:serr_not_well_formed()), {stop, normal, StateData}; wait_for_stream(timeout, StateData) -> + send_header(StateData, {1,0}), + send_element(StateData, xmpp:serr_connection_timeout()), {stop, normal, StateData}; wait_for_stream(closed, StateData) -> {stop, normal, StateData}. -wait_for_feature_request({xmlstreamelement, El}, - StateData) -> - #xmlel{name = Name, attrs = Attrs} = El, - TLS = StateData#state.tls, - TLSEnabled = StateData#state.tls_enabled, - SockMod = - (StateData#state.sockmod):get_sockmod(StateData#state.socket), - case {fxml:get_attr_s(<<"xmlns">>, Attrs), Name} of - {?NS_TLS, <<"starttls">>} - when TLS == true, TLSEnabled == false, - SockMod == gen_tcp -> - ?DEBUG("starttls", []), - Socket = StateData#state.socket, - TLSOpts1 = case - ejabberd_config:get_option( - {domain_certfile, StateData#state.server}, - fun iolist_to_binary/1) of - undefined -> StateData#state.tls_options; - CertFile -> - [{certfile, CertFile} | lists:keydelete(certfile, 1, - StateData#state.tls_options)] - end, - TLSOpts = case ejabberd_config:get_option( - {s2s_tls_compression, StateData#state.server}, - fun(true) -> true; - (false) -> false - end, false) of - true -> lists:delete(compression_none, TLSOpts1); - false -> [compression_none | TLSOpts1] - end, - TLSSocket = (StateData#state.sockmod):starttls(Socket, - TLSOpts, - fxml:element_to_binary(#xmlel{name - = - <<"proceed">>, - attrs - = - [{<<"xmlns">>, - ?NS_TLS}], - children - = - []})), - {next_state, wait_for_stream, - StateData#state{socket = TLSSocket, streamid = new_id(), - tls_enabled = true, tls_options = TLSOpts}}; - {?NS_SASL, <<"auth">>} when TLSEnabled -> - Mech = fxml:get_attr_s(<<"mechanism">>, Attrs), - case Mech of - <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> -> - AuthDomain = StateData#state.auth_domain, - AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>, - AuthDomain), - if AllowRemoteHost -> - (StateData#state.sockmod):reset_stream(StateData#state.socket), - send_element(StateData, - #xmlel{name = <<"success">>, - attrs = [{<<"xmlns">>, ?NS_SASL}], - children = []}), - ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)", - [AuthDomain, StateData#state.tls_enabled]), - change_shaper(StateData, <<"">>, - jid:make(<<"">>, AuthDomain, <<"">>)), - {next_state, wait_for_stream, - StateData#state{streamid = new_id(), - authenticated = true}}; - true -> - send_element(StateData, - #xmlel{name = <<"failure">>, - attrs = [{<<"xmlns">>, ?NS_SASL}], - children = []}), - send_text(StateData, ?STREAM_TRAILER), - {stop, normal, StateData} - end; - _ -> - send_element(StateData, - #xmlel{name = <<"failure">>, - attrs = [{<<"xmlns">>, ?NS_SASL}], - children = - [#xmlel{name = <<"invalid-mechanism">>, - attrs = [], children = []}]}), - {stop, normal, StateData} - end; - _ -> - stream_established({xmlstreamelement, El}, StateData) +wait_for_feature_request({xmlstreamelement, El}, StateData) -> + decode_element(El, wait_for_feature_request, StateData); +wait_for_feature_request(#starttls{}, + #state{tls = true, tls_enabled = false} = StateData) -> + case (StateData#state.sockmod):get_sockmod(StateData#state.socket) of + gen_tcp -> + ?DEBUG("starttls", []), + Socket = StateData#state.socket, + TLSOpts1 = case + ejabberd_config:get_option( + {domain_certfile, StateData#state.server}, + fun iolist_to_binary/1) of + undefined -> StateData#state.tls_options; + CertFile -> + lists:keystore(certfile, 1, + StateData#state.tls_options, + {certfile, CertFile}) + end, + TLSOpts2 = case ejabberd_config:get_option( + {s2s_cafile, StateData#state.server}, + fun iolist_to_binary/1) of + undefined -> TLSOpts1; + CAFile -> + lists:keystore(cafile, 1, TLSOpts1, + {cafile, CAFile}) + end, + TLSOpts = case ejabberd_config:get_option( + {s2s_tls_compression, StateData#state.server}, + fun(true) -> true; + (false) -> false + end, false) of + true -> lists:delete(compression_none, TLSOpts2); + false -> [compression_none | TLSOpts2] + end, + TLSSocket = (StateData#state.sockmod):starttls( + Socket, TLSOpts, + fxml:element_to_binary( + xmpp:encode(#starttls_proceed{}))), + {next_state, wait_for_stream, + StateData#state{socket = TLSSocket, streamid = new_id(), + tls_enabled = true, tls_options = TLSOpts}}; + _ -> + send_element(StateData, #starttls_failure{}), + {stop, normal, StateData} end; -wait_for_feature_request({xmlstreamend, _Name}, - StateData) -> - send_text(StateData, ?STREAM_TRAILER), +wait_for_feature_request(#sasl_auth{mechanism = Mech}, + #state{tls_enabled = true} = StateData) -> + case Mech of + <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> -> + AuthDomain = StateData#state.auth_domain, + AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>, AuthDomain), + if AllowRemoteHost -> + (StateData#state.sockmod):reset_stream(StateData#state.socket), + send_element(StateData, #sasl_success{}), + ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)", + [AuthDomain, StateData#state.tls_enabled]), + change_shaper(StateData, <<"">>, jid:make(AuthDomain)), + {next_state, wait_for_stream, + StateData#state{streamid = new_id(), + authenticated = true}}; + true -> + Txt = xmpp:mk_text(<<"Denied by ACL">>, ?MYLANG), + send_element(StateData, + #sasl_failure{reason = 'not-authorized', + text = Txt}), + {stop, normal, StateData} + end; + _ -> + send_element(StateData, #sasl_failure{reason = 'invalid-mechanism'}), + {stop, normal, StateData} + end; +wait_for_feature_request({xmlstreamend, _Name}, StateData) -> {stop, normal, StateData}; -wait_for_feature_request({xmlstreamerror, _}, - StateData) -> - send_text(StateData, - <<(?INVALID_XML_ERR)/binary, - (?STREAM_TRAILER)/binary>>), +wait_for_feature_request({xmlstreamerror, _}, StateData) -> + send_element(StateData, xmpp:serr_not_well_formed()), {stop, normal, StateData}; wait_for_feature_request(closed, StateData) -> - {stop, normal, StateData}. + {stop, normal, StateData}; +wait_for_feature_request(_Pkt, #state{tls_required = TLSRequired, + tls_enabled = TLSEnabled} = StateData) + when TLSRequired and not TLSEnabled -> + Txt = <<"Use of STARTTLS required">>, + send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)), + {stop, normal, StateData}; +wait_for_feature_request(El, StateData) -> + stream_established({xmlstreamelement, El}, StateData). stream_established({xmlstreamelement, El}, StateData) -> cancel_timer(StateData#state.timer), Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - case is_key_packet(El) of - {key, To, From, Id, Key} -> - ?DEBUG("GET KEY: ~p", [{To, From, Id, Key}]), - LTo = jid:nameprep(To), - LFrom = jid:nameprep(From), - case {ejabberd_s2s:allow_host(LTo, LFrom), - lists:member(LTo, - ejabberd_router:dirty_get_all_domains())} - of - {true, true} -> - ejabberd_s2s_out:terminate_if_waiting_delay(LTo, LFrom), - ejabberd_s2s_out:start(LTo, LFrom, - {verify, self(), Key, - StateData#state.streamid}), - Conns = (?DICT):store({LFrom, LTo}, - wait_for_verification, - StateData#state.connections), - change_shaper(StateData, LTo, - jid:make(<<"">>, LFrom, <<"">>)), - {next_state, stream_established, - StateData#state{connections = Conns, timer = Timer}}; - {_, false} -> - send_text(StateData, ?HOST_UNKNOWN_ERR), - {stop, normal, StateData}; - {false, _} -> - send_text(StateData, ?INVALID_FROM_ERR), - {stop, normal, StateData} - end; - {verify, To, From, Id, Key} -> - ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]), - LTo = jid:nameprep(To), - LFrom = jid:nameprep(From), - Type = case ejabberd_s2s:make_key({LTo, LFrom}, Id) of - Key -> <<"valid">>; - _ -> <<"invalid">> - end, - send_element(StateData, - #xmlel{name = <<"db:verify">>, - attrs = - [{<<"from">>, To}, {<<"to">>, From}, - {<<"id">>, Id}, {<<"type">>, Type}], - children = []}), - {next_state, stream_established, - StateData#state{timer = Timer}}; - _ -> - NewEl = jlib:remove_attr(<<"xmlns">>, El), - #xmlel{name = Name, attrs = Attrs} = NewEl, - From_s = fxml:get_attr_s(<<"from">>, Attrs), - From = jid:from_string(From_s), - To_s = fxml:get_attr_s(<<"to">>, Attrs), - To = jid:from_string(To_s), - if (To /= error) and (From /= error) -> - LFrom = From#jid.lserver, - LTo = To#jid.lserver, - if StateData#state.authenticated -> - case LFrom == StateData#state.auth_domain andalso - lists:member(LTo, - ejabberd_router:dirty_get_all_domains()) - of - true -> - if (Name == <<"iq">>) or (Name == <<"message">>) - or (Name == <<"presence">>) -> - ejabberd_hooks:run(s2s_receive_packet, LTo, - [From, To, NewEl]), - ejabberd_router:route(From, To, NewEl); - true -> error - end; - false -> error - end; - true -> - case (?DICT):find({LFrom, LTo}, - StateData#state.connections) - of - {ok, established} -> - if (Name == <<"iq">>) or (Name == <<"message">>) - or (Name == <<"presence">>) -> - ejabberd_hooks:run(s2s_receive_packet, LTo, - [From, To, NewEl]), - ejabberd_router:route(From, To, NewEl); - true -> error - end; - _ -> error - end - end; - true -> error - end, - ejabberd_hooks:run(s2s_loop_debug, - [{xmlstreamelement, El}]), - {next_state, stream_established, - StateData#state{timer = Timer}} + decode_element(El, stream_established, StateData#state{timer = Timer}); +stream_established(#db_result{to = To, from = From, key = Key}, + StateData) -> + ?DEBUG("GET KEY: ~p", [{To, From, Key}]), + case {ejabberd_s2s:allow_host(To, From), + lists:member(To, ejabberd_router:dirty_get_all_domains())} of + {true, true} -> + ejabberd_s2s_out:terminate_if_waiting_delay(To, From), + ejabberd_s2s_out:start(To, From, + {verify, self(), Key, + StateData#state.streamid}), + Conns = (?DICT):store({From, To}, + wait_for_verification, + StateData#state.connections), + change_shaper(StateData, To, jid:make(From)), + {next_state, stream_established, + StateData#state{connections = Conns}}; + {_, false} -> + send_element(StateData, xmpp:serr_host_unknown()), + {stop, normal, StateData}; + {false, _} -> + send_element(StateData, xmpp:serr_invalid_from()), + {stop, normal, StateData} end; +stream_established(#db_verify{to = To, from = From, id = Id, key = Key}, + StateData) -> + ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]), + Type = case ejabberd_s2s:make_key({To, From}, Id) of + Key -> valid; + _ -> invalid + end, + send_element(StateData, + #db_verify{from = To, to = From, id = Id, type = Type}), + {next_state, stream_established, StateData}; +stream_established(Pkt, StateData) when ?is_stanza(Pkt) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + if To /= undefined, From /= undefined -> + LFrom = From#jid.lserver, + LTo = To#jid.lserver, + if StateData#state.authenticated -> + case LFrom == StateData#state.auth_domain andalso + lists:member(LTo, ejabberd_router:dirty_get_all_domains()) of + true -> + ejabberd_hooks:run(s2s_receive_packet, LTo, + [From, To, Pkt]), + ejabberd_router:route(From, To, Pkt); + false -> + send_error(StateData, Pkt, xmpp:err_not_authorized()) + end; + true -> + case (?DICT):find({LFrom, LTo}, StateData#state.connections) of + {ok, established} -> + ejabberd_hooks:run(s2s_receive_packet, LTo, + [From, To, Pkt]), + ejabberd_router:route(From, To, Pkt); + _ -> + send_error(StateData, Pkt, xmpp:err_not_authorized()) + end + end; + true -> + send_error(StateData, Pkt, xmpp:err_jid_malformed()) + end, + ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]), + {next_state, stream_established, StateData}; stream_established({valid, From, To}, StateData) -> send_element(StateData, - #xmlel{name = <<"db:result">>, - attrs = - [{<<"from">>, To}, {<<"to">>, From}, - {<<"type">>, <<"valid">>}], - children = []}), + #db_result{from = To, to = From, type = valid}), ?INFO_MSG("Accepted s2s dialback authentication for ~s (TLS=~p)", [From, StateData#state.tls_enabled]), - LFrom = jid:nameprep(From), - LTo = jid:nameprep(To), NSD = StateData#state{connections = - (?DICT):store({LFrom, LTo}, established, + (?DICT):store({From, To}, established, StateData#state.connections)}, {next_state, stream_established, NSD}; stream_established({invalid, From, To}, StateData) -> send_element(StateData, - #xmlel{name = <<"db:result">>, - attrs = - [{<<"from">>, To}, {<<"to">>, From}, - {<<"type">>, <<"invalid">>}], - children = []}), - LFrom = jid:nameprep(From), - LTo = jid:nameprep(To), + #db_result{from = To, to = From, type = invalid}), NSD = StateData#state{connections = - (?DICT):erase({LFrom, LTo}, + (?DICT):erase({From, To}, StateData#state.connections)}, {next_state, stream_established, NSD}; stream_established({xmlstreamend, _Name}, StateData) -> {stop, normal, StateData}; stream_established({xmlstreamerror, _}, StateData) -> - send_text(StateData, - <<(?INVALID_XML_ERR)/binary, - (?STREAM_TRAILER)/binary>>), + send_element(StateData, xmpp:serr_not_well_formed()), {stop, normal, StateData}; stream_established(timeout, StateData) -> + send_element(StateData, xmpp:serr_connection_timeout()), {stop, normal, StateData}; stream_established(closed, StateData) -> - {stop, normal, StateData}. + {stop, normal, StateData}; +stream_established(Pkt, StateData) -> + ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]), + {next_state, stream_established, StateData}. %%---------------------------------------------------------------------- %% Func: StateName/3 @@ -589,8 +511,14 @@ code_change(_OldVsn, StateName, StateData, _Extra) -> handle_info({send_text, Text}, StateName, StateData) -> send_text(StateData, Text), {next_state, StateName, StateData}; -handle_info({timeout, Timer, _}, _StateName, +handle_info({timeout, Timer, _}, StateName, #state{timer = Timer} = StateData) -> + if StateName == wait_for_stream -> + send_header(StateData, undefined); + true -> + ok + end, + send_element(StateData, xmpp:serr_connection_timeout()), {stop, normal, StateData}; handle_info(_, StateName, StateData) -> {next_state, StateName, StateData}. @@ -603,6 +531,7 @@ terminate(Reason, _StateName, StateData) -> || Host <- get_external_hosts(StateData)]; _ -> ok end, + catch send_trailer(StateData), (StateData#state.sockmod):close(StateData#state.socket), ok. @@ -621,39 +550,55 @@ print_state(State) -> State. %%% Internal functions %%%---------------------------------------------------------------------- +-spec send_text(state(), iodata()) -> ok. send_text(StateData, Text) -> (StateData#state.sockmod):send(StateData#state.socket, Text). +-spec send_element(state(), xmpp_element()) -> ok. send_element(StateData, El) -> - send_text(StateData, fxml:element_to_binary(El)). + El1 = xmpp:encode(El, ?NS_SERVER), + send_text(StateData, fxml:element_to_binary(El1)). + +-spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok. +send_error(StateData, Stanza, Error) -> + Type = xmpp:get_type(Stanza), + if Type == error; Type == result; + Type == <<"error">>; Type == <<"result">> -> + ok; + true -> + send_element(StateData, xmpp:make_error(Stanza, Error)) + end. +-spec send_trailer(state()) -> ok. +send_trailer(StateData) -> + send_text(StateData, <<"</stream:stream>">>). + +-spec send_header(state(), undefined | {integer(), integer()}) -> ok. +send_header(StateData, Version) -> + Header = xmpp:encode( + #stream_start{xmlns = ?NS_SERVER, + stream_xmlns = ?NS_STREAM, + db_xmlns = ?NS_SERVER_DIALBACK, + id = StateData#state.streamid, + version = Version}), + send_text(StateData, fxml:element_to_header(Header)). + +-spec change_shaper(state(), binary(), jid()) -> ok. change_shaper(StateData, Host, JID) -> Shaper = acl:match_rule(Host, StateData#state.shaper, JID), (StateData#state.sockmod):change_shaper(StateData#state.socket, Shaper). +-spec new_id() -> binary(). new_id() -> randoms:get_string(). +-spec cancel_timer(reference()) -> ok. cancel_timer(Timer) -> erlang:cancel_timer(Timer), receive {timeout, Timer, _} -> ok after 0 -> ok end. -is_key_packet(#xmlel{name = Name, attrs = Attrs, - children = Els}) - when Name == <<"db:result">> -> - {key, fxml:get_attr_s(<<"to">>, Attrs), - fxml:get_attr_s(<<"from">>, Attrs), - fxml:get_attr_s(<<"id">>, Attrs), fxml:get_cdata(Els)}; -is_key_packet(#xmlel{name = Name, attrs = Attrs, - children = Els}) - when Name == <<"db:verify">> -> - {verify, fxml:get_attr_s(<<"to">>, Attrs), - fxml:get_attr_s(<<"from">>, Attrs), - fxml:get_attr_s(<<"id">>, Attrs), fxml:get_cdata(Els)}; -is_key_packet(_) -> false. - fsm_limit_opts(Opts) -> case lists:keysearch(max_fsm_queue, 1, Opts) of {value, {_, N}} when is_integer(N) -> [{max_queue, N}]; @@ -666,10 +611,34 @@ fsm_limit_opts(Opts) -> end end. +-spec decode_element(xmlel() | xmpp_element(), state_name(), state()) -> fsm_transition(). +decode_element(#xmlel{} = El, StateName, StateData) -> + Opts = if StateName == stream_established -> + [ignore_els]; + true -> + [] + end, + try xmpp:decode(El, ?NS_SERVER, Opts) of + Pkt -> ?MODULE:StateName(Pkt, StateData) + catch error:{xmpp_codec, Why} -> + case xmpp:is_stanza(El) of + true -> + Lang = xmpp:get_lang(El), + Txt = xmpp:format_error(Why), + send_error(StateData, El, xmpp:err_bad_request(Txt, Lang)); + false -> + ok + end, + {next_state, StateName, StateData} + end; +decode_element(Pkt, StateName, StateData) -> + ?MODULE:StateName(Pkt, StateData). + opt_type(domain_certfile) -> fun iolist_to_binary/1; opt_type(max_fsm_queue) -> fun (I) when is_integer(I), I > 0 -> I end; opt_type(s2s_certfile) -> fun iolist_to_binary/1; +opt_type(s2s_cafile) -> fun iolist_to_binary/1; opt_type(s2s_ciphers) -> fun iolist_to_binary/1; opt_type(s2s_dhfile) -> fun iolist_to_binary/1; opt_type(s2s_protocol_options) -> @@ -691,6 +660,6 @@ opt_type(s2s_use_starttls) -> (required_trusted) -> required_trusted end; opt_type(_) -> - [domain_certfile, max_fsm_queue, s2s_certfile, + [domain_certfile, max_fsm_queue, s2s_certfile, s2s_cafile, s2s_ciphers, s2s_dhfile, s2s_protocol_options, s2s_tls_compression, s2s_use_starttls]. |