diff options
author | Evgeniy Khramtsov <ekhramtsov@process-one.net> | 2016-09-23 12:30:33 +0300 |
---|---|---|
committer | Evgeniy Khramtsov <ekhramtsov@process-one.net> | 2016-09-23 12:30:33 +0300 |
commit | 53209b9ab1c154334eafacd3ca9aebe965380d50 (patch) | |
tree | 88c5a4b5168d8e0e3813fb19f0c26cce0680da03 /src/ejabberd_s2s_in.erl | |
parent | Add tests for external component (diff) |
Add tests for s2s code
Diffstat (limited to 'src/ejabberd_s2s_in.erl')
-rw-r--r-- | src/ejabberd_s2s_in.erl | 101 |
1 files changed, 64 insertions, 37 deletions
diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index fd560a45..6d1791d0 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -168,21 +168,26 @@ wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM} when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM -> - send_header(StateData, <<" version='1.0'">>), + send_header(StateData, {1,0}), send_element(StateData, xmpp:serr_invalid_namespace()), {stop, normal, StateData}; #stream_start{to = #jid{lserver = Server}, - from = #jid{lserver = From}, - version = <<"1.0">>} + from = From, version = {1,0}} when StateData#state.tls and not StateData#state.authenticated -> - send_header(StateData, <<" version='1.0'">>), + send_header(StateData, {1,0}), Auth = if StateData#state.tls_enabled -> - {Result, Message} = - ejabberd_s2s:check_peer_certificate( - StateData#state.sockmod, - StateData#state.socket, - From), - {Result, From, Message}; + case From of + #jid{} -> + {Result, Message} = + ejabberd_s2s:check_peer_certificate( + StateData#state.sockmod, + StateData#state.socket, + From#jid.lserver), + {Result, From#jid.lserver, Message}; + undefined -> + {error, <<"(unknown)">>, + <<"Got no valid 'from' attribute">>} + end; true -> {no_verify, <<"(unknown)">>, <<"TLS not (yet) enabled">>} end, @@ -225,8 +230,8 @@ wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> NewStateData#state{server = Server}} end; #stream_start{to = #jid{lserver = Server}, - version = <<"1.0">>} when StateData#state.authenticated -> - send_header(StateData, <<" version='1.0'">>), + version = {1,0}} when StateData#state.authenticated -> + send_header(StateData, {1,0}), send_element(StateData, #stream_features{ sub_els = ejabberd_hooks:run_fold( @@ -236,24 +241,28 @@ wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> #stream_start{db_xmlns = ?NS_SERVER_DIALBACK} when (StateData#state.tls_required and StateData#state.tls_enabled) or (not StateData#state.tls_required) -> - send_header(StateData, <<"">>), + send_header(StateData, undefined), {next_state, stream_established, StateData}; #stream_start{} -> - send_header(StateData, <<" version='1.0'">>), + send_header(StateData, {1,0}), send_element(StateData, xmpp:serr_undefined_condition()), - {stop, normal, StateData} + {stop, normal, StateData}; + _ -> + send_header(StateData, {1,0}), + send_element(StateData, xmpp:serr_invalid_xml()), + {stop, normal, StateData} catch _:{xmpp_codec, Why} -> Txt = xmpp:format_error(Why), - send_header(StateData, <<" version='1.0'">>), - send_element(StateData, xmpp:serr_not_well_formed(Txt, ?MYLANG)), + send_header(StateData, {1,0}), + send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), {stop, normal, StateData} end; wait_for_stream({xmlstreamerror, _}, StateData) -> - send_header(StateData, <<"">>), + send_header(StateData, {1,0}), send_element(StateData, xmpp:serr_not_well_formed()), {stop, normal, StateData}; wait_for_stream(timeout, StateData) -> - send_header(StateData, <<"">>), + send_header(StateData, {1,0}), send_element(StateData, xmpp:serr_connection_timeout()), {stop, normal, StateData}; wait_for_stream(closed, StateData) -> @@ -277,13 +286,21 @@ wait_for_feature_request(#starttls{}, StateData#state.tls_options, {certfile, CertFile}) end, + TLSOpts2 = case ejabberd_config:get_option( + {s2s_cafile, StateData#state.server}, + fun iolist_to_binary/1) of + undefined -> TLSOpts1; + CAFile -> + lists:keystore(cafile, 1, TLSOpts1, + {cafile, CAFile}) + end, TLSOpts = case ejabberd_config:get_option( {s2s_tls_compression, StateData#state.server}, fun(true) -> true; (false) -> false end, false) of - true -> lists:delete(compression_none, TLSOpts1); - false -> [compression_none | TLSOpts1] + true -> lists:delete(compression_none, TLSOpts2); + false -> [compression_none | TLSOpts2] end, TLSSocket = (StateData#state.sockmod):starttls( Socket, TLSOpts, @@ -293,8 +310,7 @@ wait_for_feature_request(#starttls{}, StateData#state{socket = TLSSocket, streamid = new_id(), tls_enabled = true, tls_options = TLSOpts}}; _ -> - Txt = <<"Unsupported TLS transport">>, - send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)), + send_element(StateData, #starttls_failure{}), {stop, normal, StateData} end; wait_for_feature_request(#sasl_auth{mechanism = Mech}, @@ -313,7 +329,10 @@ wait_for_feature_request(#sasl_auth{mechanism = Mech}, StateData#state{streamid = new_id(), authenticated = true}}; true -> - send_element(StateData, #sasl_failure{}), + Txt = xmpp:mk_text(<<"Denied by ACL">>, ?MYLANG), + send_element(StateData, + #sasl_failure{reason = 'not-authorized', + text = Txt}), {stop, normal, StateData} end; _ -> @@ -495,7 +514,7 @@ handle_info({send_text, Text}, StateName, StateData) -> handle_info({timeout, Timer, _}, StateName, #state{timer = Timer} = StateData) -> if StateName == wait_for_stream -> - send_header(StateData, <<"">>); + send_header(StateData, undefined); true -> ok end, @@ -555,15 +574,15 @@ send_error(StateData, Stanza, Error) -> send_trailer(StateData) -> send_text(StateData, <<"</stream:stream>">>). --spec send_header(state(), binary()) -> ok. +-spec send_header(state(), undefined | {integer(), integer()}) -> ok. send_header(StateData, Version) -> - send_text(StateData, - <<"<?xml version='1.0'?><stream:stream " - "xmlns:stream='http://etherx.jabber.org/stream" - "s' xmlns='jabber:server' xmlns:db='jabber:ser" - "ver:dialback' id='", - (StateData#state.streamid)/binary, "'", Version/binary, - ">">>). + Header = xmpp:encode( + #stream_start{xmlns = ?NS_SERVER, + stream_xmlns = ?NS_STREAM, + db_xmlns = ?NS_SERVER_DIALBACK, + id = StateData#state.streamid, + version = Version}), + send_text(StateData, fxml:element_to_header(Header)). -spec change_shaper(state(), binary(), jid()) -> ok. change_shaper(StateData, Host, JID) -> @@ -606,9 +625,14 @@ fsm_limit_opts(Opts) -> end end. --spec decode_element(xmlel(), state_name(), state()) -> fsm_transition(). +-spec decode_element(xmlel() | xmpp_element(), state_name(), state()) -> fsm_transition(). decode_element(#xmlel{} = El, StateName, StateData) -> - try xmpp:decode(El) of + Opts = if StateName == stream_established -> + [ignore_els]; + true -> + [] + end, + try xmpp:decode(El, Opts) of Pkt -> ?MODULE:StateName(Pkt, StateData) catch error:{xmpp_codec, Why} -> case xmpp:is_stanza(El) of @@ -620,12 +644,15 @@ decode_element(#xmlel{} = El, StateName, StateData) -> ok end, {next_state, StateName, StateData} - end. + end; +decode_element(Pkt, StateName, StateData) -> + ?MODULE:StateName(Pkt, StateData). opt_type(domain_certfile) -> fun iolist_to_binary/1; opt_type(max_fsm_queue) -> fun (I) when is_integer(I), I > 0 -> I end; opt_type(s2s_certfile) -> fun iolist_to_binary/1; +opt_type(s2s_cafile) -> fun iolist_to_binary/1; opt_type(s2s_ciphers) -> fun iolist_to_binary/1; opt_type(s2s_dhfile) -> fun iolist_to_binary/1; opt_type(s2s_protocol_options) -> @@ -647,6 +674,6 @@ opt_type(s2s_use_starttls) -> (required_trusted) -> required_trusted end; opt_type(_) -> - [domain_certfile, max_fsm_queue, s2s_certfile, + [domain_certfile, max_fsm_queue, s2s_certfile, s2s_cafile, s2s_ciphers, s2s_dhfile, s2s_protocol_options, s2s_tls_compression, s2s_use_starttls]. |