diff options
Diffstat (limited to 'src/xmpp_stream_in.erl')
-rw-r--r-- | src/xmpp_stream_in.erl | 1220 |
1 files changed, 0 insertions, 1220 deletions
diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl deleted file mode 100644 index 31018d434..000000000 --- a/src/xmpp_stream_in.erl +++ /dev/null @@ -1,1220 +0,0 @@ -%%%------------------------------------------------------------------- -%%% Created : 26 Nov 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net> -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2018 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%------------------------------------------------------------------- --module(xmpp_stream_in). --define(GEN_SERVER, p1_server). --behaviour(?GEN_SERVER). - --protocol({rfc, 6120}). --protocol({xep, 114, '1.6'}). - -%% API --export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1, - send/2, close/1, close/2, send_error/3, establish/1, - get_transport/1, change_shaper/2, set_timeout/2, format_error/1]). - -%% gen_server callbacks --export([init/1, handle_cast/2, handle_call/3, handle_info/2, - terminate/2, code_change/3]). - -%%-define(DBGFSM, true). --ifdef(DBGFSM). --define(FSMOPTS, [{debug, [trace]}]). --else. --define(FSMOPTS, []). --endif. - --include("xmpp.hrl"). --type state() :: map(). --type stop_reason() :: {stream, reset | {in | out, stream_error()}} | - {tls, inet:posix() | atom() | binary()} | - {socket, inet:posix() | atom()} | - internal_failure. --export_type([state/0, stop_reason/0]). --callback init(list()) -> {ok, state()} | {error, term()} | ignore. --callback handle_cast(term(), state()) -> state(). --callback handle_call(term(), term(), state()) -> state(). --callback handle_info(term(), state()) -> state(). --callback terminate(term(), state()) -> any(). --callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. --callback handle_stream_start(stream_start(), state()) -> state(). --callback handle_stream_established(state()) -> state(). --callback handle_stream_end(stop_reason(), state()) -> state(). --callback handle_cdata(binary(), state()) -> state(). --callback handle_unauthenticated_packet(xmpp_element(), state()) -> state(). --callback handle_authenticated_packet(xmpp_element(), state()) -> state(). --callback handle_unbinded_packet(xmpp_element(), state()) -> state(). --callback handle_auth_success(binary(), binary(), module(), state()) -> state(). --callback handle_auth_failure(binary(), binary(), binary(), state()) -> state(). --callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). --callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). --callback handle_timeout(state()) -> state(). --callback get_password_fun(state()) -> fun(). --callback check_password_fun(state()) -> fun(). --callback check_password_digest_fun(state()) -> fun(). --callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}. --callback compress_methods(state()) -> [binary()]. --callback tls_options(state()) -> [proplists:property()]. --callback tls_required(state()) -> boolean(). --callback tls_verify(state()) -> boolean(). --callback tls_enabled(state()) -> boolean(). --callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()]. --callback unauthenticated_stream_features(state()) -> [xmpp_element()]. --callback authenticated_stream_features(state()) -> [xmpp_element()]. - -%% All callbacks are optional --optional_callbacks([init/1, - handle_cast/2, - handle_call/3, - handle_info/2, - terminate/2, - code_change/3, - handle_stream_start/2, - handle_stream_established/1, - handle_stream_end/2, - handle_cdata/2, - handle_authenticated_packet/2, - handle_unauthenticated_packet/2, - handle_unbinded_packet/2, - handle_auth_success/4, - handle_auth_failure/4, - handle_send/3, - handle_recv/3, - handle_timeout/1, - get_password_fun/1, - check_password_fun/1, - check_password_digest_fun/1, - bind/2, - compress_methods/1, - tls_options/1, - tls_required/1, - tls_verify/1, - tls_enabled/1, - sasl_mechanisms/2, - unauthenticated_stream_features/1, - authenticated_stream_features/1]). - -%%%=================================================================== -%%% API -%%%=================================================================== -start(Mod, Args, Opts) -> - ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). - -start_link(Mod, Args, Opts) -> - ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). - -call(Ref, Msg, Timeout) -> - ?GEN_SERVER:call(Ref, Msg, Timeout). - -cast(Ref, Msg) -> - ?GEN_SERVER:cast(Ref, Msg). - -reply(Ref, Reply) -> - ?GEN_SERVER:reply(Ref, Reply). - --spec stop(pid()) -> ok; - (state()) -> no_return(). -stop(Pid) when is_pid(Pid) -> - cast(Pid, stop); -stop(#{owner := Owner} = State) when Owner == self() -> - terminate(normal, State), - exit(normal); -stop(_) -> - erlang:error(badarg). - --spec send(pid(), xmpp_element()) -> ok; - (state(), xmpp_element()) -> state(). -send(Pid, Pkt) when is_pid(Pid) -> - cast(Pid, {send, Pkt}); -send(#{owner := Owner} = State, Pkt) when Owner == self() -> - send_pkt(State, Pkt); -send(_, _) -> - erlang:error(badarg). - --spec close(pid()) -> ok; - (state()) -> state(). -close(Pid) when is_pid(Pid) -> - close(Pid, closed); -close(#{owner := Owner} = State) when Owner == self() -> - close_socket(State); -close(_) -> - erlang:error(badarg). - --spec close(pid(), atom()) -> ok. -close(Pid, Reason) -> - cast(Pid, {close, Reason}). - --spec establish(state()) -> state(). -establish(State) -> - process_stream_established(State). - --spec set_timeout(state(), non_neg_integer() | infinity) -> state(). -set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> - case Timeout of - infinity -> State#{stream_timeout => infinity}; - _ -> - Time = p1_time_compat:monotonic_time(milli_seconds), - State#{stream_timeout => {Timeout, Time}} - end; -set_timeout(_, _) -> - erlang:error(badarg). - -get_transport(#{socket := Socket, owner := Owner}) - when Owner == self() -> - xmpp_socket:get_transport(Socket); -get_transport(_) -> - erlang:error(badarg). - --spec change_shaper(state(), ejabberd_shaper:shaper()) -> state(). -change_shaper(#{socket := Socket, owner := Owner} = State, Shaper) - when Owner == self() -> - Socket1 = xmpp_socket:change_shaper(Socket, Shaper), - State#{socket => Socket1}; -change_shaper(_, _) -> - erlang:error(badarg). - --spec format_error(stop_reason()) -> binary(). -format_error({socket, Reason}) -> - format("Connection failed: ~s", [format_inet_error(Reason)]); -format_error({stream, reset}) -> - <<"Stream reset by peer">>; -format_error({stream, {in, #stream_error{} = Err}}) -> - format("Stream closed by peer: ~s", [xmpp:format_stream_error(Err)]); -format_error({stream, {out, #stream_error{} = Err}}) -> - format("Stream closed by us: ~s", [xmpp:format_stream_error(Err)]); -format_error({tls, Reason}) -> - format("TLS failed: ~s", [format_tls_error(Reason)]); -format_error(internal_failure) -> - <<"Internal server error">>; -format_error(Err) -> - format("Unrecognized error: ~w", [Err]). - -%%%=================================================================== -%%% gen_server callbacks -%%%=================================================================== -init([Mod, {_SockMod, Socket}, Opts]) -> - Encrypted = proplists:get_bool(tls, Opts), - SocketMonitor = xmpp_socket:monitor(Socket), - case xmpp_socket:peername(Socket) of - {ok, IP} -> - Time = p1_time_compat:monotonic_time(milli_seconds), - State = #{owner => self(), - mod => Mod, - socket => Socket, - socket_monitor => SocketMonitor, - stream_timeout => {timer:seconds(30), Time}, - stream_direction => in, - stream_id => new_id(), - stream_state => wait_for_stream, - stream_header_sent => false, - stream_restarted => false, - stream_compressed => false, - stream_encrypted => Encrypted, - stream_version => {1,0}, - stream_authenticated => false, - codec_options => [ignore_els], - xmlns => ?NS_CLIENT, - lang => <<"">>, - user => <<"">>, - server => <<"">>, - resource => <<"">>, - lserver => <<"">>, - ip => IP}, - case try Mod:init([State, Opts]) - catch _:undef -> {ok, State} - end of - {ok, State1} when not Encrypted -> - {_, State2, Timeout} = noreply(State1), - {ok, State2, Timeout}; - {ok, State1} when Encrypted -> - TLSOpts = try callback(tls_options, State1) - catch _:{?MODULE, undef} -> [] - end, - case xmpp_socket:starttls(Socket, TLSOpts) of - {ok, TLSSocket} -> - State2 = State1#{socket => TLSSocket}, - {_, State3, Timeout} = noreply(State2), - {ok, State3, Timeout}; - {error, Reason} -> - {stop, Reason} - end; - {error, Reason} -> - {stop, Reason}; - ignore -> - ignore - end; - {error, _Reason} -> - ignore - end. - -handle_cast({send, Pkt}, State) -> - noreply(send_pkt(State, Pkt)); -handle_cast(stop, State) -> - {stop, normal, State}; -handle_cast({close, Reason}, State) -> - State1 = close_socket(State), - noreply( - case is_disconnected(State) of - true -> State1; - false -> process_stream_end({socket, Reason}, State) - end); -handle_cast(Cast, State) -> - noreply(try callback(handle_cast, Cast, State) - catch _:{?MODULE, undef} -> State - end). - -handle_call(Call, From, State) -> - noreply(try callback(handle_call, Call, From, State) - catch _:{?MODULE, undef} -> State - end). - -handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, - #{stream_state := wait_for_stream, - xmlns := XMLNS, lang := MyLang} = State) -> - El = #xmlel{name = Name, attrs = Attrs}, - noreply( - try xmpp:decode(El, XMLNS, []) of - #stream_start{} = Pkt -> - State1 = send_header(State, Pkt), - case is_disconnected(State1) of - true -> State1; - false -> process_stream(Pkt, State1) - end; - _ -> - State1 = send_header(State), - case is_disconnected(State1) of - true -> State1; - false -> send_pkt(State1, xmpp:serr_invalid_xml()) - end - catch _:{xmpp_codec, Why} -> - State1 = send_header(State), - case is_disconnected(State1) of - true -> State1; - false -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - Err = xmpp:serr_invalid_xml(Txt, Lang), - send_pkt(State1, Err) - end - end); -handle_info({'$gen_event', {xmlstreamend, _}}, State) -> - noreply(process_stream_end({stream, reset}, State)); -handle_info({'$gen_event', closed}, State) -> - noreply(process_stream_end({socket, closed}, State)); -handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> - State1 = send_header(State), - noreply( - case is_disconnected(State1) of - true -> State1; - false -> - Err = case Reason of - <<"XML stanza is too big">> -> - xmpp:serr_policy_violation(Reason, Lang); - {_, Txt} -> - xmpp:serr_not_well_formed(Txt, Lang) - end, - send_pkt(State1, Err) - end); -handle_info({'$gen_event', El}, #{stream_state := wait_for_stream} = State) -> - error_logger:warning_msg("unexpected event from XML driver: ~p; " - "xmlstreamstart was expected", [El]), - State1 = send_header(State), - noreply( - case is_disconnected(State1) of - true -> State1; - false -> send_pkt(State1, xmpp:serr_invalid_xml()) - end); -handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS, codec_options := Opts} = State) -> - noreply( - try xmpp:decode(El, NS, Opts) of - Pkt -> - State1 = try callback(handle_recv, El, Pkt, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> process_element(Pkt, State1) - end - catch _:{xmpp_codec, Why} -> - State1 = try callback(handle_recv, El, {error, Why}, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> process_invalid_xml(State1, El, Why) - end - end); -handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, - State) -> - noreply(try callback(handle_cdata, Data, State) - catch _:{?MODULE, undef} -> State - end); -handle_info(timeout, #{lang := Lang} = State) -> - Disconnected = is_disconnected(State), - noreply(try callback(handle_timeout, State) - catch _:{?MODULE, undef} when not Disconnected -> - Txt = <<"Idle connection">>, - send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang)); - _:{?MODULE, undef} -> - stop(State) - end); -handle_info({'DOWN', MRef, _Type, _Object, _Info}, - #{socket_monitor := MRef} = State) -> - noreply(process_stream_end({socket, closed}, State)); -handle_info({tcp, _, Data}, #{socket := Socket} = State) -> - noreply( - case xmpp_socket:recv(Socket, Data) of - {ok, NewSocket} -> - State#{socket => NewSocket}; - {error, Reason} when is_atom(Reason) -> - process_stream_end({socket, Reason}, State); - {error, Reason} -> - %% TODO: make fast_tls return atoms - process_stream_end({tls, Reason}, State) - end); -handle_info({tcp_closed, _}, State) -> - handle_info({'$gen_event', closed}, State); -handle_info({tcp_error, _, Reason}, State) -> - noreply(process_stream_end({socket, Reason}, State)); -handle_info(Info, State) -> - noreply(try callback(handle_info, Info, State) - catch _:{?MODULE, undef} -> State - end). - -terminate(Reason, State) -> - case get(already_terminated) of - true -> - State; - _ -> - put(already_terminated, true), - try callback(terminate, Reason, State) - catch _:{?MODULE, undef} -> ok - end, - send_trailer(State) - end. - -code_change(OldVsn, State, Extra) -> - callback(code_change, OldVsn, State, Extra). - -%%%=================================================================== -%%% Internal functions -%%%=================================================================== --spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. -noreply(#{stream_timeout := infinity} = State) -> - {noreply, State, infinity}; -noreply(#{stream_timeout := {MSecs, StartTime}} = State) -> - CurrentTime = p1_time_compat:monotonic_time(milli_seconds), - Timeout = max(0, MSecs - CurrentTime + StartTime), - {noreply, State, Timeout}. - --spec new_id() -> binary(). -new_id() -> - p1_rand:get_string(). - --spec is_disconnected(state()) -> boolean(). -is_disconnected(#{stream_state := StreamState}) -> - StreamState == disconnected. - --spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state(). -process_invalid_xml(#{lang := MyLang} = State, El, Reason) -> - case xmpp:is_stanza(El) of - true -> - Txt = xmpp:io_format_error(Reason), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - send_error(State, El, xmpp:err_bad_request(Txt, Lang)); - false -> - case {xmpp:get_name(El), xmpp:get_ns(El)} of - {Tag, ?NS_SASL} when Tag == <<"auth">>; - Tag == <<"response">>; - Tag == <<"abort">> -> - Txt = xmpp:io_format_error(Reason), - Err = #sasl_failure{reason = 'malformed-request', - text = xmpp:mk_text(Txt, MyLang)}, - send_pkt(State, Err); - {<<"starttls">>, ?NS_TLS} -> - send_pkt(State, #starttls_failure{}); - {<<"compress">>, ?NS_COMPRESS} -> - Err = #compress_failure{reason = 'setup-failed'}, - send_pkt(State, Err); - _ -> - %% Maybe add something more? - State - end - end. - --spec process_stream_end(stop_reason(), state()) -> state(). -process_stream_end(_, #{stream_state := disconnected} = State) -> - State; -process_stream_end(Reason, State) -> - State1 = State#{stream_timeout => infinity, - stream_state => disconnected}, - try callback(handle_stream_end, Reason, State1) - catch _:{?MODULE, undef} -> stop(State1) - end. - --spec process_stream(stream_start(), state()) -> state(). -process_stream(#stream_start{xmlns = XML_NS, - stream_xmlns = STREAM_NS}, - #{xmlns := NS} = State) - when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> - send_pkt(State, xmpp:serr_invalid_namespace()); -process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> - send_pkt(State, xmpp:serr_unsupported_version()); -process_stream(#stream_start{lang = Lang}, - #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State) - when size(Lang) > 35 -> - %% As stated in BCP47, 4.4.1: - %% Protocols or specifications that specify limited buffer sizes for - %% language tags MUST allow for language tags of at least 35 characters. - %% Do not store long language tag to avoid possible DoS/flood attacks - Txt = <<"Too long value of 'xml:lang' attribute">>, - send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang)); -process_stream(#stream_start{to = undefined, version = Version} = StreamStart, - #{lang := Lang, server := Server, xmlns := NS} = State) -> - if Version < {1,0} andalso NS /= ?NS_COMPONENT -> - %% Work-around for gmail servers - To = jid:make(Server), - process_stream(StreamStart#stream_start{to = To}, State); - true -> - Txt = <<"Missing 'to' attribute">>, - send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)) - end; -process_stream(#stream_start{to = #jid{luser = U, lresource = R}}, - #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> -> - Txt = <<"Improper 'to' attribute">>, - send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang)); -process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart, - #{xmlns := ?NS_COMPONENT} = State) -> - State1 = State#{remote_server => RemoteServer, - stream_state => wait_for_handshake}, - try callback(handle_stream_start, StreamStart, State1) - catch _:{?MODULE, undef} -> State1 - end; -process_stream(#stream_start{to = #jid{server = Server, lserver = LServer}, - from = From} = StreamStart, - #{stream_authenticated := Authenticated, - stream_restarted := StreamWasRestarted, - xmlns := NS, resource := Resource, - stream_encrypted := Encrypted} = State) -> - State1 = if not StreamWasRestarted -> - State#{server => Server, lserver => LServer}; - true -> - State - end, - State2 = case From of - #jid{lserver = RemoteServer} when NS == ?NS_SERVER -> - State1#{remote_server => RemoteServer}; - _ -> - State1 - end, - State3 = try callback(handle_stream_start, StreamStart, State2) - catch _:{?MODULE, undef} -> State2 - end, - case is_disconnected(State3) of - true -> State3; - false -> - State4 = send_features(State3), - case is_disconnected(State4) of - true -> State4; - false -> - TLSRequired = is_starttls_required(State4), - if not Authenticated and (TLSRequired and not Encrypted) -> - State4#{stream_state => wait_for_starttls}; - not Authenticated -> - State4#{stream_state => wait_for_sasl_request}; - (NS == ?NS_CLIENT) and (Resource == <<"">>) -> - State4#{stream_state => wait_for_bind}; - true -> - process_stream_established(State4) - end - end - end. - --spec process_element(xmpp_element(), state()) -> state(). -process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> - case Pkt of - #starttls{} when StateName == wait_for_starttls; - StateName == wait_for_sasl_request -> - process_starttls(State); - #starttls{} -> - process_starttls_failure(unexpected_starttls_request, State); - #sasl_auth{} when StateName == wait_for_starttls -> - send_pkt(State, #sasl_failure{reason = 'encryption-required'}); - #sasl_auth{} when StateName == wait_for_sasl_request -> - process_sasl_request(Pkt, State); - #sasl_auth{} when StateName == wait_for_sasl_response -> - process_sasl_request(Pkt, maps:remove(sasl_state, State)); - #sasl_auth{} -> - Txt = <<"SASL negotiation is not allowed in this state">>, - send_pkt(State, #sasl_failure{reason = 'not-authorized', - text = xmpp:mk_text(Txt, Lang)}); - #sasl_response{} when StateName == wait_for_starttls -> - send_pkt(State, #sasl_failure{reason = 'encryption-required'}); - #sasl_response{} when StateName == wait_for_sasl_response -> - process_sasl_response(Pkt, State); - #sasl_response{} -> - Txt = <<"SASL negotiation is not allowed in this state">>, - send_pkt(State, #sasl_failure{reason = 'not-authorized', - text = xmpp:mk_text(Txt, Lang)}); - #sasl_abort{} when StateName == wait_for_sasl_response -> - process_sasl_abort(State); - #sasl_abort{} -> - send_pkt(State, #sasl_failure{reason = 'aborted'}); - #sasl_success{} -> - State; - #compress{} -> - process_compress(Pkt, State); - #handshake{} when StateName == wait_for_handshake -> - process_handshake(Pkt, State); - #handshake{} -> - State; - #stream_error{} -> - process_stream_end({stream, {in, Pkt}}, State); - _ when StateName == wait_for_sasl_request; - StateName == wait_for_handshake; - StateName == wait_for_sasl_response -> - process_unauthenticated_packet(Pkt, State); - _ when StateName == wait_for_starttls -> - Txt = <<"Use of STARTTLS required">>, - Err = xmpp:serr_policy_violation(Txt, Lang), - send_pkt(State, Err); - _ when StateName == wait_for_bind -> - process_bind(Pkt, State); - _ when StateName == established -> - process_authenticated_packet(Pkt, State) - end. - --spec process_unauthenticated_packet(xmpp_element(), state()) -> state(). -process_unauthenticated_packet(Pkt, State) -> - NewPkt = set_lang(Pkt, State), - try callback(handle_unauthenticated_packet, NewPkt, State) - catch _:{?MODULE, undef} -> - Err = xmpp:serr_not_authorized(), - send(State, Err) - end. - --spec process_authenticated_packet(xmpp_element(), state()) -> state(). -process_authenticated_packet(Pkt, State) -> - Pkt1 = set_lang(Pkt, State), - case set_from_to(Pkt1, State) of - {ok, Pkt2} -> - try callback(handle_authenticated_packet, Pkt2, State) - catch _:{?MODULE, undef} -> - Err = xmpp:err_service_unavailable(), - send_error(State, Pkt, Err) - end; - {error, Err} -> - send_pkt(State, Err) - end. - --spec process_bind(xmpp_element(), state()) -> state(). -process_bind(#iq{type = set, sub_els = [_]} = Pkt, - #{xmlns := ?NS_CLIENT, lang := MyLang} = State) -> - try xmpp:try_subtag(Pkt, #bind{}) of - #bind{resource = R} -> - case callback(bind, R, State) of - {ok, #{user := U, server := S, resource := NewR} = State1} - when NewR /= <<"">> -> - Reply = #bind{jid = jid:make(U, S, NewR)}, - State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)), - process_stream_established(State2); - {error, #stanza_error{} = Err, State1} -> - send_error(State1, Pkt, Err) - end; - _ -> - try callback(handle_unbinded_packet, Pkt, State) - catch _:{?MODULE, undef} -> - Err = xmpp:err_not_authorized(), - send_error(State, Pkt, Err) - end - catch _:{xmpp_codec, Why} -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(Pkt)), - Err = xmpp:err_bad_request(Txt, Lang), - send_error(State, Pkt, Err) - end; -process_bind(Pkt, State) -> - try callback(handle_unbinded_packet, Pkt, State) - catch _:{?MODULE, undef} -> - Err = xmpp:err_not_authorized(), - send_error(State, Pkt, Err) - end. - --spec process_handshake(handshake(), state()) -> state(). -process_handshake(#handshake{data = Digest}, - #{stream_id := StreamID, - remote_server := RemoteServer} = State) -> - GetPW = try callback(get_password_fun, State) - catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end - end, - AuthRes = case GetPW(<<"">>) of - {false, _} -> - false; - {Password, _} -> - str:sha(<<StreamID/binary, Password/binary>>) == Digest - end, - case AuthRes of - true -> - State1 = try callback(handle_auth_success, - RemoteServer, <<"handshake">>, undefined, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - State2 = send_pkt(State1, #handshake{}), - process_stream_established(State2) - end; - false -> - State1 = try callback(handle_auth_failure, - RemoteServer, <<"handshake">>, <<"not authorized">>, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> send_pkt(State1, xmpp:serr_not_authorized()) - end - end. - --spec process_stream_established(state()) -> state(). -process_stream_established(#{stream_state := StateName} = State) - when StateName == disconnected; StateName == established -> - State; -process_stream_established(State) -> - State1 = State#{stream_authenticated => true, - stream_state => established, - stream_timeout => infinity}, - try callback(handle_stream_established, State1) - catch _:{?MODULE, undef} -> State1 - end. - --spec process_compress(compress(), state()) -> state(). -process_compress(#compress{}, - #{stream_compressed := Compressed, - stream_authenticated := Authenticated} = State) - when Compressed or not Authenticated -> - send_pkt(State, #compress_failure{reason = 'setup-failed'}); -process_compress(#compress{methods = HisMethods}, - #{socket := Socket} = State) -> - MyMethods = try callback(compress_methods, State) - catch _:{?MODULE, undef} -> [] - end, - CommonMethods = lists_intersection(MyMethods, HisMethods), - case lists:member(<<"zlib">>, CommonMethods) of - true -> - case xmpp_socket:compress(Socket) of - {ok, ZlibSocket} -> - State1 = send_pkt(State, #compressed{}), - case is_disconnected(State1) of - true -> State1; - false -> - State1#{socket => ZlibSocket, - stream_id => new_id(), - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - stream_compressed => true} - end; - {error, _} -> - Err = #compress_failure{reason = 'setup-failed'}, - send_pkt(State, Err) - end; - false -> - send_pkt(State, #compress_failure{reason = 'unsupported-method'}) - end. - --spec process_starttls(state()) -> state(). -process_starttls(#{stream_encrypted := true} = State) -> - process_starttls_failure(already_encrypted, State); -process_starttls(#{socket := Socket} = State) -> - case is_starttls_available(State) of - true -> - TLSOpts = try callback(tls_options, State) - catch _:{?MODULE, undef} -> [] - end, - case xmpp_socket:starttls(Socket, TLSOpts) of - {ok, TLSSocket} -> - State1 = send_pkt(State, #starttls_proceed{}), - case is_disconnected(State1) of - true -> State1; - false -> - State1#{socket => TLSSocket, - stream_id => new_id(), - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - stream_encrypted => true} - end; - {error, Reason} -> - process_starttls_failure(Reason, State) - end; - false -> - process_starttls_failure(starttls_unsupported, State) - end. - --spec process_starttls_failure(term(), state()) -> state(). -process_starttls_failure(Why, State) -> - State1 = send_pkt(State, #starttls_failure{}), - case is_disconnected(State1) of - true -> State1; - false -> process_stream_end({tls, Why}, State1) - end. - --spec process_sasl_request(sasl_auth(), state()) -> state(). -process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, - #{lserver := LServer} = State) -> - State1 = State#{sasl_mech => Mech}, - Mechs = get_sasl_mechanisms(State1), - case lists:member(Mech, Mechs) of - true when Mech == <<"EXTERNAL">> -> - Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of - {ok, Peer} -> - {ok, [{auth_module, pkix}, {username, Peer}]}; - {error, Reason, Peer} -> - {error, Reason, Peer} - end, - process_sasl_result(Res, State1); - true -> - GetPW = try callback(get_password_fun, State1) - catch _:{?MODULE, undef} -> fun(_) -> false end - end, - CheckPW = try callback(check_password_fun, State1) - catch _:{?MODULE, undef} -> fun(_, _, _) -> false end - end, - CheckPWDigest = try callback(check_password_digest_fun, State1) - catch _:{?MODULE, undef} -> fun(_, _, _, _, _) -> false end - end, - SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [], - GetPW, CheckPW, CheckPWDigest), - Res = cyrsasl:server_start(SASLState, Mech, ClientIn), - process_sasl_result(Res, State1#{sasl_state => SASLState}); - false -> - process_sasl_result({error, unsupported_mechanism, <<"">>}, State1) - end. - --spec process_sasl_response(sasl_response(), state()) -> state(). -process_sasl_response(#sasl_response{text = ClientIn}, - #{sasl_state := SASLState} = State) -> - SASLResult = cyrsasl:server_step(SASLState, ClientIn), - process_sasl_result(SASLResult, State). - --spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state(). -process_sasl_result({ok, Props}, State) -> - process_sasl_success(Props, <<"">>, State); -process_sasl_result({ok, Props, ServerOut}, State) -> - process_sasl_success(Props, ServerOut, State); -process_sasl_result({continue, ServerOut, NewSASLState}, State) -> - process_sasl_continue(ServerOut, NewSASLState, State); -process_sasl_result({error, Reason, User}, State) -> - process_sasl_failure(Reason, User, State). - --spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state(). -process_sasl_success(Props, ServerOut, - #{socket := Socket, - sasl_mech := Mech} = State) -> - User = identity(Props), - AuthModule = proplists:get_value(auth_module, Props), - Socket1 = xmpp_socket:reset_stream(Socket), - State0 = State#{socket => Socket1}, - State1 = try callback(handle_auth_success, User, Mech, AuthModule, State0) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - State2 = send_pkt(State1, #sasl_success{text = ServerOut}), - case is_disconnected(State2) of - true -> State2; - false -> - State3 = maps:remove(sasl_state, - maps:remove(sasl_mech, State2)), - State3#{stream_id => new_id(), - stream_authenticated => true, - stream_header_sent => false, - stream_restarted => true, - stream_state => wait_for_stream, - user => User} - end - end. - --spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state(). -process_sasl_continue(ServerOut, NewSASLState, State) -> - State1 = State#{sasl_state => NewSASLState, - stream_state => wait_for_sasl_response}, - send_pkt(State1, #sasl_challenge{text = ServerOut}). - --spec process_sasl_failure(atom(), binary(), state()) -> state(). -process_sasl_failure(Err, User, - #{sasl_mech := Mech, lang := Lang} = State) -> - {Reason, Text} = format_sasl_error(Mech, Err), - State1 = try callback(handle_auth_failure, User, Mech, Text, State) - catch _:{?MODULE, undef} -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - State2 = send_pkt(State1, - #sasl_failure{reason = Reason, - text = xmpp:mk_text(Text, Lang)}), - case is_disconnected(State2) of - true -> State2; - false -> - State3 = maps:remove(sasl_state, - maps:remove(sasl_mech, State2)), - State3#{stream_state => wait_for_sasl_request} - end - end. - --spec process_sasl_abort(state()) -> state(). -process_sasl_abort(State) -> - process_sasl_failure(aborted, <<"">>, State). - --spec send_features(state()) -> state(). -send_features(#{stream_version := {1,0}, - stream_encrypted := Encrypted} = State) -> - TLSRequired = is_starttls_required(State), - Features = if TLSRequired and not Encrypted -> - get_tls_feature(State); - true -> - get_sasl_feature(State) ++ get_compress_feature(State) - ++ get_tls_feature(State) ++ get_bind_feature(State) - ++ get_session_feature(State) ++ get_other_features(State) - end, - send_pkt(State, #stream_features{sub_els = Features}); -send_features(State) -> - %% clients and servers from stone age - State. - --spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()]. -get_sasl_mechanisms(#{stream_encrypted := Encrypted, - xmlns := NS, lserver := LServer} = State) -> - Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer); - true -> [] - end, - TLSVerify = try callback(tls_verify, State) - catch _:{?MODULE, undef} -> false - end, - Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> - [<<"EXTERNAL">>|Mechs]; - true -> - Mechs - end, - try callback(sasl_mechanisms, Mechs1, State) - catch _:{?MODULE, undef} -> Mechs1 - end. - --spec get_sasl_feature(state()) -> [sasl_mechanisms()]. -get_sasl_feature(#{stream_authenticated := false, - stream_encrypted := Encrypted} = State) -> - TLSRequired = is_starttls_required(State), - if Encrypted or not TLSRequired -> - Mechs = get_sasl_mechanisms(State), - [#sasl_mechanisms{list = Mechs}]; - true -> - [] - end; -get_sasl_feature(_) -> - []. - --spec get_compress_feature(state()) -> [compression()]. -get_compress_feature(#{stream_compressed := false, - stream_authenticated := true} = State) -> - try callback(compress_methods, State) of - [] -> []; - Ms -> [#compression{methods = Ms}] - catch _:{?MODULE, undef} -> - [] - end; -get_compress_feature(_) -> - []. - --spec get_tls_feature(state()) -> [starttls()]. -get_tls_feature(#{stream_authenticated := false, - stream_encrypted := false} = State) -> - case is_starttls_available(State) of - true -> - TLSRequired = is_starttls_required(State), - [#starttls{required = TLSRequired}]; - false -> - [] - end; -get_tls_feature(_) -> - []. - --spec get_bind_feature(state()) -> [bind()]. -get_bind_feature(#{xmlns := ?NS_CLIENT, - stream_authenticated := true, - resource := <<"">>}) -> - [#bind{}]; -get_bind_feature(_) -> - []. - --spec get_session_feature(state()) -> [xmpp_session()]. -get_session_feature(#{xmlns := ?NS_CLIENT, - stream_authenticated := true, - resource := <<"">>}) -> - [#xmpp_session{optional = true}]; -get_session_feature(_) -> - []. - --spec get_other_features(state()) -> [xmpp_element()]. -get_other_features(#{stream_authenticated := Auth} = State) -> - try - if Auth -> callback(authenticated_stream_features, State); - true -> callback(unauthenticated_stream_features, State) - end - catch _:{?MODULE, undef} -> - [] - end. - --spec is_starttls_available(state()) -> boolean(). -is_starttls_available(State) -> - try callback(tls_enabled, State) - catch _:{?MODULE, undef} -> true - end. - --spec is_starttls_required(state()) -> boolean(). -is_starttls_required(State) -> - try callback(tls_required, State) - catch _:{?MODULE, undef} -> false - end. - --spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} | - {error, stream_error()}. -set_from_to(Pkt, _State) when not ?is_stanza(Pkt) -> - {ok, Pkt}; -set_from_to(Pkt, #{user := U, server := S, resource := R, - lang := Lang, xmlns := ?NS_CLIENT}) -> - JID = jid:make(U, S, R), - From = case xmpp:get_from(Pkt) of - undefined -> JID; - F -> F - end, - if JID#jid.luser == From#jid.luser andalso - JID#jid.lserver == From#jid.lserver andalso - (JID#jid.lresource == From#jid.lresource - orelse From#jid.lresource == <<"">>) -> - To = case xmpp:get_to(Pkt) of - undefined -> jid:make(U, S); - T -> T - end, - {ok, xmpp:set_from_to(Pkt, JID, To)}; - true -> - Txt = <<"Improper 'from' attribute">>, - {error, xmpp:serr_invalid_from(Txt, Lang)} - end; -set_from_to(Pkt, #{lang := Lang}) -> - From = xmpp:get_from(Pkt), - To = xmpp:get_to(Pkt), - if From == undefined -> - Txt = <<"Missing 'from' attribute">>, - {error, xmpp:serr_improper_addressing(Txt, Lang)}; - To == undefined -> - Txt = <<"Missing 'to' attribute">>, - {error, xmpp:serr_improper_addressing(Txt, Lang)}; - true -> - {ok, Pkt} - end. - --spec send_header(state()) -> state(). -send_header(#{stream_version := Version} = State) -> - send_header(State, #stream_start{version = Version}). - --spec send_header(state(), stream_start()) -> state(). -send_header(#{stream_id := StreamID, - stream_version := MyVersion, - stream_header_sent := false, - lang := MyLang, - xmlns := NS} = State, - #stream_start{to = HisTo, from = HisFrom, - lang = HisLang, version = HisVersion}) -> - Lang = select_lang(MyLang, HisLang), - NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; - true -> <<"">> - end, - Version = case HisVersion of - undefined -> undefined; - {0,_} -> HisVersion; - _ -> MyVersion - end, - StreamStart = #stream_start{version = Version, - lang = Lang, - xmlns = NS, - stream_xmlns = ?NS_STREAM, - db_xmlns = NS_DB, - id = StreamID, - to = HisFrom, - from = HisTo}, - State1 = State#{lang => Lang, - stream_version => Version, - stream_header_sent => true}, - case socket_send(State1, StreamStart) of - ok -> State1; - {error, Why} -> process_stream_end({socket, Why}, State1) - end; -send_header(State, _) -> - State. - --spec send_pkt(state(), xmpp_element() | xmlel()) -> state(). -send_pkt(State, Pkt) -> - Result = socket_send(State, Pkt), - State1 = try callback(handle_send, Pkt, Result, State) - catch _:{?MODULE, undef} -> State - end, - case Result of - _ when is_record(Pkt, stream_error) -> - process_stream_end({stream, {out, Pkt}}, State1); - ok -> - State1; - {error, Why} -> - process_stream_end({socket, Why}, State1) - end. - --spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state(). -send_error(State, Pkt, Err) -> - case xmpp:is_stanza(Pkt) of - true -> - case xmpp:get_type(Pkt) of - result -> State; - error -> State; - <<"result">> -> State; - <<"error">> -> State; - _ -> - ErrPkt = xmpp:make_error(Pkt, Err), - send_pkt(State, ErrPkt) - end; - false -> - State - end. - --spec send_trailer(state()) -> state(). -send_trailer(State) -> - socket_send(State, trailer), - close_socket(State). - --spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}. -socket_send(#{socket := Sock, - stream_state := StateName, - xmlns := NS, - stream_header_sent := true}, Pkt) -> - case Pkt of - trailer -> - xmpp_socket:send_trailer(Sock); - #stream_start{} when StateName /= disconnected -> - xmpp_socket:send_header(Sock, xmpp:encode(Pkt)); - _ when StateName /= disconnected -> - xmpp_socket:send_element(Sock, xmpp:encode(Pkt, NS)); - _ -> - {error, closed} - end; -socket_send(_, _) -> - {error, closed}. - --spec close_socket(state()) -> state(). -close_socket(#{socket := Socket} = State) -> - xmpp_socket:close(Socket), - State#{stream_timeout => infinity, - stream_state => disconnected}. - --spec select_lang(binary(), binary()) -> binary(). -select_lang(Lang, <<"">>) -> Lang; -select_lang(_, Lang) -> Lang. - --spec set_lang(xmpp_element(), state()) -> xmpp_element(). -set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) -> - HisLang = xmpp:get_lang(Pkt), - Lang = select_lang(MyLang, HisLang), - xmpp:set_lang(Pkt, Lang); -set_lang(Pkt, _) -> - Pkt. - --spec format_inet_error(atom()) -> string(). -format_inet_error(closed) -> - "connection closed"; -format_inet_error(Reason) -> - case inet:format_error(Reason) of - "unknown POSIX error" -> atom_to_list(Reason); - Txt -> Txt - end. - --spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}. -format_sasl_error(<<"EXTERNAL">>, Err) -> - xmpp_stream_pkix:format_error(Err); -format_sasl_error(Mech, Err) -> - cyrsasl:format_error(Mech, Err). - --spec format_tls_error(atom() | binary()) -> list(). -format_tls_error(Reason) when is_atom(Reason) -> - format_inet_error(Reason); -format_tls_error(Reason) -> - Reason. - --spec format(io:format(), list()) -> binary(). -format(Fmt, Args) -> - iolist_to_binary(io_lib:format(Fmt, Args)). - --spec lists_intersection(list(), list()) -> list(). -lists_intersection(L1, L2) -> - lists:filter( - fun(E) -> - lists:member(E, L2) - end, L1). - --spec identity([cyrsasl:sasl_property()]) -> binary(). -identity(Props) -> - case proplists:get_value(authzid, Props, <<>>) of - <<>> -> proplists:get_value(username, Props, <<>>); - AuthzId -> AuthzId - end. - -%%%=================================================================== -%%% Callbacks -%%%=================================================================== -callback(F, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 1) of - true -> Mod:F(State); - false -> erlang:error({?MODULE, undef}) - end. - -callback(F, Arg1, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 2) of - true -> Mod:F(Arg1, State); - false -> erlang:error({?MODULE, undef}) - end. - -callback(code_change, OldVsn, #{mod := Mod} = State, Extra) -> - %% code_change/3 callback is a special snowflake - case erlang:function_exported(Mod, code_change, 3) of - true -> Mod:code_change(OldVsn, State, Extra); - false -> {ok, State} - end; -callback(F, Arg1, Arg2, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 3) of - true -> Mod:F(Arg1, Arg2, State); - false -> erlang:error({?MODULE, undef}) - end. - -callback(F, Arg1, Arg2, Arg3, #{mod := Mod} = State) -> - case erlang:function_exported(Mod, F, 4) of - true -> Mod:F(Arg1, Arg2, Arg3, State); - false -> erlang:error({?MODULE, undef}) - end. |