diff options
Diffstat (limited to 'src/xmpp_stream_in.erl')
-rw-r--r-- | src/xmpp_stream_in.erl | 843 |
1 files changed, 577 insertions, 266 deletions
diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index 1307f9da4..e9c1b3339 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -25,58 +25,81 @@ -protocol({rfc, 6120}). %% API --export([start/3, call/3, cast/2, reply/2, send/2, send_error/3, - get_transport/1, change_shaper/2]). +-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1, + send/2, close/1, close/2, send_error/3, establish/1, + get_transport/1, change_shaper/2, set_timeout/2, format_error/1]). %% gen_server callbacks -export([init/1, handle_cast/2, handle_call/3, handle_info/2, terminate/2, code_change/3]). +%%-define(DBGFSM, true). +-ifdef(DBGFSM). +-define(FSMOPTS, [{debug, [trace]}]). +-else. +-define(FSMOPTS, []). +-endif. + -include("xmpp.hrl"). -type state() :: map(). --type next_state() :: {noreply, state()} | {stop, term(), state()}. +-type stop_reason() :: {stream, reset | stream_error()} | + {tls, term()} | + {socket, inet:posix() | closed | timeout}. -callback init(list()) -> {ok, state()} | {stop, term()} | ignore. --callback handle_stream_start(state()) -> next_state(). --callback handle_stream_end(state()) -> next_state(). --callback handle_stream_close(state()) -> next_state(). --callback handle_cdata(binary(), state()) -> next_state(). --callback handle_unauthenticated_packet(xmpp_element(), state()) -> next_state(). --callback handle_authenticated_packet(xmpp_element(), state()) -> next_state(). --callback handle_unbinded_packet(xmpp_element(), state()) -> next_state(). --callback handle_auth_success(binary(), binary(), module(), state()) -> next_state(). --callback handle_auth_failure(binary(), binary(), atom(), state()) -> next_state(). --callback handle_send(ok | {error, atom()}, - xmpp_element(), fxml:xmlel(), binary(), state()) -> next_state(). --callback init_sasl(state()) -> cyrsasl:sasl_state(). +-callback handle_cast(term(), state()) -> state(). +-callback handle_call(term(), term(), state()) -> state(). +-callback handle_info(term(), state()) -> state(). +-callback terminate(term(), state()) -> any(). +-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. +-callback handle_stream_start(state()) -> state(). +-callback handle_stream_end(stop_reason(), state()) -> state(). +-callback handle_stream_close(stop_reason(), state()) -> state(). +-callback handle_cdata(binary(), state()) -> state(). +-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state(). +-callback handle_authenticated_packet(xmpp_element(), state()) -> state(). +-callback handle_unbinded_packet(xmpp_element(), state()) -> state(). +-callback handle_auth_success(binary(), binary(), module(), state()) -> state(). +-callback handle_auth_failure(binary(), binary(), atom(), state()) -> state(). +-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). +-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). +-callback get_password_fun(state()) -> fun(). +-callback check_password_fun(state()) -> fun(). +-callback check_password_digest_fun(state()) -> fun(). -callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}. --callback handshake(binary(), state()) -> {ok, state()} | {error, stream_error(), state()}. -callback compress_methods(state()) -> [binary()]. -callback tls_options(state()) -> [proplists:property()]. -callback tls_required(state()) -> boolean(). --callback sasl_mechanisms(state()) -> [binary()]. +-callback tls_verify(state()) -> boolean(). -callback unauthenticated_stream_features(state()) -> [xmpp_element()]. -callback authenticated_stream_features(state()) -> [xmpp_element()]. %% All callbacks are optional -optional_callbacks([init/1, + handle_cast/2, + handle_call/3, + handle_info/2, + terminate/2, + code_change/3, handle_stream_start/1, - handle_stream_end/1, - handle_stream_close/1, + handle_stream_end/2, + handle_stream_close/2, handle_cdata/2, handle_authenticated_packet/2, handle_unauthenticated_packet/2, handle_unbinded_packet/2, handle_auth_success/4, handle_auth_failure/4, - handle_send/5, - init_sasl/1, + handle_send/3, + handle_recv/3, + get_password_fun/1, + check_password_fun/1, + check_password_digest_fun/1, bind/2, - handshake/2, compress_methods/1, tls_options/1, tls_required/1, - sasl_mechanisms/1, + tls_verify/1, unauthenticated_stream_features/1, authenticated_stream_features/1]). @@ -84,7 +107,10 @@ %%% API %%%=================================================================== start(Mod, Args, Opts) -> - gen_server:start(?MODULE, [Mod|Args], Opts). + gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + +start_link(Mod, Args, Opts) -> + gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). call(Ref, Msg, Timeout) -> gen_server:call(Ref, Msg, Timeout). @@ -95,16 +121,80 @@ cast(Ref, Msg) -> reply(Ref, Reply) -> gen_server:reply(Ref, Reply). --spec send(state(), xmpp_element()) -> next_state(). -send(State, Pkt) -> - send_element(State, Pkt). +-spec stop(pid()) -> ok; + (state()) -> no_return(). +stop(Pid) when is_pid(Pid) -> + cast(Pid, stop); +stop(#{owner := Owner} = State) when Owner == self() -> + terminate(normal, State), + exit(normal); +stop(_) -> + erlang:error(badarg). -get_transport(#{sockmod := SockMod, socket := Socket}) -> - SockMod:get_transport(Socket). +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Pid, Pkt) when is_pid(Pid) -> + cast(Pid, {send, Pkt}); +send(#{owner := Owner} = State, Pkt) when Owner == self() -> + send_element(State, Pkt); +send(_, _) -> + erlang:error(badarg). + +-spec close(pid()) -> ok; + (state()) -> state(). +close(Ref) -> + close(Ref, true). + +-spec close(pid(), boolean()) -> ok; + (state(), boolean()) -> state(). +close(Pid, SendTrailer) when is_pid(Pid) -> + cast(Pid, {close, SendTrailer}); +close(#{owner := Owner} = State, SendTrailer) when Owner == self() -> + if SendTrailer -> send_trailer(State); + true -> close_socket(State) + end; +close(_, _) -> + erlang:error(badarg). + +-spec establish(state()) -> state(). +establish(State) -> + process_stream_established(State). + +-spec set_timeout(state(), non_neg_integer() | infinity) -> state(). +set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> + case Timeout of + infinity -> State#{stream_timeout => infinity}; + _ -> + Time = p1_time_compat:monotonic_time(milli_seconds), + State#{stream_timeout => {Timeout, Time}} + end; +set_timeout(_, _) -> + erlang:error(badarg). + +get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner}) + when Owner == self() -> + SockMod:get_transport(Socket); +get_transport(_) -> + erlang:error(badarg). -spec change_shaper(state(), shaper:shaper()) -> ok. -change_shaper(#{sockmod := SockMod, socket := Socket}, Shaper) -> - SockMod:change_shaper(Socket, Shaper). +change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper) + when Owner == self() -> + SockMod:change_shaper(Socket, Shaper); +change_shaper(_, _) -> + erlang:error(badarg). + +-spec format_error(stop_reason()) -> binary(). +format_error({socket, Reason}) -> + format("Connection failed: ~s", [format_inet_error(Reason)]); +format_error({stream, reset}) -> + <<"Stream reset by peer">>; +format_error({stream, #stream_error{reason = Reason, text = Txt}}) -> + format("Stream failed: ~s", [format_stream_error(Reason, Txt)]); +format_error({tls, Reason}) -> + format("TLS failed: ~w", [Reason]); +format_error(Err) -> + format("Unrecognized error: ~w", [Err]). %%%=================================================================== %%% gen_server callbacks @@ -114,19 +204,24 @@ init([Module, {SockMod, Socket}, Opts]) -> {_, XS} -> XS; false -> false end, - TLSEnabled = proplists:get_bool(tls, Opts), + Encrypted = proplists:get_bool(tls, Opts), SocketMonitor = SockMod:monitor(Socket), case peername(SockMod, Socket) of {ok, IP} -> - State = #{mod => Module, + Time = p1_time_compat:monotonic_time(milli_seconds), + State = #{owner => self(), + mod => Module, socket => Socket, sockmod => SockMod, socket_monitor => SocketMonitor, + stream_timeout => {timer:seconds(30), Time}, + stream_direction => in, stream_id => new_id(), stream_state => wait_for_stream, + stream_header_sent => false, stream_restarted => false, stream_compressed => false, - stream_tlsed => TLSEnabled, + stream_encrypted => Encrypted, stream_version => {1,0}, stream_authenticated => false, xml_socket => XMLSocket, @@ -137,97 +232,133 @@ init([Module, {SockMod, Socket}, Opts]) -> resource => <<"">>, lserver => <<"">>, ip => IP}, - try Module:init([State, Opts]) - catch _:undef -> {ok, State} + case try Module:init([State, Opts]) + catch _:undef -> {ok, State} + end of + {ok, State1} -> + {_, State2, Timeout} = noreply(State1), + {ok, State2, Timeout}; + Err -> + Err end; {error, Reason} -> {stop, Reason} end. +handle_cast({send, Pkt}, State) -> + noreply(send_element(State, Pkt)); +handle_cast(stop, State) -> + {stop, normal, State}; handle_cast(Cast, #{mod := Mod} = State) -> - try Mod:handle_cast(Cast, State) - catch _:undef -> {noreply, State} - end. + noreply(try Mod:handle_cast(Cast, State) + catch _:undef -> State + end). handle_call(Call, From, #{mod := Mod} = State) -> - try Mod:handle_call(Call, From, State) - catch _:undef -> {reply, ok, State} - end. + noreply(try Mod:handle_call(Call, From, State) + catch _:undef -> State + end). handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, - #{stream_state := wait_for_stream, xmlns := XMLNS} = State) -> - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}, XMLNS, []) of + #{stream_state := wait_for_stream, + xmlns := XMLNS, lang := MyLang} = State) -> + El = #xmlel{name = Name, attrs = Attrs}, + try xmpp:decode(El, XMLNS, []) of #stream_start{} = Pkt -> - case send_header(State, Pkt) of - {noreply, State1} -> - process_stream(Pkt, State1); - Err -> - Err + State1 = send_header(State, Pkt), + case is_disconnected(State1) of + true -> State1; + false -> noreply(process_stream(Pkt, State1)) end; _ -> - case send_header(State) of - {noreply, State1} -> - send_element(State1, xmpp:serr_invalid_xml()); - Err -> - Err + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> noreply(send_element(State1, xmpp:serr_invalid_xml())) end catch _:{xmpp_codec, Why} -> - case send_header(State) of - {noreply, State1} -> process_invalid_xml(Why, State1); - Err -> Err + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + Err = xmpp:serr_invalid_xml(Txt, Lang), + noreply(send_element(State1, Err)) end end; -handle_info({'$gen_event', {xmlstreamend, _}}, #{mod := Mod} = State) -> - try Mod:handle_stream_end(State) - catch _:undef -> {stop, normal, State} - end; handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> - case send_header(State) of - {noreply, State1} -> + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> Err = case Reason of <<"XML stanza is too big">> -> xmpp:serr_policy_violation(Reason, Lang); _ -> xmpp:serr_not_well_formed() end, - send_element(State1, Err); - Err -> - Err + noreply(send_element(State1, Err)) end; handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS} = State) -> + #{xmlns := NS, lang := MyLang, mod := Mod} = State) -> try xmpp:decode(El, NS, [ignore_els]) of Pkt -> - process_element(Pkt, State) + State1 = try Mod:handle_recv(El, Pkt, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> noreply(process_element(Pkt, State1)) + end catch _:{xmpp_codec, Why} -> - process_invalid_xml(Why, State) + State1 = try Mod:handle_recv(El, {error, Why}, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang))) + end end; handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, #{mod := Mod} = State) -> - try Mod:handle_cdata(Data, State) - catch _:undef -> {noreply, State} - end; -handle_info(closed, #{mod := Mod} = State) -> - try Mod:handle_stream_close(State) - catch _:undef -> {stop, normal, State} - end; + noreply(try Mod:handle_cdata(Data, State) + catch _:undef -> State + end); +handle_info({'$gen_event', {xmlstreamend, _}}, State) -> + noreply(process_stream_end({error, {stream, reset}}, State)); +handle_info({'$gen_event', closed}, State) -> + noreply(process_stream_close({error, {socket, closed}}, State)); +handle_info(timeout, #{mod := Mod} = State) -> + Disconnected = is_disconnected(State), + noreply(try Mod:handle_timeout(State) + catch _:undef when not Disconnected -> + send_element(State, xmpp:serr_connection_timeout()); + _:undef -> + stop(State) + end); handle_info({'DOWN', MRef, _Type, _Object, _Info}, - #{socket_monitor := MRef, mod := Mod} = State) -> - try Mod:handle_stream_close(State) - catch _:undef -> {stop, normal, State} - end; + #{socket_monitor := MRef} = State) -> + noreply(process_stream_close({error, {socket, closed}}, State)); handle_info(Info, #{mod := Mod} = State) -> - try Mod:handle_info(Info, State) - catch _:undef -> {noreply, State} - end. + noreply(try Mod:handle_info(Info, State) + catch _:undef -> State + end). -terminate(Reason, #{mod := Mod, socket := Socket, - sockmod := SockMod} = State) -> - try Mod:terminate(Reason, State) - catch _:undef -> ok - end, - send_text(State, <<"</stream:stream>">>), - SockMod:close(Socket). +terminate(Reason, #{mod := Mod} = State) -> + case get(already_terminated) of + true -> + State; + _ -> + put(already_terminated, true), + try Mod:terminate(Reason, State) + catch _:undef -> ok + end, + send_trailer(State) + end. code_change(OldVsn, #{mod := Mod} = State, Extra) -> Mod:code_change(OldVsn, State, Extra). @@ -235,20 +366,49 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== +-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. +noreply(#{stream_timeout := infinity} = State) -> + {noreply, State, infinity}; +noreply(#{stream_timeout := {MSecs, StartTime}} = State) -> + CurrentTime = p1_time_compat:monotonic_time(milli_seconds), + Timeout = max(0, MSecs - CurrentTime + StartTime), + {noreply, State, Timeout}. + -spec new_id() -> binary(). new_id() -> randoms:get_string(). +-spec is_disconnected(state()) -> boolean(). +is_disconnected(#{stream_state := StreamState}) -> + StreamState == disconnected. + +-spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}| + {error, inet:posix()}. peername(SockMod, Socket) -> case SockMod of gen_tcp -> inet:peername(Socket); _ -> SockMod:peername(Socket) end. -process_invalid_xml(Reason, #{lang := Lang} = State) -> - Txt = xmpp:io_format_error(Reason), - send_element(State, xmpp:serr_invalid_xml(Txt, Lang)). +-spec process_stream_close(stop_reason(), state()) -> state(). +process_stream_close(_, #{stream_state := disconnected} = State) -> + State; +process_stream_close(Reason, #{mod := Mod} = State) -> + State1 = send_trailer(State), + try Mod:handle_stream_close(Reason, State1) + catch _:undef -> stop(State1) + end. + +-spec process_stream_end(stop_reason(), state()) -> state(). +process_stream_end(_, #{stream_state := disconnected} = State) -> + State; +process_stream_end(Reason, #{mod := Mod} = State) -> + State1 = send_trailer(State), + try Mod:handle_stream_end(Reason, State1) + catch _:undef -> stop(State1) + end. +-spec process_stream(stream_start(), state()) -> state(). process_stream(#stream_start{xmlns = XML_NS, stream_xmlns = STREAM_NS}, #{xmlns := NS} = State) @@ -268,73 +428,67 @@ process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) -> send_element(State, xmpp:serr_improper_addressing(Txt, Lang)); process_stream(#stream_start{from = undefined, version = {1,0}}, #{lang := Lang, xmlns := ?NS_SERVER, - stream_tlsed := true} = State) -> + stream_encrypted := true} = State) -> Txt = <<"Missing 'from' attribute">>, send_element(State, xmpp:serr_invalid_from(Txt, Lang)); process_stream(#stream_start{to = #jid{luser = U, lresource = R}}, #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> -> Txt = <<"Improper 'to' attribute">>, send_element(State, xmpp:serr_improper_addressing(Txt, Lang)); -process_stream(#stream_start{to = #jid{lserver = RemoteServer}}, +process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart, #{xmlns := ?NS_COMPONENT, mod := Mod} = State) -> - State1 = State#{remote_server => RemoteServer}, - case try Mod:handle_stream_start(State1) - catch _:undef -> {noreply, State1} - end of - {noreply, State2} -> - {noreply, State2#{stream_state => wait_for_handshake}}; - Err -> - Err + State1 = State#{remote_server => RemoteServer, + stream_state => wait_for_handshake}, + try Mod:handle_stream_start(StreamStart, State1) + catch _:undef -> State1 end; process_stream(#stream_start{to = #jid{server = Server, lserver = LServer}, - from = From}, + from = From} = StreamStart, #{stream_authenticated := Authenticated, stream_restarted := StreamWasRestarted, mod := Mod, xmlns := NS, resource := Resource, - stream_tlsed := TLSEnabled} = State) -> - case if not StreamWasRestarted -> - State1 = State#{server => Server, lserver => LServer}, - try Mod:handle_stream_start(State1) - catch _:undef -> {noreply, State1} - end; - true -> - {noreply, State} - end of - {noreply, State2} -> - State3 = if NS == ?NS_SERVER andalso TLSEnabled -> - State2#{remote_server => From#jid.lserver}; - true -> - State2 - end, - case send_features(State3) of - {noreply, State4} -> + stream_encrypted := Encrypted} = State) -> + State1 = if not StreamWasRestarted -> + State#{server => Server, lserver => LServer}; + true -> + State + end, + State2 = if NS == ?NS_SERVER andalso Encrypted -> + State1#{remote_server => From#jid.lserver}; + true -> + State1 + end, + State3 = try Mod:handle_stream_start(StreamStart, State2) + catch _:undef -> State2 + end, + case is_disconnected(State3) of + true -> State3; + false -> + State4 = send_features(State3), + case is_disconnected(State4) of + true -> State4; + false -> TLSRequired = is_starttls_required(State4), - NewStreamState = - if not Authenticated and - (not TLSEnabled and TLSRequired) -> - wait_for_starttls; - not Authenticated -> - wait_for_sasl_request; - (NS == ?NS_CLIENT) and (Resource == <<"">>) -> - wait_for_bind; - true -> - session_established - end, - {noreply, State4#{stream_state => NewStreamState}}; - Err -> - Err - end; - Err -> - Err + if not Authenticated and (TLSRequired and not Encrypted) -> + State4#{stream_state => wait_for_starttls}; + not Authenticated -> + State4#{stream_state => wait_for_sasl_request}; + (NS == ?NS_CLIENT) and (Resource == <<"">>) -> + State4#{stream_state => wait_for_bind}; + true -> + process_stream_established(State4) + end + end end. +-spec process_element(xmpp_element(), state()) -> state(). process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> case Pkt of #starttls{} when StateName == wait_for_starttls; StateName == wait_for_sasl_request -> process_starttls(State); #starttls{} -> - send_element(State, #starttls_failure{}); + process_starttls_failure(unexpected_starttls_request, State); #sasl_auth{} when StateName == wait_for_starttls -> send_element(State, #sasl_failure{reason = 'encryption-required'}); #sasl_auth{} when StateName == wait_for_sasl_request -> @@ -356,7 +510,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> #sasl_abort{} -> send_element(State, #sasl_failure{reason = 'aborted'}); #sasl_success{} -> - {noreply, State}; + State; #compress{} when StateName == wait_for_sasl_response -> send_element(State, #compress_failure{reason = 'setup-failed'}); #compress{} -> @@ -364,7 +518,9 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> #handshake{} when StateName == wait_for_handshake -> process_handshake(Pkt, State); #handshake{} -> - {noreply, State}; + State; + #stream_error{} -> + process_stream_end({error, {stream, Pkt}}, State); _ when StateName == wait_for_sasl_request; StateName == wait_for_handshake; StateName == wait_for_sasl_response -> @@ -375,10 +531,11 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> send_error(State, Pkt, Err); _ when StateName == wait_for_bind -> process_bind(Pkt, State); - _ when StateName == session_established -> + _ when StateName == established -> process_authenticated_packet(Pkt, State) end. +-spec process_unauthenticated_packet(xmpp_element(), state()) -> state(). process_unauthenticated_packet(Pkt, #{mod := Mod} = State) -> NewPkt = set_lang(Pkt, State), try Mod:handle_unauthenticated_packet(NewPkt, State) @@ -387,6 +544,7 @@ process_unauthenticated_packet(Pkt, #{mod := Mod} = State) -> send_error(State, Pkt, Err) end. +-spec process_authenticated_packet(xmpp_element(), state()) -> state(). process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> Pkt1 = set_lang(Pkt, State), case set_from_to(Pkt1, State) of @@ -411,6 +569,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> send_element(State, Err) end. +-spec process_bind(xmpp_element(), state()) -> state(). process_bind(#iq{type = set, sub_els = [_]} = Pkt, #{xmlns := ?NS_CLIENT, mod := Mod, lang := Lang} = State) -> case xmpp:get_subtag(Pkt, #bind{}) of @@ -426,8 +585,8 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt, server := S, resource := NewR} = State1} when NewR /= <<"">> -> Reply = #bind{jid = jid:make(U, S, NewR)}, - State2 = State1#{stream_state => session_established}, - send_element(State2, xmpp:make_iq_result(Pkt, Reply)); + State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)), + process_stream_established(State2); {error, #stanza_error{}, State1} = Err -> send_error(State1, Pkt, Err) end @@ -446,16 +605,55 @@ process_bind(Pkt, #{mod := Mod} = State) -> send_error(State, Pkt, Err) end. -process_handshake(#handshake{data = Data}, #{mod := Mod} = State) -> - case Mod:handshake(Data, State) of - {ok, State1} -> - State2 = State1#{stream_state => session_established, - stream_authenticated => true}, - send_element(State2, #handshake{}); - {error, #stream_error{} = Err, State1} -> - send_element(State1, Err) +-spec process_handshake(handshake(), state()) -> state(). +process_handshake(#handshake{data = Digest}, + #{mod := Mod, stream_id := StreamID, + remote_server := RemoteServer} = State) -> + GetPW = try Mod:get_password_fun(State) + catch _:undef -> fun(_) -> {false, undefined} end + end, + AuthRes = case GetPW(<<"">>) of + {false, _} -> + false; + {Password, _} -> + p1_sha:sha(<<StreamID/binary, Password/binary>>) == Digest + end, + case AuthRes of + true -> + State1 = try Mod:handle_auth_success( + RemoteServer, <<"handshake">>, undefined, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + State2 = send_element(State1, #handshake{}), + process_stream_established(State2) + end; + false -> + State1 = try Mod:handle_auth_failure( + RemoteServer, <<"handshake">>, 'not-authorized', State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> send_element(State1, xmpp:serr_not_authorized()) + end + end. + +-spec process_stream_established(state()) -> state(). +process_stream_established(#{stream_state := StateName} = State) + when StateName == disconnected; StateName == established -> + State; +process_stream_established(#{mod := Mod} = State) -> + State1 = State#{stream_authenticated := true, + stream_state => established, + stream_timeout => infinity}, + try Mod:handle_stream_established(State1) + catch _:undef -> State1 end. +-spec process_compress(compress(), state()) -> state(). process_compress(#compress{}, #{stream_compressed := true} = State) -> send_element(State, #compress_failure{reason = 'setup-failed'}); process_compress(#compress{methods = HisMethods}, @@ -468,16 +666,17 @@ process_compress(#compress{methods = HisMethods}, true -> BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})), ZlibSocket = SockMod:compress(Socket, BCompressed), - State1 = State#{socket => ZlibSocket, - stream_id => new_id(), - stream_restarted => true, - stream_state => wait_for_stream, - stream_compressed => true}, - {noreply, State1}; + State#{socket => ZlibSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_compressed => true}; false -> send_element(State, #compress_failure{reason = 'unsupported-method'}) end. +-spec process_starttls(state()) -> state(). process_starttls(#{socket := Socket, sockmod := SockMod, mod := Mod} = State) -> TLSOpts = try Mod:tls_options(State) @@ -485,38 +684,69 @@ process_starttls(#{socket := Socket, end, case SockMod:starttls(Socket, TLSOpts) of {ok, TLSSocket} -> - case send_element(State, #starttls_proceed{}) of - {noreply, State1} -> - {noreply, State1#{socket => TLSSocket, - stream_id => new_id(), - stream_restarted => true, - stream_state => wait_for_stream, - stream_tlsed => true}}; - Err -> - Err + State1 = send_element(State, #starttls_proceed{}), + case is_disconnected(State1) of + true -> State1; + false -> + State1#{socket => TLSSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_encrypted => true} end; - {error, _Reason} -> - send_element(State, #starttls_failure{}) + {error, Reason} -> + process_starttls_failure(Reason, State) end. -process_sasl_request(#sasl_auth{mechanism = <<"EXTERNAL">>}, - #{stream_tlsed := false} = State) -> - process_sasl_failure('encryption-required', <<"">>, State); -process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, - #{mod := Mod} = State) -> - try Mod:init_sasl(State) of - SASLState -> - SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn), - process_sasl_result(SASLResult, State) - catch _:undef -> - process_sasl_failure('temporary-auth-failure', <<"">>, State) +-spec process_starttls_failure(term(), state()) -> state(). +process_starttls_failure(Why, State) -> + State1 = send_element(State, #starttls_failure{}), + case is_disconnected(State1) of + true -> State1; + false -> process_stream_end({error, {tls, Why}}, State1) end. +-spec process_sasl_request(sasl_auth(), state()) -> state(). +process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, + #{mod := Mod, lserver := LServer} = State) -> + GetPW = try Mod:get_password_fun(State) + catch _:undef -> fun(_) -> false end + end, + CheckPW = try Mod:check_password_fun(State) + catch _:undef -> fun(_, _, _) -> false end + end, + CheckPWDigest = try Mod:check_password_digest_fun(State) + catch _:undef -> fun(_, _, _, _, _) -> false end + end, + SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [], + GetPW, CheckPW, CheckPWDigest), + State1 = State#{sasl_state => SASLState, sasl_mech => Mech}, + Mechs = get_sasl_mechanisms(State1), + SASLResult = case lists:member(Mech, Mechs) of + true when Mech == <<"EXTERNAL">> -> + case xmpp_stream_pkix:authenticate(State1, ClientIn) of + {ok, Peer} -> + {ok, [{auth_module, pkix}, + {username, Peer}]}; + {error, _Reason, Peer} -> + %% TODO: return meaningful error + {error, 'not-authorized', Peer} + end; + true -> + cyrsasl:server_start(SASLState, Mech, ClientIn); + false -> + {error, 'invalid-mechanism'} + end, + process_sasl_result(SASLResult, State1). + +-spec process_sasl_response(sasl_response(), state()) -> state(). process_sasl_response(#sasl_response{text = ClientIn}, #{sasl_state := SASLState} = State) -> SASLResult = cyrsasl:server_step(SASLState, ClientIn), process_sasl_result(SASLResult, State). +-spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state(). process_sasl_result({ok, Props}, State) -> process_sasl_success(Props, <<"">>, State); process_sasl_result({ok, Props, ServerOut}, State) -> @@ -528,58 +758,59 @@ process_sasl_result({error, Reason, User}, State) -> process_sasl_result({error, Reason}, State) -> process_sasl_failure(Reason, <<"">>, State). +-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state(). process_sasl_success(Props, ServerOut, #{socket := Socket, sockmod := SockMod, - mod := Mod, sasl_state := SASLState} = State) -> - Mech = cyrsasl:get_mech(SASLState), + mod := Mod, sasl_mech := Mech} = State) -> User = identity(Props), AuthModule = proplists:get_value(auth_module, Props), - case try Mod:handle_auth_success(User, Mech, AuthModule, State) - catch _:undef -> {noreply, State} - end of - {noreply, State1} -> + State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> SockMod:reset_stream(Socket), - case send_element(State1, #sasl_success{text = ServerOut}) of - {noreply, State2} -> - State3 = maps:remove(sasl_state, State2), - {noreply, State3#{stream_id => new_id(), - stream_authenticated => true, - stream_restarted => true, - stream_state => wait_for_stream, - user => User}}; - Err -> - Err - end; - Err -> - Err + State2 = send_element(State1, #sasl_success{text = ServerOut}), + case is_disconnected(State2) of + true -> State2; + false -> + State3 = maps:remove(sasl_state, + maps:remove(sasl_mech, State2)), + State3#{stream_id => new_id(), + stream_authenticated => true, + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + user => User} + end end. +-spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state(). process_sasl_continue(ServerOut, NewSASLState, State) -> - send_element(State, #sasl_challenge{text = ServerOut}), - {noreply, State#{sasl_state => NewSASLState, - stream_state => wait_for_sasl_response}}. + State1 = State#{sasl_state => NewSASLState, + stream_state => wait_for_sasl_response}, + send_element(State1, #sasl_challenge{text = ServerOut}). +-spec process_sasl_failure(atom(), binary(), state()) -> state(). process_sasl_failure(Reason, User, - #{mod := Mod, sasl_state := SASLState} = State) -> - Mech = cyrsasl:get_mech(SASLState), - case try Mod:handle_auth_failure(User, Mech, Reason, State) - catch _:undef -> {noreply, State} - end of - {noreply, State1} -> - State2 = maps:remove(sasl_state, State1), - State3 = State2#{stream_state => wait_for_sasl_request}, - send_element(State3, #sasl_failure{reason = Reason}); - Err -> - Err - end. + #{mod := Mod, sasl_mech := Mech} = State) -> + State1 = try Mod:handle_auth_failure(User, Mech, Reason, State) + catch _:undef -> State + end, + State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)), + State3 = State2#{stream_state => wait_for_sasl_request}, + send_element(State3, #sasl_failure{reason = Reason}). +-spec process_sasl_abort(state()) -> state(). process_sasl_abort(State) -> process_sasl_failure('aborted', <<"">>, State). +-spec send_features(state()) -> state(). send_features(#{stream_version := {1,0}, - stream_tlsed := TLSEnabled} = State) -> + stream_encrypted := Encrypted} = State) -> TLSRequired = is_starttls_required(State), - Features = if TLSRequired and not TLSEnabled -> + Features = if TLSRequired and not Encrypted -> get_tls_feature(State); true -> get_sasl_feature(State) ++ get_compress_feature(State) @@ -588,26 +819,38 @@ send_features(#{stream_version := {1,0}, end, send_element(State, #stream_features{sub_els = Features}); send_features(State) -> - %% clients from stone age - {noreply, State}. + %% clients and servers from stone age + State. +-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()]. +get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod, + xmlns := NS, lserver := LServer} = State) -> + Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer); + true -> [] + end, + TLSVerify = try Mod:tls_verify(State) + catch _:undef -> false + end, + if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> + [<<"EXTERNAL">>|Mechs]; + true -> + Mechs + end. + +-spec get_sasl_feature(state()) -> [sasl_mechanisms()]. get_sasl_feature(#{stream_authenticated := false, - mod := Mod, - stream_tlsed := TLSEnabled} = State) -> + stream_encrypted := Encrypted} = State) -> TLSRequired = is_starttls_required(State), - if TLSEnabled or not TLSRequired -> - try Mod:sasl_mechanisms(State) of - [] -> []; - List -> [#sasl_mechanisms{list = List}] - catch _:undef -> - [] - end; + if Encrypted or not TLSRequired -> + Mechs = get_sasl_mechanisms(State), + [#sasl_mechanisms{list = Mechs}]; true -> [] end; get_sasl_feature(_) -> []. +-spec get_compress_feature(state()) -> [compression()]. get_compress_feature(#{stream_compressed := false, mod := Mod} = State) -> try Mod:compress_methods(State) of [] -> []; @@ -618,23 +861,31 @@ get_compress_feature(#{stream_compressed := false, mod := Mod} = State) -> get_compress_feature(_) -> []. +-spec get_tls_feature(state()) -> [starttls()]. get_tls_feature(#{stream_authenticated := false, - stream_tlsed := false} = State) -> + stream_encrypted := false} = State) -> TLSRequired = is_starttls_required(State), [#starttls{required = TLSRequired}]; get_tls_feature(_) -> []. -get_bind_feature(#{stream_authenticated := true, resource := <<"">>}) -> +-spec get_bind_feature(state()) -> [bind()]. +get_bind_feature(#{xmlns := ?NS_CLIENT, + stream_authenticated := true, + resource := <<"">>}) -> [#bind{}]; get_bind_feature(_) -> []. -get_session_feature(#{stream_authenticated := true, resource := <<"">>}) -> +-spec get_session_feature(state()) -> [xmpp_session()]. +get_session_feature(#{xmlns := ?NS_CLIENT, + stream_authenticated := true, + resource := <<"">>}) -> [#xmpp_session{optional = true}]; get_session_feature(_) -> []. +-spec get_other_features(state()) -> [xmpp_element()]. get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) -> try if Auth -> Mod:authenticated_stream_features(State); @@ -644,15 +895,18 @@ get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) -> [] end. +-spec is_starttls_required(state()) -> boolean(). is_starttls_required(#{mod := Mod} = State) -> try Mod:tls_required(State) catch _:undef -> false end. +-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} | + {error, stream_error()}. set_from_to(Pkt, _State) when not ?is_stanza(Pkt) -> {ok, Pkt}; set_from_to(Pkt, #{user := U, server := S, resource := R, - xmlns := ?NS_CLIENT}) -> + lang := Lang, xmlns := ?NS_CLIENT}) -> JID = jid:make(U, S, R), From = case xmpp:get_from(Pkt) of undefined -> JID; @@ -668,7 +922,8 @@ set_from_to(Pkt, #{user := U, server := S, resource := R, end, {ok, xmpp:set_from_to(Pkt, JID, To)}; true -> - {error, xmpp:serr_invalid_from()} + Txt = <<"Improper 'from' attribute">>, + {error, xmpp:serr_invalid_from(Txt, Lang)} end; set_from_to(Pkt, #{lang := Lang}) -> From = xmpp:get_from(Pkt), @@ -683,17 +938,22 @@ set_from_to(Pkt, #{lang := Lang}) -> {ok, Pkt} end. +-spec send_header(state()) -> state(). send_header(State) -> send_header(State, #stream_start{}). -send_header(#{stream_state := wait_for_stream, - stream_id := StreamID, +-spec send_header(state(), stream_start()) -> state(). +send_header(#{stream_id := StreamID, stream_version := MyVersion, + stream_header_sent := false, lang := MyLang, xmlns := NS, server := DefaultServer} = State, #stream_start{to = To, lang = HisLang, version = HisVersion}) -> - Lang = choose_lang(MyLang, HisLang), + Lang = select_lang(MyLang, HisLang), + NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; + true -> <<"">> + end, From = case To of #jid{} -> To; undefined -> jid:make(DefaultServer) @@ -706,63 +966,114 @@ send_header(#{stream_state := wait_for_stream, lang = Lang, xmlns = NS, stream_xmlns = ?NS_STREAM, + db_xmlns = NS_DB, id = StreamID, from = From}), - State1 = State#{lang => Lang}, + State1 = State#{lang => Lang, stream_header_sent => true}, case send_text(State1, fxml:element_to_header(Header)) of - ok -> {noreply, State1}; - {error, _} -> {stop, normal, State1} + ok -> State1; + {error, Why} -> process_stream_close({error, {socket, Why}}, State1) end; send_header(State, _) -> - {noreply, State}. + State. +-spec send_element(state(), xmpp_element()) -> state(). send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> El = xmpp:encode(Pkt, NS), Data = fxml:element_to_binary(El), - case send_text(State, Data) of - ok when is_record(Pkt, stream_error) -> - {stop, normal, State}; - ok when is_record(Pkt, starttls_failure) -> - {stop, normal, State}; - Res -> - try Mod:handle_send(Res, Pkt, El, Data, State) - catch _:undef when Res == ok -> - {noreply, State}; - _:undef -> - {stop, normal, State} - end + Result = send_text(State, Data), + State1 = try Mod:handle_send(Pkt, Result, State) + catch _:undef -> State + end, + case Result of + _ when is_record(Pkt, stream_error) -> + process_stream_end({error, {stream, Pkt}}, State1); + ok -> + State1; + {error, Why} -> + process_stream_close({error, {socket, Why}}, State1) end. -send_error(State, Pkt, Err) when ?is_stanza(Pkt) -> - case xmpp:get_type(Pkt) of - result -> {noreply, State}; - error -> {noreply, State}; - _ -> - ErrPkt = xmpp:make_error(Pkt, Err), - send_element(State, ErrPkt) - end; -send_error(State, _, _) -> - {noreply, State}. +-spec send_error(state(), xmpp_element(), stanza_error()) -> state(). +send_error(State, Pkt, Err) -> + case xmpp:is_stanza(Pkt) of + true -> + case xmpp:get_type(Pkt) of + result -> State; + error -> State; + <<"result">> -> State; + <<"error">> -> State; + _ -> + ErrPkt = xmpp:make_error(Pkt, Err), + send_element(State, ErrPkt) + end; + false -> + State + end. + +-spec send_trailer(state()) -> state(). +send_trailer(State) -> + send_text(State, <<"</stream:stream>">>), + close_socket(State). -send_text(#{socket := Sock, sockmod := SockMod}, Data) -> - SockMod:send(Sock, Data). +-spec send_text(state(), binary()) -> ok | {error, inet:posix()}. +send_text(#{socket := Sock, sockmod := SockMod, + stream_state := StateName, + stream_header_sent := true}, Data) when StateName /= disconnected -> + SockMod:send(Sock, Data); +send_text(_, _) -> + {error, einval}. -choose_lang(Lang, <<"">>) -> Lang; -choose_lang(_, Lang) -> Lang. +-spec close_socket(state()) -> state(). +close_socket(#{sockmod := SockMod, socket := Socket} = State) -> + SockMod:close(Socket), + State#{stream_timeout => infinity, + stream_state => disconnected}. +-spec select_lang(binary(), binary()) -> binary(). +select_lang(Lang, <<"">>) -> Lang; +select_lang(_, Lang) -> Lang. + +-spec set_lang(xmpp_element(), state()) -> xmpp_element(). set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) -> HisLang = xmpp:get_lang(Pkt), - Lang = choose_lang(MyLang, HisLang), + Lang = select_lang(MyLang, HisLang), xmpp:set_lang(Pkt, Lang); set_lang(Pkt, _) -> Pkt. +-spec format_inet_error(atom()) -> string(). +format_inet_error(Reason) -> + case inet:format_error(Reason) of + "unknown POSIX error" -> atom_to_list(Reason); + Txt -> Txt + end. + +-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string(). +format_stream_error(Reason, Txt) -> + Slogan = case Reason of + #'see-other-host'{} -> "see-other-host"; + _ -> atom_to_list(Reason) + end, + case Txt of + undefined -> Slogan; + #text{data = <<"">>} -> Slogan; + #text{data = Data} -> + binary_to_list(Data) ++ " (" ++ Slogan ++ ")" + end. + +-spec format(io:format(), list()) -> binary(). +format(Fmt, Args) -> + iolist_to_binary(io_lib:format(Fmt, Args)). + +-spec lists_intersection(list(), list()) -> list(). lists_intersection(L1, L2) -> lists:filter( fun(E) -> lists:member(E, L2) end, L1). +-spec identity([cyrsasl:sasl_property()]) -> binary(). identity(Props) -> case proplists:get_value(authzid, Props, <<>>) of <<>> -> proplists:get_value(username, Props, <<>>); |