diff options
Diffstat (limited to 'src/ejabberd_s2s_in.erl')
-rw-r--r-- | src/ejabberd_s2s_in.erl | 57 |
1 files changed, 37 insertions, 20 deletions
diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index 64a85fc1c..725421d1d 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -14,8 +14,8 @@ %% External exports -export([start/2, start_link/2, - match_domain/2, - socket_type/0]). + become_controller/1, + match_domain/2]). %% gen_fsm callbacks -export([init/1, @@ -37,6 +37,8 @@ -define(DICT, dict). -record(state, {socket, + sockmod, + receiver, streamid, shaper, tls = false, @@ -85,8 +87,8 @@ start(SockData, Opts) -> start_link(SockData, Opts) -> gen_fsm:start_link(ejabberd_s2s_in, [SockData, Opts], ?FSMOPTS). -socket_type() -> - xml_stream. +become_controller(Pid) -> + gen_fsm:send_all_state_event(Pid, become_controller). %%%---------------------------------------------------------------------- %%% Callback functions from gen_fsm @@ -99,12 +101,19 @@ socket_type() -> %% ignore | %% {stop, StopReason} %%---------------------------------------------------------------------- -init([Socket, Opts]) -> - ?INFO_MSG("started: ~p", [Socket]), +init([{SockMod, Socket}, Opts]) -> + ?INFO_MSG("started: ~p", [{SockMod, Socket}]), Shaper = case lists:keysearch(shaper, 1, Opts) of {value, {_, S}} -> S; _ -> none end, + MaxStanzaSize = + case lists:keysearch(max_stanza_size, 1, Opts) of + {value, {_, Size}} -> Size; + _ -> infinity + end, + ReceiverPid = ejabberd_receiver:start( + Socket, SockMod, none, MaxStanzaSize), StartTLS = case ejabberd_config:get_local_option(s2s_use_starttls) of undefined -> false; @@ -120,6 +129,8 @@ init([Socket, Opts]) -> Timer = erlang:start_timer(?S2STIMEOUT, self(), []), {ok, wait_for_stream, #state{socket = Socket, + sockmod = SockMod, + receiver = ReceiverPid, streamid = new_id(), shaper = Shaper, tls = StartTLS, @@ -144,10 +155,9 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> SASL = if StateData#state.tls_enabled -> - case ejabberd_socket:get_peer_certificate( - StateData#state.socket) of + case tls:get_peer_certificate(StateData#state.socket) of {ok, _Cert} -> - case ejabberd_socket:get_verify_result( + case tls:get_verify_result( StateData#state.socket) of 0 -> [{xmlelement, "mechanisms", @@ -204,7 +214,7 @@ wait_for_feature_request({xmlstreamelement, El}, StateData) -> {xmlelement, Name, Attrs, Els} = El, TLS = StateData#state.tls, TLSEnabled = StateData#state.tls_enabled, - SockMod = ejabberd_socket:get_sockmod(StateData#state.socket), + SockMod = StateData#state.sockmod, case {xml:get_attr_s("xmlns", Attrs), Name} of {?NS_TLS, "starttls"} when TLS == true, TLSEnabled == false, @@ -212,11 +222,13 @@ wait_for_feature_request({xmlstreamelement, El}, StateData) -> ?INFO_MSG("starttls", []), Socket = StateData#state.socket, TLSOpts = StateData#state.tls_options, - TLSSocket = ejabberd_socket:starttls(Socket, TLSOpts), + {ok, TLSSocket} = tls:tcp_to_tls(Socket, TLSOpts), + ejabberd_receiver:starttls(StateData#state.receiver, TLSSocket), send_element(StateData, {xmlelement, "proceed", [{"xmlns", ?NS_TLS}], []}), {next_state, wait_for_stream, - StateData#state{socket = TLSSocket, + StateData#state{sockmod = tls, + socket = TLSSocket, streamid = new_id(), tls_enabled = true }}; @@ -227,10 +239,9 @@ wait_for_feature_request({xmlstreamelement, El}, StateData) -> Auth = jlib:decode_base64(xml:get_cdata(Els)), AuthDomain = jlib:nameprep(Auth), AuthRes = - case ejabberd_socket:get_peer_certificate( - StateData#state.socket) of + case tls:get_peer_certificate(StateData#state.socket) of {ok, Cert} -> - case ejabberd_socket:get_verify_result( + case tls:get_verify_result( StateData#state.socket) of 0 -> case AuthDomain of @@ -256,8 +267,8 @@ wait_for_feature_request({xmlstreamelement, El}, StateData) -> end, if AuthRes -> - ejabberd_socket:reset_stream( - StateData#state.socket), + ejabberd_receiver:reset_stream( + StateData#state.receiver), send_element(StateData, {xmlelement, "success", [{"xmlns", ?NS_SASL}], []}), @@ -456,6 +467,12 @@ stream_established(closed, StateData) -> %% {next_state, NextStateName, NextStateData, Timeout} | %% {stop, Reason, NewStateData} %%---------------------------------------------------------------------- +handle_event(become_controller, StateName, StateData) -> + ok = (StateData#state.sockmod):controlling_process( + StateData#state.socket, + StateData#state.receiver), + ejabberd_receiver:become_controller(StateData#state.receiver), + {next_state, StateName, StateData}; handle_event(_Event, StateName, StateData) -> {next_state, StateName, StateData}. @@ -500,7 +517,7 @@ handle_info(_, StateName, StateData) -> %%---------------------------------------------------------------------- terminate(Reason, _StateName, StateData) -> ?INFO_MSG("terminated: ~p", [Reason]), - ejabberd_socket:close(StateData#state.socket), + ejabberd_receiver:close(StateData#state.receiver), ok. %%%---------------------------------------------------------------------- @@ -508,7 +525,7 @@ terminate(Reason, _StateName, StateData) -> %%%---------------------------------------------------------------------- send_text(StateData, Text) -> - ejabberd_socket:send(StateData#state.socket, Text). + (StateData#state.sockmod):send(StateData#state.socket, Text). send_element(StateData, El) -> send_text(StateData, xml:element_to_string(El)). @@ -516,7 +533,7 @@ send_element(StateData, El) -> change_shaper(StateData, Host, JID) -> Shaper = acl:match_rule(Host, StateData#state.shaper, JID), - ejabberd_socket:change_shaper(StateData#state.socket, Shaper). + ejabberd_receiver:change_shaper(StateData#state.receiver, Shaper). new_id() -> |