aboutsummaryrefslogtreecommitdiff
path: root/src/ejabberd_s2s_in.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/ejabberd_s2s_in.erl')
-rw-r--r--src/ejabberd_s2s_in.erl673
1 files changed, 321 insertions, 352 deletions
diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl
index d8d0a400a..395a0fce7 100644
--- a/src/ejabberd_s2s_in.erl
+++ b/src/ejabberd_s2s_in.erl
@@ -42,7 +42,7 @@
-include("ejabberd.hrl").
-include("logger.hrl").
--include("jlib.hrl").
+-include("xmpp.hrl").
-define(DICT, dict).
@@ -62,40 +62,19 @@
connections = (?DICT):new() :: ?TDICT,
timer = make_ref() :: reference()}).
-%-define(DBGFSM, true).
+-type state_name() :: wait_for_stream | wait_for_feature_request | stream_established.
+-type state() :: #state{}.
+-type fsm_next() :: {next_state, state_name(), state()}.
+-type fsm_stop() :: {stop, normal, state()}.
+-type fsm_transition() :: fsm_stop() | fsm_next().
+%%-define(DBGFSM, true).
-ifdef(DBGFSM).
-
-define(FSMOPTS, [{debug, [trace]}]).
-
-else.
-
-define(FSMOPTS, []).
-
-endif.
--define(STREAM_HEADER(Version),
- <<"<?xml version='1.0'?><stream:stream "
- "xmlns:stream='http://etherx.jabber.org/stream"
- "s' xmlns='jabber:server' xmlns:db='jabber:ser"
- "ver:dialback' id='",
- (StateData#state.streamid)/binary, "'", Version/binary,
- ">">>).
-
--define(STREAM_TRAILER, <<"</stream:stream>">>).
-
--define(INVALID_NAMESPACE_ERR,
- fxml:element_to_binary(?SERR_INVALID_NAMESPACE)).
-
--define(HOST_UNKNOWN_ERR,
- fxml:element_to_binary(?SERR_HOST_UNKNOWN)).
-
--define(INVALID_FROM_ERR,
- fxml:element_to_binary(?SERR_INVALID_FROM)).
-
--define(INVALID_XML_ERR,
- fxml:element_to_binary(?SERR_XML_NOT_WELL_FORMED)).
-
start(SockData, Opts) ->
supervisor:start_child(ejabberd_s2s_in_sup,
[SockData, Opts]).
@@ -185,351 +164,294 @@ init([{SockMod, Socket}, Opts]) ->
%% {next_state, NextStateName, NextStateData, Timeout} |
%% {stop, Reason, NewStateData}
%%----------------------------------------------------------------------
-
-wait_for_stream({xmlstreamstart, _Name, Attrs},
- StateData) ->
- case {fxml:get_attr_s(<<"xmlns">>, Attrs),
- fxml:get_attr_s(<<"xmlns:db">>, Attrs),
- fxml:get_attr_s(<<"to">>, Attrs),
- fxml:get_attr_s(<<"version">>, Attrs) == <<"1.0">>}
- of
- {<<"jabber:server">>, _, Server, true}
- when StateData#state.tls and
- not StateData#state.authenticated ->
- send_text(StateData,
- ?STREAM_HEADER(<<" version='1.0'">>)),
- Auth = if StateData#state.tls_enabled ->
- case jid:nameprep(fxml:get_attr_s(<<"from">>, Attrs)) of
- From when From /= <<"">>, From /= error ->
- {Result, Message} =
- ejabberd_s2s:check_peer_certificate(StateData#state.sockmod,
- StateData#state.socket,
- From),
- {Result, From, Message};
- _ ->
- {error, <<"(unknown)">>,
- <<"Got no valid 'from' attribute">>}
- end;
- true ->
- {no_verify, <<"(unknown)">>,
- <<"TLS not (yet) enabled">>}
- end,
- StartTLS = if StateData#state.tls_enabled -> [];
- not StateData#state.tls_enabled and
+wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
+ try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
+ #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM}
+ when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM ->
+ send_header(StateData, {1,0}),
+ send_element(StateData, xmpp:serr_invalid_namespace()),
+ {stop, normal, StateData};
+ #stream_start{to = #jid{lserver = Server},
+ from = From, version = {1,0}}
+ when StateData#state.tls and not StateData#state.authenticated ->
+ send_header(StateData, {1,0}),
+ Auth = if StateData#state.tls_enabled ->
+ case From of
+ #jid{} ->
+ {Result, Message} =
+ ejabberd_s2s:check_peer_certificate(
+ StateData#state.sockmod,
+ StateData#state.socket,
+ From#jid.lserver),
+ {Result, From#jid.lserver, Message};
+ undefined ->
+ {error, <<"(unknown)">>,
+ <<"Got no valid 'from' attribute">>}
+ end;
+ true ->
+ {no_verify, <<"(unknown)">>, <<"TLS not (yet) enabled">>}
+ end,
+ StartTLS = if StateData#state.tls_enabled -> [];
+ not StateData#state.tls_enabled and
not StateData#state.tls_required ->
- [#xmlel{name = <<"starttls">>,
- attrs = [{<<"xmlns">>, ?NS_TLS}],
- children = []}];
- not StateData#state.tls_enabled and
+ [#starttls{required = false}];
+ not StateData#state.tls_enabled and
StateData#state.tls_required ->
- [#xmlel{name = <<"starttls">>,
- attrs = [{<<"xmlns">>, ?NS_TLS}],
- children =
- [#xmlel{name = <<"required">>,
- attrs = [], children = []}]}]
- end,
- case Auth of
- {error, RemoteServer, CertError}
- when StateData#state.tls_certverify ->
- ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)",
- [StateData#state.server, RemoteServer, CertError]),
- send_text(StateData,
- <<(fxml:element_to_binary(?SERRT_POLICY_VIOLATION(<<"en">>,
- CertError)))/binary,
- (?STREAM_TRAILER)/binary>>),
- {stop, normal, StateData};
- {VerifyResult, RemoteServer, Msg} ->
- {SASL, NewStateData} = case VerifyResult of
- ok ->
- {[#xmlel{name = <<"mechanisms">>,
- attrs = [{<<"xmlns">>, ?NS_SASL}],
- children =
- [#xmlel{name = <<"mechanism">>,
- attrs = [],
- children =
- [{xmlcdata,
- <<"EXTERNAL">>}]}]}],
- StateData#state{auth_domain = RemoteServer}};
- error ->
- ?DEBUG("Won't accept certificate of ~s: ~s",
- [RemoteServer, Msg]),
- {[], StateData};
- no_verify ->
- {[], StateData}
- end,
- send_element(NewStateData,
- #xmlel{name = <<"stream:features">>, attrs = [],
- children =
- SASL ++
- StartTLS ++
- ejabberd_hooks:run_fold(s2s_stream_features,
- Server, [],
- [Server])}),
- {next_state, wait_for_feature_request,
- NewStateData#state{server = Server}}
- end;
- {<<"jabber:server">>, _, Server, true}
- when StateData#state.authenticated ->
- send_text(StateData,
- ?STREAM_HEADER(<<" version='1.0'">>)),
- send_element(StateData,
- #xmlel{name = <<"stream:features">>, attrs = [],
- children =
- ejabberd_hooks:run_fold(s2s_stream_features,
- Server, [],
- [Server])}),
- {next_state, stream_established, StateData};
- {<<"jabber:server">>, <<"jabber:server:dialback">>,
- _Server, _} when
- (StateData#state.tls_required and StateData#state.tls_enabled)
- or (not StateData#state.tls_required) ->
- send_text(StateData, ?STREAM_HEADER(<<"">>)),
- {next_state, stream_established, StateData};
- _ ->
- send_text(StateData, ?INVALID_NAMESPACE_ERR),
- {stop, normal, StateData}
+ [#starttls{required = true}]
+ end,
+ case Auth of
+ {error, RemoteServer, CertError}
+ when StateData#state.tls_certverify ->
+ ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)",
+ [StateData#state.server, RemoteServer, CertError]),
+ send_element(StateData,
+ xmpp:serr_policy_violation(CertError, ?MYLANG)),
+ {stop, normal, StateData};
+ {VerifyResult, RemoteServer, Msg} ->
+ {SASL, NewStateData} =
+ case VerifyResult of
+ ok ->
+ {[#sasl_mechanisms{list = [<<"EXTERNAL">>]}],
+ StateData#state{auth_domain = RemoteServer}};
+ error ->
+ ?DEBUG("Won't accept certificate of ~s: ~s",
+ [RemoteServer, Msg]),
+ {[], StateData};
+ no_verify ->
+ {[], StateData}
+ end,
+ send_element(NewStateData,
+ #stream_features{
+ sub_els = SASL ++ StartTLS ++
+ ejabberd_hooks:run_fold(
+ s2s_stream_features, Server, [],
+ [Server])}),
+ {next_state, wait_for_feature_request,
+ NewStateData#state{server = Server}}
+ end;
+ #stream_start{to = #jid{lserver = Server},
+ version = {1,0}} when StateData#state.authenticated ->
+ send_header(StateData, {1,0}),
+ send_element(StateData,
+ #stream_features{
+ sub_els = ejabberd_hooks:run_fold(
+ s2s_stream_features, Server, [],
+ [Server])}),
+ {next_state, stream_established, StateData};
+ #stream_start{db_xmlns = ?NS_SERVER_DIALBACK}
+ when (StateData#state.tls_required and StateData#state.tls_enabled)
+ or (not StateData#state.tls_required) ->
+ send_header(StateData, undefined),
+ {next_state, stream_established, StateData};
+ #stream_start{} ->
+ send_header(StateData, {1,0}),
+ send_element(StateData, xmpp:serr_undefined_condition()),
+ {stop, normal, StateData};
+ _ ->
+ send_header(StateData, {1,0}),
+ send_element(StateData, xmpp:serr_invalid_xml()),
+ {stop, normal, StateData}
+ catch _:{xmpp_codec, Why} ->
+ Txt = xmpp:format_error(Why),
+ send_header(StateData, {1,0}),
+ send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)),
+ {stop, normal, StateData}
end;
wait_for_stream({xmlstreamerror, _}, StateData) ->
- send_text(StateData,
- <<(?STREAM_HEADER(<<"">>))/binary,
- (?INVALID_XML_ERR)/binary, (?STREAM_TRAILER)/binary>>),
+ send_header(StateData, {1,0}),
+ send_element(StateData, xmpp:serr_not_well_formed()),
{stop, normal, StateData};
wait_for_stream(timeout, StateData) ->
+ send_header(StateData, {1,0}),
+ send_element(StateData, xmpp:serr_connection_timeout()),
{stop, normal, StateData};
wait_for_stream(closed, StateData) ->
{stop, normal, StateData}.
-wait_for_feature_request({xmlstreamelement, El},
- StateData) ->
- #xmlel{name = Name, attrs = Attrs} = El,
- TLS = StateData#state.tls,
- TLSEnabled = StateData#state.tls_enabled,
- SockMod =
- (StateData#state.sockmod):get_sockmod(StateData#state.socket),
- case {fxml:get_attr_s(<<"xmlns">>, Attrs), Name} of
- {?NS_TLS, <<"starttls">>}
- when TLS == true, TLSEnabled == false,
- SockMod == gen_tcp ->
- ?DEBUG("starttls", []),
- Socket = StateData#state.socket,
- TLSOpts1 = case
- ejabberd_config:get_option(
- {domain_certfile, StateData#state.server},
- fun iolist_to_binary/1) of
- undefined -> StateData#state.tls_options;
- CertFile ->
- [{certfile, CertFile} | lists:keydelete(certfile, 1,
- StateData#state.tls_options)]
- end,
- TLSOpts = case ejabberd_config:get_option(
- {s2s_tls_compression, StateData#state.server},
- fun(true) -> true;
- (false) -> false
- end, false) of
- true -> lists:delete(compression_none, TLSOpts1);
- false -> [compression_none | TLSOpts1]
- end,
- TLSSocket = (StateData#state.sockmod):starttls(Socket,
- TLSOpts,
- fxml:element_to_binary(#xmlel{name
- =
- <<"proceed">>,
- attrs
- =
- [{<<"xmlns">>,
- ?NS_TLS}],
- children
- =
- []})),
- {next_state, wait_for_stream,
- StateData#state{socket = TLSSocket, streamid = new_id(),
- tls_enabled = true, tls_options = TLSOpts}};
- {?NS_SASL, <<"auth">>} when TLSEnabled ->
- Mech = fxml:get_attr_s(<<"mechanism">>, Attrs),
- case Mech of
- <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> ->
- AuthDomain = StateData#state.auth_domain,
- AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>,
- AuthDomain),
- if AllowRemoteHost ->
- (StateData#state.sockmod):reset_stream(StateData#state.socket),
- send_element(StateData,
- #xmlel{name = <<"success">>,
- attrs = [{<<"xmlns">>, ?NS_SASL}],
- children = []}),
- ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)",
- [AuthDomain, StateData#state.tls_enabled]),
- change_shaper(StateData, <<"">>,
- jid:make(<<"">>, AuthDomain, <<"">>)),
- {next_state, wait_for_stream,
- StateData#state{streamid = new_id(),
- authenticated = true}};
- true ->
- send_element(StateData,
- #xmlel{name = <<"failure">>,
- attrs = [{<<"xmlns">>, ?NS_SASL}],
- children = []}),
- send_text(StateData, ?STREAM_TRAILER),
- {stop, normal, StateData}
- end;
- _ ->
- send_element(StateData,
- #xmlel{name = <<"failure">>,
- attrs = [{<<"xmlns">>, ?NS_SASL}],
- children =
- [#xmlel{name = <<"invalid-mechanism">>,
- attrs = [], children = []}]}),
- {stop, normal, StateData}
- end;
- _ ->
- stream_established({xmlstreamelement, El}, StateData)
+wait_for_feature_request({xmlstreamelement, El}, StateData) ->
+ decode_element(El, wait_for_feature_request, StateData);
+wait_for_feature_request(#starttls{},
+ #state{tls = true, tls_enabled = false} = StateData) ->
+ case (StateData#state.sockmod):get_sockmod(StateData#state.socket) of
+ gen_tcp ->
+ ?DEBUG("starttls", []),
+ Socket = StateData#state.socket,
+ TLSOpts1 = case
+ ejabberd_config:get_option(
+ {domain_certfile, StateData#state.server},
+ fun iolist_to_binary/1) of
+ undefined -> StateData#state.tls_options;
+ CertFile ->
+ lists:keystore(certfile, 1,
+ StateData#state.tls_options,
+ {certfile, CertFile})
+ end,
+ TLSOpts2 = case ejabberd_config:get_option(
+ {s2s_cafile, StateData#state.server},
+ fun iolist_to_binary/1) of
+ undefined -> TLSOpts1;
+ CAFile ->
+ lists:keystore(cafile, 1, TLSOpts1,
+ {cafile, CAFile})
+ end,
+ TLSOpts = case ejabberd_config:get_option(
+ {s2s_tls_compression, StateData#state.server},
+ fun(true) -> true;
+ (false) -> false
+ end, false) of
+ true -> lists:delete(compression_none, TLSOpts2);
+ false -> [compression_none | TLSOpts2]
+ end,
+ TLSSocket = (StateData#state.sockmod):starttls(
+ Socket, TLSOpts,
+ fxml:element_to_binary(
+ xmpp:encode(#starttls_proceed{}))),
+ {next_state, wait_for_stream,
+ StateData#state{socket = TLSSocket, streamid = new_id(),
+ tls_enabled = true, tls_options = TLSOpts}};
+ _ ->
+ send_element(StateData, #starttls_failure{}),
+ {stop, normal, StateData}
end;
-wait_for_feature_request({xmlstreamend, _Name},
- StateData) ->
- send_text(StateData, ?STREAM_TRAILER),
+wait_for_feature_request(#sasl_auth{mechanism = Mech},
+ #state{tls_enabled = true} = StateData) ->
+ case Mech of
+ <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> ->
+ AuthDomain = StateData#state.auth_domain,
+ AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>, AuthDomain),
+ if AllowRemoteHost ->
+ (StateData#state.sockmod):reset_stream(StateData#state.socket),
+ send_element(StateData, #sasl_success{}),
+ ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)",
+ [AuthDomain, StateData#state.tls_enabled]),
+ change_shaper(StateData, <<"">>, jid:make(AuthDomain)),
+ {next_state, wait_for_stream,
+ StateData#state{streamid = new_id(),
+ authenticated = true}};
+ true ->
+ Txt = xmpp:mk_text(<<"Denied by ACL">>, ?MYLANG),
+ send_element(StateData,
+ #sasl_failure{reason = 'not-authorized',
+ text = Txt}),
+ {stop, normal, StateData}
+ end;
+ _ ->
+ send_element(StateData, #sasl_failure{reason = 'invalid-mechanism'}),
+ {stop, normal, StateData}
+ end;
+wait_for_feature_request({xmlstreamend, _Name}, StateData) ->
{stop, normal, StateData};
-wait_for_feature_request({xmlstreamerror, _},
- StateData) ->
- send_text(StateData,
- <<(?INVALID_XML_ERR)/binary,
- (?STREAM_TRAILER)/binary>>),
+wait_for_feature_request({xmlstreamerror, _}, StateData) ->
+ send_element(StateData, xmpp:serr_not_well_formed()),
{stop, normal, StateData};
wait_for_feature_request(closed, StateData) ->
- {stop, normal, StateData}.
+ {stop, normal, StateData};
+wait_for_feature_request(_Pkt, #state{tls_required = TLSRequired,
+ tls_enabled = TLSEnabled} = StateData)
+ when TLSRequired and not TLSEnabled ->
+ Txt = <<"Use of STARTTLS required">>,
+ send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)),
+ {stop, normal, StateData};
+wait_for_feature_request(El, StateData) ->
+ stream_established({xmlstreamelement, El}, StateData).
stream_established({xmlstreamelement, El}, StateData) ->
cancel_timer(StateData#state.timer),
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
- case is_key_packet(El) of
- {key, To, From, Id, Key} ->
- ?DEBUG("GET KEY: ~p", [{To, From, Id, Key}]),
- LTo = jid:nameprep(To),
- LFrom = jid:nameprep(From),
- case {ejabberd_s2s:allow_host(LTo, LFrom),
- lists:member(LTo,
- ejabberd_router:dirty_get_all_domains())}
- of
- {true, true} ->
- ejabberd_s2s_out:terminate_if_waiting_delay(LTo, LFrom),
- ejabberd_s2s_out:start(LTo, LFrom,
- {verify, self(), Key,
- StateData#state.streamid}),
- Conns = (?DICT):store({LFrom, LTo},
- wait_for_verification,
- StateData#state.connections),
- change_shaper(StateData, LTo,
- jid:make(<<"">>, LFrom, <<"">>)),
- {next_state, stream_established,
- StateData#state{connections = Conns, timer = Timer}};
- {_, false} ->
- send_text(StateData, ?HOST_UNKNOWN_ERR),
- {stop, normal, StateData};
- {false, _} ->
- send_text(StateData, ?INVALID_FROM_ERR),
- {stop, normal, StateData}
- end;
- {verify, To, From, Id, Key} ->
- ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]),
- LTo = jid:nameprep(To),
- LFrom = jid:nameprep(From),
- Type = case ejabberd_s2s:make_key({LTo, LFrom}, Id) of
- Key -> <<"valid">>;
- _ -> <<"invalid">>
- end,
- send_element(StateData,
- #xmlel{name = <<"db:verify">>,
- attrs =
- [{<<"from">>, To}, {<<"to">>, From},
- {<<"id">>, Id}, {<<"type">>, Type}],
- children = []}),
- {next_state, stream_established,
- StateData#state{timer = Timer}};
- _ ->
- NewEl = jlib:remove_attr(<<"xmlns">>, El),
- #xmlel{name = Name, attrs = Attrs} = NewEl,
- From_s = fxml:get_attr_s(<<"from">>, Attrs),
- From = jid:from_string(From_s),
- To_s = fxml:get_attr_s(<<"to">>, Attrs),
- To = jid:from_string(To_s),
- if (To /= error) and (From /= error) ->
- LFrom = From#jid.lserver,
- LTo = To#jid.lserver,
- if StateData#state.authenticated ->
- case LFrom == StateData#state.auth_domain andalso
- lists:member(LTo,
- ejabberd_router:dirty_get_all_domains())
- of
- true ->
- if (Name == <<"iq">>) or (Name == <<"message">>)
- or (Name == <<"presence">>) ->
- ejabberd_hooks:run(s2s_receive_packet, LTo,
- [From, To, NewEl]),
- ejabberd_router:route(From, To, NewEl);
- true -> error
- end;
- false -> error
- end;
- true ->
- case (?DICT):find({LFrom, LTo},
- StateData#state.connections)
- of
- {ok, established} ->
- if (Name == <<"iq">>) or (Name == <<"message">>)
- or (Name == <<"presence">>) ->
- ejabberd_hooks:run(s2s_receive_packet, LTo,
- [From, To, NewEl]),
- ejabberd_router:route(From, To, NewEl);
- true -> error
- end;
- _ -> error
- end
- end;
- true -> error
- end,
- ejabberd_hooks:run(s2s_loop_debug,
- [{xmlstreamelement, El}]),
- {next_state, stream_established,
- StateData#state{timer = Timer}}
+ decode_element(El, stream_established, StateData#state{timer = Timer});
+stream_established(#db_result{to = To, from = From, key = Key},
+ StateData) ->
+ ?DEBUG("GET KEY: ~p", [{To, From, Key}]),
+ case {ejabberd_s2s:allow_host(To, From),
+ lists:member(To, ejabberd_router:dirty_get_all_domains())} of
+ {true, true} ->
+ ejabberd_s2s_out:terminate_if_waiting_delay(To, From),
+ ejabberd_s2s_out:start(To, From,
+ {verify, self(), Key,
+ StateData#state.streamid}),
+ Conns = (?DICT):store({From, To},
+ wait_for_verification,
+ StateData#state.connections),
+ change_shaper(StateData, To, jid:make(From)),
+ {next_state, stream_established,
+ StateData#state{connections = Conns}};
+ {_, false} ->
+ send_element(StateData, xmpp:serr_host_unknown()),
+ {stop, normal, StateData};
+ {false, _} ->
+ send_element(StateData, xmpp:serr_invalid_from()),
+ {stop, normal, StateData}
end;
+stream_established(#db_verify{to = To, from = From, id = Id, key = Key},
+ StateData) ->
+ ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]),
+ Type = case ejabberd_s2s:make_key({To, From}, Id) of
+ Key -> valid;
+ _ -> invalid
+ end,
+ send_element(StateData,
+ #db_verify{from = To, to = From, id = Id, type = Type}),
+ {next_state, stream_established, StateData};
+stream_established(Pkt, StateData) when ?is_stanza(Pkt) ->
+ From = xmpp:get_from(Pkt),
+ To = xmpp:get_to(Pkt),
+ if To /= undefined, From /= undefined ->
+ LFrom = From#jid.lserver,
+ LTo = To#jid.lserver,
+ if StateData#state.authenticated ->
+ case LFrom == StateData#state.auth_domain andalso
+ lists:member(LTo, ejabberd_router:dirty_get_all_domains()) of
+ true ->
+ ejabberd_hooks:run(s2s_receive_packet, LTo,
+ [From, To, Pkt]),
+ ejabberd_router:route(From, To, Pkt);
+ false ->
+ send_error(StateData, Pkt, xmpp:err_not_authorized())
+ end;
+ true ->
+ case (?DICT):find({LFrom, LTo}, StateData#state.connections) of
+ {ok, established} ->
+ ejabberd_hooks:run(s2s_receive_packet, LTo,
+ [From, To, Pkt]),
+ ejabberd_router:route(From, To, Pkt);
+ _ ->
+ send_error(StateData, Pkt, xmpp:err_not_authorized())
+ end
+ end;
+ true ->
+ send_error(StateData, Pkt, xmpp:err_jid_malformed())
+ end,
+ ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
+ {next_state, stream_established, StateData};
stream_established({valid, From, To}, StateData) ->
send_element(StateData,
- #xmlel{name = <<"db:result">>,
- attrs =
- [{<<"from">>, To}, {<<"to">>, From},
- {<<"type">>, <<"valid">>}],
- children = []}),
+ #db_result{from = To, to = From, type = valid}),
?INFO_MSG("Accepted s2s dialback authentication for ~s (TLS=~p)",
[From, StateData#state.tls_enabled]),
- LFrom = jid:nameprep(From),
- LTo = jid:nameprep(To),
NSD = StateData#state{connections =
- (?DICT):store({LFrom, LTo}, established,
+ (?DICT):store({From, To}, established,
StateData#state.connections)},
{next_state, stream_established, NSD};
stream_established({invalid, From, To}, StateData) ->
send_element(StateData,
- #xmlel{name = <<"db:result">>,
- attrs =
- [{<<"from">>, To}, {<<"to">>, From},
- {<<"type">>, <<"invalid">>}],
- children = []}),
- LFrom = jid:nameprep(From),
- LTo = jid:nameprep(To),
+ #db_result{from = To, to = From, type = invalid}),
NSD = StateData#state{connections =
- (?DICT):erase({LFrom, LTo},
+ (?DICT):erase({From, To},
StateData#state.connections)},
{next_state, stream_established, NSD};
stream_established({xmlstreamend, _Name}, StateData) ->
{stop, normal, StateData};
stream_established({xmlstreamerror, _}, StateData) ->
- send_text(StateData,
- <<(?INVALID_XML_ERR)/binary,
- (?STREAM_TRAILER)/binary>>),
+ send_element(StateData, xmpp:serr_not_well_formed()),
{stop, normal, StateData};
stream_established(timeout, StateData) ->
+ send_element(StateData, xmpp:serr_connection_timeout()),
{stop, normal, StateData};
stream_established(closed, StateData) ->
- {stop, normal, StateData}.
+ {stop, normal, StateData};
+stream_established(Pkt, StateData) ->
+ ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
+ {next_state, stream_established, StateData}.
%%----------------------------------------------------------------------
%% Func: StateName/3
@@ -589,8 +511,14 @@ code_change(_OldVsn, StateName, StateData, _Extra) ->
handle_info({send_text, Text}, StateName, StateData) ->
send_text(StateData, Text),
{next_state, StateName, StateData};
-handle_info({timeout, Timer, _}, _StateName,
+handle_info({timeout, Timer, _}, StateName,
#state{timer = Timer} = StateData) ->
+ if StateName == wait_for_stream ->
+ send_header(StateData, undefined);
+ true ->
+ ok
+ end,
+ send_element(StateData, xmpp:serr_connection_timeout()),
{stop, normal, StateData};
handle_info(_, StateName, StateData) ->
{next_state, StateName, StateData}.
@@ -603,6 +531,7 @@ terminate(Reason, _StateName, StateData) ->
|| Host <- get_external_hosts(StateData)];
_ -> ok
end,
+ catch send_trailer(StateData),
(StateData#state.sockmod):close(StateData#state.socket),
ok.
@@ -621,39 +550,55 @@ print_state(State) -> State.
%%% Internal functions
%%%----------------------------------------------------------------------
+-spec send_text(state(), iodata()) -> ok.
send_text(StateData, Text) ->
(StateData#state.sockmod):send(StateData#state.socket,
Text).
+-spec send_element(state(), xmpp_element()) -> ok.
send_element(StateData, El) ->
- send_text(StateData, fxml:element_to_binary(El)).
+ El1 = xmpp:encode(El, ?NS_SERVER),
+ send_text(StateData, fxml:element_to_binary(El1)).
+
+-spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok.
+send_error(StateData, Stanza, Error) ->
+ Type = xmpp:get_type(Stanza),
+ if Type == error; Type == result;
+ Type == <<"error">>; Type == <<"result">> ->
+ ok;
+ true ->
+ send_element(StateData, xmpp:make_error(Stanza, Error))
+ end.
+-spec send_trailer(state()) -> ok.
+send_trailer(StateData) ->
+ send_text(StateData, <<"</stream:stream>">>).
+
+-spec send_header(state(), undefined | {integer(), integer()}) -> ok.
+send_header(StateData, Version) ->
+ Header = xmpp:encode(
+ #stream_start{xmlns = ?NS_SERVER,
+ stream_xmlns = ?NS_STREAM,
+ db_xmlns = ?NS_SERVER_DIALBACK,
+ id = StateData#state.streamid,
+ version = Version}),
+ send_text(StateData, fxml:element_to_header(Header)).
+
+-spec change_shaper(state(), binary(), jid()) -> ok.
change_shaper(StateData, Host, JID) ->
Shaper = acl:match_rule(Host, StateData#state.shaper,
JID),
(StateData#state.sockmod):change_shaper(StateData#state.socket,
Shaper).
+-spec new_id() -> binary().
new_id() -> randoms:get_string().
+-spec cancel_timer(reference()) -> ok.
cancel_timer(Timer) ->
erlang:cancel_timer(Timer),
receive {timeout, Timer, _} -> ok after 0 -> ok end.
-is_key_packet(#xmlel{name = Name, attrs = Attrs,
- children = Els})
- when Name == <<"db:result">> ->
- {key, fxml:get_attr_s(<<"to">>, Attrs),
- fxml:get_attr_s(<<"from">>, Attrs),
- fxml:get_attr_s(<<"id">>, Attrs), fxml:get_cdata(Els)};
-is_key_packet(#xmlel{name = Name, attrs = Attrs,
- children = Els})
- when Name == <<"db:verify">> ->
- {verify, fxml:get_attr_s(<<"to">>, Attrs),
- fxml:get_attr_s(<<"from">>, Attrs),
- fxml:get_attr_s(<<"id">>, Attrs), fxml:get_cdata(Els)};
-is_key_packet(_) -> false.
-
fsm_limit_opts(Opts) ->
case lists:keysearch(max_fsm_queue, 1, Opts) of
{value, {_, N}} when is_integer(N) -> [{max_queue, N}];
@@ -666,10 +611,34 @@ fsm_limit_opts(Opts) ->
end
end.
+-spec decode_element(xmlel() | xmpp_element(), state_name(), state()) -> fsm_transition().
+decode_element(#xmlel{} = El, StateName, StateData) ->
+ Opts = if StateName == stream_established ->
+ [ignore_els];
+ true ->
+ []
+ end,
+ try xmpp:decode(El, ?NS_SERVER, Opts) of
+ Pkt -> ?MODULE:StateName(Pkt, StateData)
+ catch error:{xmpp_codec, Why} ->
+ case xmpp:is_stanza(El) of
+ true ->
+ Lang = xmpp:get_lang(El),
+ Txt = xmpp:format_error(Why),
+ send_error(StateData, El, xmpp:err_bad_request(Txt, Lang));
+ false ->
+ ok
+ end,
+ {next_state, StateName, StateData}
+ end;
+decode_element(Pkt, StateName, StateData) ->
+ ?MODULE:StateName(Pkt, StateData).
+
opt_type(domain_certfile) -> fun iolist_to_binary/1;
opt_type(max_fsm_queue) ->
fun (I) when is_integer(I), I > 0 -> I end;
opt_type(s2s_certfile) -> fun iolist_to_binary/1;
+opt_type(s2s_cafile) -> fun iolist_to_binary/1;
opt_type(s2s_ciphers) -> fun iolist_to_binary/1;
opt_type(s2s_dhfile) -> fun iolist_to_binary/1;
opt_type(s2s_protocol_options) ->
@@ -691,6 +660,6 @@ opt_type(s2s_use_starttls) ->
(required_trusted) -> required_trusted
end;
opt_type(_) ->
- [domain_certfile, max_fsm_queue, s2s_certfile,
+ [domain_certfile, max_fsm_queue, s2s_certfile, s2s_cafile,
s2s_ciphers, s2s_dhfile, s2s_protocol_options,
s2s_tls_compression, s2s_use_starttls].