diff options
author | Evgeniy Khramtsov <ekhramtsov@process-one.net> | 2016-12-30 00:00:36 +0300 |
---|---|---|
committer | Evgeniy Khramtsov <ekhramtsov@process-one.net> | 2016-12-30 00:00:36 +0300 |
commit | e7fe4dc474ed180a4200b2bdefc2ff58b12340c0 (patch) | |
tree | 83a2ca201a8fae1f1ba49a2f2fb541ad64e02f91 /src/xmpp_stream_out.erl | |
parent | Add xmpp_stream_out behaviour and rewrite s2s/SM code (diff) |
More refactoring on session management
Diffstat (limited to 'src/xmpp_stream_out.erl')
-rw-r--r-- | src/xmpp_stream_out.erl | 204 |
1 files changed, 110 insertions, 94 deletions
diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl index fc373fff..08804e43 100644 --- a/src/xmpp_stream_out.erl +++ b/src/xmpp_stream_out.erl @@ -33,6 +33,7 @@ -include_lib("kernel/include/inet.hrl"). -type state() :: map(). +-type noreply() :: {noreply, state(), timeout()}. -type host_port() :: {inet:hostname(), inet:port_number()}. -type ip_port() :: {inet:ip_address(), inet:port_number()}. -type network_error() :: {error, inet:posix() | inet_res:res_error()}. @@ -42,7 +43,8 @@ {tls, term()} | {pkix, binary()} | {auth, atom() | binary() | string()} | - {socket, inet:posix() | closed | timeout}. + {socket, inet:posix() | closed | timeout} | + internal_failure. -callback init(list()) -> {ok, state()} | {stop, term()} | ignore. @@ -107,7 +109,7 @@ close(_, _) -> establish(State) -> process_stream_established(State). --spec set_timeout(state(), non_neg_integer() | infinity) -> state(). +-spec set_timeout(state(), timeout()) -> state(). set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> case Timeout of infinity -> State#{stream_timeout => infinity}; @@ -148,12 +150,15 @@ format_error({tls, Reason}) -> format("TLS failed: ~w", [Reason]); format_error({auth, Reason}) -> format("Authentication failed: ~s", [Reason]); +format_error(internal_failure) -> + <<"Internal server error">>; format_error(Err) -> format("Unrecognized error: ~w", [Err]). %%%=================================================================== %%% gen_server callbacks %%%=================================================================== +-spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore. init([Mod, SockMod, From, To, Opts]) -> Time = p1_time_compat:monotonic_time(milli_seconds), State = #{owner => self(), @@ -183,36 +188,38 @@ init([Mod, SockMod, From, To, Opts]) -> Err end. +-spec handle_call(term(), term(), state()) -> noreply(). handle_call(Call, From, #{mod := Mod} = State) -> noreply(try Mod:handle_call(Call, From, State) catch _:undef -> State end). +-spec handle_cast(term(), state()) -> noreply(). handle_cast(connect, #{remote_server := RemoteServer, sockmod := SockMod, stream_state := connecting} = State) -> - case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of - false -> - noreply(process_stream_close({error, {idna, bad_string}}, State)); - ASCIIName -> - case resolve(binary_to_list(ASCIIName), State) of - {ok, AddrPorts} -> - case connect(AddrPorts, State) of - {ok, Socket, AddrPort} -> - SocketMonitor = SockMod:monitor(Socket), - State1 = State#{ip => AddrPort, - socket => Socket, - socket_monitor => SocketMonitor}, - State2 = State1#{stream_state => wait_for_stream}, - noreply(send_header(State2)); - {error, Why} -> - Err = {error, {socket, Why}}, - noreply(process_stream_close(Err, State)) - end; - {error, Why} -> - noreply(process_stream_close({error, {dns, Why}}, State)) - end - end; + noreply( + case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of + false -> + process_stream_end({idna, bad_string}, State); + ASCIIName -> + case resolve(binary_to_list(ASCIIName), State) of + {ok, AddrPorts} -> + case connect(AddrPorts, State) of + {ok, Socket, AddrPort} -> + SocketMonitor = SockMod:monitor(Socket), + State1 = State#{ip => AddrPort, + socket => Socket, + socket_monitor => SocketMonitor}, + State2 = State1#{stream_state => wait_for_stream}, + send_header(State2); + {error, Why} -> + process_stream_end({socket, Why}, State) + end; + {error, Why} -> + process_stream_end({dns, Why}, State) + end + end); handle_cast(connect, State) -> %% Ignoring connection attempts in other states noreply(State); @@ -225,66 +232,70 @@ handle_cast(Cast, #{mod := Mod} = State) -> catch _:undef -> State end). +-spec handle_info(term(), state()) -> noreply(). handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, #{stream_state := wait_for_stream, xmlns := XMLNS, lang := MyLang} = State) -> El = #xmlel{name = Name, attrs = Attrs}, - try xmpp:decode(El, XMLNS, []) of - #stream_start{} = Pkt -> - noreply(process_stream(Pkt, State)); - _ -> - noreply(send_element(State, xmpp:serr_invalid_xml())) - catch _:{xmpp_codec, Why} -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - Err = xmpp:serr_invalid_xml(Txt, Lang), - noreply(send_element(State, Err)) - end; + noreply( + try xmpp:decode(El, XMLNS, []) of + #stream_start{} = Pkt -> + process_stream(Pkt, State); + _ -> + send_element(State, xmpp:serr_invalid_xml()) + catch _:{xmpp_codec, Why} -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + Err = xmpp:serr_invalid_xml(Txt, Lang), + send_element(State, Err) + end); handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> State1 = send_header(State), - case is_disconnected(State1) of - true -> State1; - false -> - Err = case Reason of - <<"XML stanza is too big">> -> - xmpp:serr_policy_violation(Reason, Lang); - _ -> - xmpp:serr_not_well_formed() - end, - noreply(send_element(State1, Err)) - end; + noreply( + case is_disconnected(State1) of + true -> State1; + false -> + Err = case Reason of + <<"XML stanza is too big">> -> + xmpp:serr_policy_violation(Reason, Lang); + _ -> + xmpp:serr_not_well_formed() + end, + send_element(State1, Err) + end); handle_info({'$gen_event', {xmlstreamelement, El}}, #{xmlns := NS, lang := MyLang, mod := Mod} = State) -> - try xmpp:decode(El, NS, [ignore_els]) of - Pkt -> - State1 = try Mod:handle_recv(El, Pkt, State) - catch _:undef -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> noreply(process_element(Pkt, State1)) - end - catch _:{xmpp_codec, Why} -> - State1 = try Mod:handle_recv(El, undefined, State) - catch _:undef -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang))) - end - end; + noreply( + try xmpp:decode(El, NS, [ignore_els]) of + Pkt -> + State1 = try Mod:handle_recv(El, Pkt, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> process_element(Pkt, State1) + end + catch _:{xmpp_codec, Why} -> + State1 = try Mod:handle_recv(El, undefined, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + send_error(State1, El, xmpp:err_bad_request(Txt, Lang)) + end + end); handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, #{mod := Mod} = State) -> noreply(try Mod:handle_cdata(Data, State) catch _:undef -> State end); handle_info({'$gen_event', {xmlstreamend, _}}, State) -> - noreply(process_stream_end({error, {stream, reset}}, State)); + noreply(process_stream_end({stream, reset}, State)); handle_info({'$gen_event', closed}, State) -> - noreply(process_stream_close({error, {socket, closed}}, State)); + noreply(process_stream_end({socket, closed}, State)); handle_info(timeout, #{mod := Mod} = State) -> Disconnected = is_disconnected(State), noreply(try Mod:handle_timeout(State) @@ -295,12 +306,13 @@ handle_info(timeout, #{mod := Mod} = State) -> end); handle_info({'DOWN', MRef, _Type, _Object, _Info}, #{socket_monitor := MRef} = State) -> - noreply(process_stream_close({error, {socket, closed}}, State)); + noreply(process_stream_end({socket, closed}, State)); handle_info(Info, #{mod := Mod} = State) -> noreply(try Mod:handle_info(Info, State) catch _:undef -> State end). +-spec terminate(term(), state()) -> any(). terminate(Reason, #{mod := Mod} = State) -> case get(already_terminated) of true -> @@ -319,7 +331,7 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== --spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. +-spec noreply(state()) -> noreply(). noreply(#{stream_timeout := infinity} = State) -> {noreply, State, infinity}; noreply(#{stream_timeout := {MSecs, OldTime}} = State) -> @@ -335,15 +347,6 @@ new_id() -> is_disconnected(#{stream_state := StreamState}) -> StreamState == disconnected. --spec process_stream_close(stop_reason(), state()) -> state(). -process_stream_close(_, #{stream_state := disconnected} = State) -> - State; -process_stream_close(Reason, #{mod := Mod} = State) -> - State1 = send_trailer(State), - try Mod:handle_stream_close(Reason, State1) - catch _:undef -> stop(State1) - end. - -spec process_stream_end(stop_reason(), state()) -> state(). process_stream_end(_, #{stream_state := disconnected} = State) -> State; @@ -359,6 +362,8 @@ process_stream(#stream_start{xmlns = XML_NS, #{xmlns := NS} = State) when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> send_element(State, xmpp:serr_invalid_namespace()); +process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> + send_element(State, xmpp:serr_unsupported_version()); process_stream(#stream_start{lang = Lang, id = ID, version = Version} = StreamStart, #{mod := Mod} = State) -> @@ -370,8 +375,10 @@ process_stream(#stream_start{lang = Lang, id = ID, true -> State2; false -> case Version of - {1,0} -> State2#{stream_state => wait_for_features}; - _ -> process_stream_downgrade(StreamStart, State) + {1, _} -> + State2#{stream_state => wait_for_features}; + _ -> + process_stream_downgrade(StreamStart, State2) end end. @@ -387,7 +394,7 @@ process_element(Pkt, #{stream_state := StateName} = State) -> #sasl_failure{} when StateName == wait_for_sasl_response -> process_sasl_failure(Pkt, State); #stream_error{} -> - process_stream_end({error, {stream, Pkt}}, State); + process_stream_end({stream, Pkt}, State); _ when is_record(Pkt, stream_features); is_record(Pkt, starttls_proceed); is_record(Pkt, starttls); @@ -487,14 +494,23 @@ process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) -> stream_encrypted => true}, send_header(State1); {error, Why} -> - process_stream_close({error, {tls, Why}}, State) + process_stream_end({tls, Why}, State) end. -spec process_stream_downgrade(stream_start(), state()) -> state(). -process_stream_downgrade(StreamStart, #{mod := Mod} = State) -> - try Mod:downgrade_stream(StreamStart, State) - catch _:undef -> - send_element(State, xmpp:serr_unsupported_version()) +process_stream_downgrade(StreamStart, + #{mod := Mod, lang := Lang, + stream_encrypted := Encrypted} = State) -> + TLSRequired = is_starttls_required(State), + if not Encrypted and TLSRequired -> + Txt = <<"Use of STARTTLS required">>, + send_element(State, xmpp:err_policy_violation(Txt, Lang)); + true -> + State1 = State#{stream_state => downgraded}, + try Mod:handle_stream_downgraded(StreamStart, State1) + catch _:undef -> + send_element(State1, xmpp:serr_unsupported_version()) + end end. -spec process_cert_verification(state()) -> state(). @@ -509,7 +525,7 @@ process_cert_verification(#{stream_encrypted := true, {ok, _} -> State#{stream_verified => true}; {error, Why, _Peer} -> - process_stream_close({error, {pkix, Why}}, State) + process_stream_end({pkix, Why}, State) end; false -> State#{stream_verified => true} @@ -538,7 +554,7 @@ process_sasl_success(#{mod := Mod, -spec process_sasl_failure(sasl_failure(), state()) -> state(). process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) -> try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State) - catch _:undef -> process_stream_close({error, {auth, Reason}}, State) + catch _:undef -> process_stream_end({auth, Reason}, State) end. -spec process_packet(xmpp_element(), state()) -> state(). @@ -581,7 +597,7 @@ send_header(#{remote_server := RemoteServer, version = {1,0}}), case send_text(State, fxml:element_to_header(Header)) of ok -> State; - {error, Why} -> process_stream_close({error, {socket, Why}}, State) + {error, Why} -> process_stream_end({socket, Why}, State) end. -spec send_element(state(), xmpp_element()) -> state(). @@ -596,11 +612,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> false -> case send_text(State1, Data) of _ when is_record(Pkt, stream_error) -> - process_stream_end({error, {stream, Pkt}}, State1); + process_stream_end({stream, Pkt}, State1); ok -> State1; {error, Why} -> - process_stream_close({error, {socket, Why}}, State1) + process_stream_end({socket, Why}, State1) end end. @@ -626,7 +642,7 @@ send_text(#{sockmod := SockMod, socket := Socket, stream_state := StateName}, Data) when StateName /= disconnected -> SockMod:send(Socket, Data); send_text(_, _) -> - {error, einval}. + {error, closed}. -spec send_trailer(state()) -> state(). send_trailer(State) -> |