From 1e55e018e534aa82541c5f460063a237192b768c Mon Sep 17 00:00:00 2001 From: Evgeniy Khramtsov Date: Mon, 9 Jan 2017 17:02:17 +0300 Subject: Adopt remaining code to support new hooks --- src/xmpp_stream_in.erl | 333 ++++++++++++++++++++++++++++--------------------- 1 file changed, 194 insertions(+), 139 deletions(-) (limited to 'src/xmpp_stream_in.erl') diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index 1ad78d45..b2b3b307 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -20,9 +20,11 @@ %%% %%%------------------------------------------------------------------- -module(xmpp_stream_in). --behaviour(gen_server). +-define(GEN_SERVER, gen_server). +-behaviour(?GEN_SERVER). -protocol({rfc, 6120}). +-protocol({xep, 114, '1.6'}). %% API -export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1, @@ -43,17 +45,18 @@ -include("xmpp.hrl"). -type state() :: map(). -type stop_reason() :: {stream, reset | {in | out, stream_error()}} | - {tls, term()} | + {tls, inet:posix() | atom() | binary()} | {socket, inet:posix() | closed | timeout} | internal_failure. --callback init(list()) -> {ok, state()} | {stop, term()} | ignore. +-callback init(list()) -> {ok, state()} | {error, term()} | ignore. -callback handle_cast(term(), state()) -> state(). -callback handle_call(term(), term(), state()) -> state(). -callback handle_info(term(), state()) -> state(). -callback terminate(term(), state()) -> any(). -callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. --callback handle_stream_start(state()) -> state(). +-callback handle_stream_start(stream_start(), state()) -> state(). +-callback handle_stream_established(state()) -> state(). -callback handle_stream_end(stop_reason(), state()) -> state(). -callback handle_cdata(binary(), state()) -> state(). -callback handle_unauthenticated_packet(xmpp_element(), state()) -> state(). @@ -63,6 +66,7 @@ -callback handle_auth_failure(binary(), binary(), atom(), state()) -> state(). -callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). -callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). +-callback handle_timeout(state()) -> state(). -callback get_password_fun(state()) -> fun(). -callback check_password_fun(state()) -> fun(). -callback check_password_digest_fun(state()) -> fun(). @@ -71,6 +75,8 @@ -callback tls_options(state()) -> [proplists:property()]. -callback tls_required(state()) -> boolean(). -callback tls_verify(state()) -> boolean(). +-callback tls_enabled(state()) -> boolean(). +-callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()]. -callback unauthenticated_stream_features(state()) -> [xmpp_element()]. -callback authenticated_stream_features(state()) -> [xmpp_element()]. @@ -81,7 +87,8 @@ handle_info/2, terminate/2, code_change/3, - handle_stream_start/1, + handle_stream_start/2, + handle_stream_established/1, handle_stream_end/2, handle_cdata/2, handle_authenticated_packet/2, @@ -91,6 +98,7 @@ handle_auth_failure/4, handle_send/3, handle_recv/3, + handle_timeout/1, get_password_fun/1, check_password_fun/1, check_password_digest_fun/1, @@ -99,6 +107,8 @@ tls_options/1, tls_required/1, tls_verify/1, + tls_enabled/1, + sasl_mechanisms/2, unauthenticated_stream_features/1, authenticated_stream_features/1]). @@ -106,19 +116,19 @@ %%% API %%%=================================================================== start(Mod, Args, Opts) -> - gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). start_link(Mod, Args, Opts) -> - gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). call(Ref, Msg, Timeout) -> - gen_server:call(Ref, Msg, Timeout). + ?GEN_SERVER:call(Ref, Msg, Timeout). cast(Ref, Msg) -> - gen_server:cast(Ref, Msg). + ?GEN_SERVER:cast(Ref, Msg). reply(Ref, Reply) -> - gen_server:reply(Ref, Reply). + ?GEN_SERVER:reply(Ref, Reply). -spec stop(pid()) -> ok; (state()) -> no_return(). @@ -135,7 +145,7 @@ stop(_) -> send(Pid, Pkt) when is_pid(Pid) -> cast(Pid, {send, Pkt}); send(#{owner := Owner} = State, Pkt) when Owner == self() -> - send_element(State, Pkt); + send_pkt(State, Pkt); send(_, _) -> erlang:error(badarg). @@ -193,7 +203,7 @@ format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) -> format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) -> format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]); format_error({tls, Reason}) -> - format("TLS failed: ~w", [Reason]); + format("TLS failed: ~s", [format_tls_error(Reason)]); format_error(internal_failure) -> <<"Internal server error">>; format_error(Err) -> @@ -203,13 +213,9 @@ format_error(Err) -> %%% gen_server callbacks %%%=================================================================== init([Module, {SockMod, Socket}, Opts]) -> - XMLSocket = case lists:keyfind(xml_socket, 1, Opts) of - {_, XS} -> XS; - false -> false - end, Encrypted = proplists:get_bool(tls, Opts), SocketMonitor = SockMod:monitor(Socket), - case peername(SockMod, Socket) of + case SockMod:peername(Socket) of {ok, IP} -> Time = p1_time_compat:monotonic_time(milli_seconds), State = #{owner => self(), @@ -227,7 +233,6 @@ init([Module, {SockMod, Socket}, Opts]) -> stream_encrypted => Encrypted, stream_version => {1,0}, stream_authenticated => false, - xml_socket => XMLSocket, xmlns => ?NS_CLIENT, lang => <<"">>, user => <<"">>, @@ -238,18 +243,32 @@ init([Module, {SockMod, Socket}, Opts]) -> case try Module:init([State, Opts]) catch _:undef -> {ok, State} end of - {ok, State1} -> + {ok, State1} when not Encrypted -> {_, State2, Timeout} = noreply(State1), {ok, State2, Timeout}; - Err -> - Err + {ok, State1} when Encrypted -> + TLSOpts = try Module:tls_options(State1) + catch _:undef -> [] + end, + case SockMod:starttls(Socket, TLSOpts) of + {ok, TLSSocket} -> + State2 = State1#{socket => TLSSocket}, + {_, State3, Timeout} = noreply(State2), + {ok, State3, Timeout}; + {error, Reason} -> + {stop, Reason} + end; + {error, Reason} -> + {stop, Reason}; + ignore -> + ignore end; - {error, Reason} -> - {stop, Reason} + {error, _Reason} -> + ignore end. handle_cast({send, Pkt}, State) -> - noreply(send_element(State, Pkt)); + noreply(send_pkt(State, Pkt)); handle_cast(stop, State) -> {stop, normal, State}; handle_cast(Cast, #{mod := Mod} = State) -> @@ -278,7 +297,7 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, State1 = send_header(State), case is_disconnected(State1) of true -> State1; - false -> send_element(State1, xmpp:serr_invalid_xml()) + false -> send_pkt(State1, xmpp:serr_invalid_xml()) end catch _:{xmpp_codec, Why} -> State1 = send_header(State), @@ -288,7 +307,7 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, Txt = xmpp:io_format_error(Why), Lang = select_lang(MyLang, xmpp:get_lang(El)), Err = xmpp:serr_invalid_xml(Txt, Lang), - send_element(State1, Err) + send_pkt(State1, Err) end end); handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> @@ -303,7 +322,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> _ -> xmpp:serr_not_well_formed() end, - send_element(State1, Err) + send_pkt(State1, Err) end); handle_info({'$gen_event', {xmlstreamelement, El}}, #{xmlns := NS, mod := Mod} = State) -> @@ -339,7 +358,7 @@ handle_info(timeout, #{mod := Mod} = State) -> Disconnected = is_disconnected(State), noreply(try Mod:handle_timeout(State) catch _:undef when not Disconnected -> - send_element(State, xmpp:serr_connection_timeout()); + send_pkt(State, xmpp:serr_connection_timeout()); _:undef -> stop(State) end); @@ -385,14 +404,6 @@ new_id() -> is_disconnected(#{stream_state := StreamState}) -> StreamState == disconnected. --spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}| - {error, inet:posix()}. -peername(SockMod, Socket) -> - case SockMod of - gen_tcp -> inet:peername(Socket); - _ -> SockMod:peername(Socket) - end. - -spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state(). process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> case xmpp:is_stanza(El) of @@ -408,12 +419,12 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> Txt = xmpp:io_format_error(Reason), Err = #sasl_failure{reason = 'malformed-request', text = xmpp:mk_text(Txt, MyLang)}, - send_element(State, Err); + send_pkt(State, Err); {<<"starttls">>, ?NS_TLS} -> - send_element(State, #starttls_failure{}); + send_pkt(State, #starttls_failure{}); {<<"compress">>, ?NS_COMPRESS} -> Err = #compress_failure{reason = 'setup-failed'}, - send_element(State, Err); + send_pkt(State, Err); _ -> %% Maybe add something more? State @@ -434,9 +445,9 @@ process_stream(#stream_start{xmlns = XML_NS, stream_xmlns = STREAM_NS}, #{xmlns := NS} = State) when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> - send_element(State, xmpp:serr_invalid_namespace()); + send_pkt(State, xmpp:serr_invalid_namespace()); process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> - send_element(State, xmpp:serr_unsupported_version()); + send_pkt(State, xmpp:serr_unsupported_version()); process_stream(#stream_start{lang = Lang}, #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State) when size(Lang) > 35 -> @@ -445,14 +456,14 @@ process_stream(#stream_start{lang = Lang}, %% language tags MUST allow for language tags of at least 35 characters. %% Do not store long language tag to avoid possible DoS/flood attacks Txt = <<"Too long value of 'xml:lang' attribute">>, - send_element(State, xmpp:serr_policy_violation(Txt, DefaultLang)); + send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang)); process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) -> Txt = <<"Missing 'to' attribute">>, - send_element(State, xmpp:serr_improper_addressing(Txt, Lang)); + send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)); process_stream(#stream_start{to = #jid{luser = U, lresource = R}}, #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> -> Txt = <<"Improper 'to' attribute">>, - send_element(State, xmpp:serr_improper_addressing(Txt, Lang)); + send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)); process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart, #{xmlns := ?NS_COMPONENT, mod := Mod} = State) -> State1 = State#{remote_server => RemoteServer, @@ -509,29 +520,29 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> #starttls{} -> process_starttls_failure(unexpected_starttls_request, State); #sasl_auth{} when StateName == wait_for_starttls -> - send_element(State, #sasl_failure{reason = 'encryption-required'}); + send_pkt(State, #sasl_failure{reason = 'encryption-required'}); #sasl_auth{} when StateName == wait_for_sasl_request -> process_sasl_request(Pkt, State); #sasl_auth{} -> Txt = <<"SASL negotiation is not allowed in this state">>, - send_element(State, #sasl_failure{reason = 'not-authorized', + send_pkt(State, #sasl_failure{reason = 'not-authorized', text = xmpp:mk_text(Txt, Lang)}); #sasl_response{} when StateName == wait_for_starttls -> - send_element(State, #sasl_failure{reason = 'encryption-required'}); + send_pkt(State, #sasl_failure{reason = 'encryption-required'}); #sasl_response{} when StateName == wait_for_sasl_response -> process_sasl_response(Pkt, State); #sasl_response{} -> Txt = <<"SASL negotiation is not allowed in this state">>, - send_element(State, #sasl_failure{reason = 'not-authorized', + send_pkt(State, #sasl_failure{reason = 'not-authorized', text = xmpp:mk_text(Txt, Lang)}); #sasl_abort{} when StateName == wait_for_sasl_response -> process_sasl_abort(State); #sasl_abort{} -> - send_element(State, #sasl_failure{reason = 'aborted'}); + send_pkt(State, #sasl_failure{reason = 'aborted'}); #sasl_success{} -> State; #compress{} when StateName == wait_for_sasl_response -> - send_element(State, #compress_failure{reason = 'setup-failed'}); + send_pkt(State, #compress_failure{reason = 'setup-failed'}); #compress{} -> process_compress(Pkt, State); #handshake{} when StateName == wait_for_handshake -> @@ -570,7 +581,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> {ok, #iq{type = set, sub_els = [_]} = Pkt2} when NS == ?NS_CLIENT -> case xmpp:get_subtag(Pkt2, #xmpp_session{}) of #xmpp_session{} -> - send_element(State, xmpp:make_iq_result(Pkt2)); + send_pkt(State, xmpp:make_iq_result(Pkt2)); _ -> try Mod:handle_authenticated_packet(Pkt2, State) catch _:undef -> @@ -585,7 +596,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> send_error(State, Pkt, Err) end; {error, Err} -> - send_element(State, Err) + send_pkt(State, Err) end. -spec process_bind(xmpp_element(), state()) -> state(). @@ -604,7 +615,7 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt, server := S, resource := NewR} = State1} when NewR /= <<"">> -> Reply = #bind{jid = jid:make(U, S, NewR)}, - State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)), + State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)), process_stream_established(State2); {error, #stanza_error{}, State1} = Err -> send_error(State1, Pkt, Err) @@ -646,7 +657,7 @@ process_handshake(#handshake{data = Digest}, case is_disconnected(State1) of true -> State1; false -> - State2 = send_element(State1, #handshake{}), + State2 = send_pkt(State1, #handshake{}), process_stream_established(State2) end; false -> @@ -656,7 +667,7 @@ process_handshake(#handshake{data = Digest}, end, case is_disconnected(State1) of true -> State1; - false -> send_element(State1, xmpp:serr_not_authorized()) + false -> send_pkt(State1, xmpp:serr_not_authorized()) end end. @@ -674,7 +685,7 @@ process_stream_established(#{mod := Mod} = State) -> -spec process_compress(compress(), state()) -> state(). process_compress(#compress{}, #{stream_compressed := true} = State) -> - send_element(State, #compress_failure{reason = 'setup-failed'}); + send_pkt(State, #compress_failure{reason = 'setup-failed'}); process_compress(#compress{methods = HisMethods}, #{socket := Socket, sockmod := SockMod, mod := Mod} = State) -> MyMethods = try Mod:compress_methods(State) @@ -683,44 +694,60 @@ process_compress(#compress{methods = HisMethods}, CommonMethods = lists_intersection(MyMethods, HisMethods), case lists:member(<<"zlib">>, CommonMethods) of true -> - BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})), - ZlibSocket = SockMod:compress(Socket, BCompressed), - State#{socket => ZlibSocket, - stream_id => new_id(), - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - stream_compressed => true}; + State1 = send_pkt(State, #compressed{}), + case is_disconnected(State1) of + true -> State1; + false -> + case SockMod:compress(Socket) of + {ok, ZlibSocket} -> + State1#{socket => ZlibSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_compressed => true}; + {error, _} -> + Err = #compress_failure{reason = 'setup-failed'}, + send_pkt(State1, Err) + end + end; false -> - send_element(State, #compress_failure{reason = 'unsupported-method'}) + send_pkt(State, #compress_failure{reason = 'unsupported-method'}) end. -spec process_starttls(state()) -> state(). +process_starttls(#{stream_encrypted := true} = State) -> + process_starttls_failure(already_encrypted, State); process_starttls(#{socket := Socket, sockmod := SockMod, mod := Mod} = State) -> - TLSOpts = try Mod:tls_options(State) - catch _:undef -> [] - end, - case SockMod:starttls(Socket, TLSOpts) of - {ok, TLSSocket} -> - State1 = send_element(State, #starttls_proceed{}), - case is_disconnected(State1) of - true -> State1; - false -> - State1#{socket => TLSSocket, - stream_id => new_id(), - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - stream_encrypted => true} + case is_starttls_available(State) of + true -> + TLSOpts = try Mod:tls_options(State) + catch _:undef -> [] + end, + case SockMod:starttls(Socket, TLSOpts) of + {ok, TLSSocket} -> + State1 = send_pkt(State, #starttls_proceed{}), + case is_disconnected(State1) of + true -> State1; + false -> + State1#{socket => TLSSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_encrypted => true} + end; + {error, Reason} -> + process_starttls_failure(Reason, State) end; - {error, Reason} -> - process_starttls_failure(Reason, State) + false -> + process_starttls_failure(starttls_unsupported, State) end. -spec process_starttls_failure(term(), state()) -> state(). process_starttls_failure(Why, State) -> - State1 = send_element(State, #starttls_failure{}), + State1 = send_pkt(State, #starttls_failure{}), case is_disconnected(State1) of true -> State1; false -> process_stream_end({tls, Why}, State1) @@ -780,17 +807,17 @@ process_sasl_success(Props, ServerOut, mod := Mod, sasl_mech := Mech} = State) -> User = identity(Props), AuthModule = proplists:get_value(auth_module, Props), - State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State) - catch _:undef -> State - end, + State1 = send_pkt(State, #sasl_success{text = ServerOut}), case is_disconnected(State1) of true -> State1; false -> - SockMod:reset_stream(Socket), - State2 = send_element(State1, #sasl_success{text = ServerOut}), + State2 = try Mod:handle_auth_success(User, Mech, AuthModule, State1) + catch _:undef -> State1 + end, case is_disconnected(State2) of true -> State2; false -> + SockMod:reset_stream(Socket), State3 = maps:remove(sasl_state, maps:remove(sasl_mech, State2)), State3#{stream_id => new_id(), @@ -806,19 +833,23 @@ process_sasl_success(Props, ServerOut, process_sasl_continue(ServerOut, NewSASLState, State) -> State1 = State#{sasl_state => NewSASLState, stream_state => wait_for_sasl_response}, - send_element(State1, #sasl_challenge{text = ServerOut}). + send_pkt(State1, #sasl_challenge{text = ServerOut}). -spec process_sasl_failure(atom(), binary(), state()) -> state(). process_sasl_failure(Err, User, #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) -> {Reason, Text} = format_sasl_error(Mech, Err), - State1 = try Mod:handle_auth_failure(User, Mech, Text, State) - catch _:undef -> State - end, - State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)), - State3 = State2#{stream_state => wait_for_sasl_request}, - send_element(State3, #sasl_failure{reason = Reason, - text = xmpp:mk_text(Text, Lang)}). + State1 = send_pkt(State, #sasl_failure{reason = Reason, + text = xmpp:mk_text(Text, Lang)}), + case is_disconnected(State1) of + true -> State1; + false -> + State2 = try Mod:handle_auth_failure(User, Mech, Text, State1) + catch _:undef -> State1 + end, + State3 = maps:remove(sasl_state, maps:remove(sasl_mech, State2)), + State3#{stream_state => wait_for_sasl_request} + end. -spec process_sasl_abort(state()) -> state(). process_sasl_abort(State) -> @@ -835,7 +866,7 @@ send_features(#{stream_version := {1,0}, ++ get_tls_feature(State) ++ get_bind_feature(State) ++ get_session_feature(State) ++ get_other_features(State) end, - send_element(State, #stream_features{sub_els = Features}); + send_pkt(State, #stream_features{sub_els = Features}); send_features(State) -> %% clients and servers from stone age State. @@ -849,10 +880,13 @@ get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod, TLSVerify = try Mod:tls_verify(State) catch _:undef -> false end, - if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> - [<<"EXTERNAL">>|Mechs]; - true -> - Mechs + Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> + [<<"EXTERNAL">>|Mechs]; + true -> + Mechs + end, + try Mod:sasl_mechanisms(Mechs1, State) + catch _:undef -> Mechs1 end. -spec get_sasl_feature(state()) -> [sasl_mechanisms()]. @@ -882,8 +916,13 @@ get_compress_feature(_) -> -spec get_tls_feature(state()) -> [starttls()]. get_tls_feature(#{stream_authenticated := false, stream_encrypted := false} = State) -> - TLSRequired = is_starttls_required(State), - [#starttls{required = TLSRequired}]; + case is_starttls_available(State) of + true -> + TLSRequired = is_starttls_required(State), + [#starttls{required = TLSRequired}]; + false -> + [] + end; get_tls_feature(_) -> []. @@ -913,6 +952,12 @@ get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) -> [] end. +-spec is_starttls_available(state()) -> boolean(). +is_starttls_available(#{mod := Mod} = State) -> + try Mod:tls_enabled(State) + catch _:undef -> true + end. + -spec is_starttls_required(state()) -> boolean(). is_starttls_required(#{mod := Mod} = State) -> try Mod:tls_required(State) @@ -967,13 +1012,14 @@ send_header(#{stream_id := StreamID, lang := MyLang, xmlns := NS, server := DefaultServer} = State, - #stream_start{to = To, lang = HisLang, version = HisVersion}) -> + #stream_start{to = HisTo, from = HisFrom, + lang = HisLang, version = HisVersion}) -> Lang = select_lang(MyLang, HisLang), NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; true -> <<"">> end, - From = case To of - #jid{} -> To; + From = case HisTo of + #jid{} -> HisTo; undefined -> jid:make(DefaultServer) end, Version = case HisVersion of @@ -981,45 +1027,40 @@ send_header(#{stream_id := StreamID, {0,_} -> HisVersion; _ -> MyVersion end, - Header = xmpp:encode(#stream_start{version = Version, - lang = Lang, - xmlns = NS, - stream_xmlns = ?NS_STREAM, - db_xmlns = NS_DB, - id = StreamID, - from = From}), + StreamStart = #stream_start{version = Version, + lang = Lang, + xmlns = NS, + stream_xmlns = ?NS_STREAM, + db_xmlns = NS_DB, + id = StreamID, + to = HisFrom, + from = From}, State1 = State#{lang => Lang, stream_version => Version, stream_header_sent => true}, - case send_text(State1, fxml:element_to_header(Header)) of + case socket_send(State1, StreamStart) of ok -> State1; {error, Why} -> process_stream_end({socket, Why}, State1) end; send_header(State, _) -> State. --spec send_element(state(), xmpp_element()) -> state(). -send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> - El = xmpp:encode(Pkt, NS), - Data = fxml:element_to_binary(El), - Result = send_text(State, Data), +-spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). +send_pkt(#{mod := Mod} = State, Pkt) -> + Result = socket_send(State, Pkt), State1 = try Mod:handle_send(Pkt, Result, State) catch _:undef -> State end, - case is_disconnected(State1) of - true -> State1; - false -> - case Result of - _ when is_record(Pkt, stream_error) -> - process_stream_end({stream, {out, Pkt}}, State1); - ok -> - State1; - {error, Why} -> - process_stream_end({socket, Why}, State1) - end + case Result of + _ when is_record(Pkt, stream_error) -> + process_stream_end({stream, {out, Pkt}}, State1); + ok -> + State1; + {error, Why} -> + process_stream_end({socket, Why}, State1) end. --spec send_error(state(), xmpp_element(), stanza_error()) -> state(). +-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state(). send_error(State, Pkt, Err) -> case xmpp:is_stanza(Pkt) of true -> @@ -1030,7 +1071,7 @@ send_error(State, Pkt, Err) -> <<"error">> -> State; _ -> ErrPkt = xmpp:make_error(Pkt, Err), - send_element(State, ErrPkt) + send_pkt(State, ErrPkt) end; false -> State @@ -1038,15 +1079,23 @@ send_error(State, Pkt, Err) -> -spec send_trailer(state()) -> state(). send_trailer(State) -> - send_text(State, <<"">>), + socket_send(State, trailer), close_socket(State). --spec send_text(state(), binary()) -> ok | {error, inet:posix()}. -send_text(#{socket := Sock, sockmod := SockMod, - stream_state := StateName, - stream_header_sent := true}, Data) when StateName /= disconnected -> - SockMod:send(Sock, Data); -send_text(_, _) -> +-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}. +socket_send(#{socket := Sock, sockmod := SockMod, + stream_state := StateName, + xmlns := NS, + stream_header_sent := true}, Pkt) when StateName /= disconnected -> + case Pkt of + trailer -> + SockMod:send_trailer(Sock); + #stream_start{} -> + SockMod:send_header(Sock, xmpp:encode(Pkt)); + _ -> + SockMod:send_element(Sock, xmpp:encode(Pkt, NS)) + end; +socket_send(_, _) -> {error, closed}. -spec close_socket(state()) -> state(). @@ -1096,6 +1145,12 @@ format_sasl_error(<<"EXTERNAL">>, Err) -> format_sasl_error(Mech, Err) -> cyrsasl:format_error(Mech, Err). +-spec format_tls_error(atom() | binary()) -> list(). +format_tls_error(Reason) when is_atom(Reason) -> + format_inet_error(Reason); +format_tls_error(Reason) -> + Reason. + -spec format(io:format(), list()) -> binary(). format(Fmt, Args) -> iolist_to_binary(io_lib:format(Fmt, Args)). -- cgit v1.2.3