aboutsummaryrefslogtreecommitdiff
path: root/src/xmpp_stream_in.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/xmpp_stream_in.erl')
-rw-r--r--src/xmpp_stream_in.erl843
1 files changed, 577 insertions, 266 deletions
diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl
index 1307f9da4..e9c1b3339 100644
--- a/src/xmpp_stream_in.erl
+++ b/src/xmpp_stream_in.erl
@@ -25,58 +25,81 @@
-protocol({rfc, 6120}).
%% API
--export([start/3, call/3, cast/2, reply/2, send/2, send_error/3,
- get_transport/1, change_shaper/2]).
+-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1,
+ send/2, close/1, close/2, send_error/3, establish/1,
+ get_transport/1, change_shaper/2, set_timeout/2, format_error/1]).
%% gen_server callbacks
-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
terminate/2, code_change/3]).
+%%-define(DBGFSM, true).
+-ifdef(DBGFSM).
+-define(FSMOPTS, [{debug, [trace]}]).
+-else.
+-define(FSMOPTS, []).
+-endif.
+
-include("xmpp.hrl").
-type state() :: map().
--type next_state() :: {noreply, state()} | {stop, term(), state()}.
+-type stop_reason() :: {stream, reset | stream_error()} |
+ {tls, term()} |
+ {socket, inet:posix() | closed | timeout}.
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
--callback handle_stream_start(state()) -> next_state().
--callback handle_stream_end(state()) -> next_state().
--callback handle_stream_close(state()) -> next_state().
--callback handle_cdata(binary(), state()) -> next_state().
--callback handle_unauthenticated_packet(xmpp_element(), state()) -> next_state().
--callback handle_authenticated_packet(xmpp_element(), state()) -> next_state().
--callback handle_unbinded_packet(xmpp_element(), state()) -> next_state().
--callback handle_auth_success(binary(), binary(), module(), state()) -> next_state().
--callback handle_auth_failure(binary(), binary(), atom(), state()) -> next_state().
--callback handle_send(ok | {error, atom()},
- xmpp_element(), fxml:xmlel(), binary(), state()) -> next_state().
--callback init_sasl(state()) -> cyrsasl:sasl_state().
+-callback handle_cast(term(), state()) -> state().
+-callback handle_call(term(), term(), state()) -> state().
+-callback handle_info(term(), state()) -> state().
+-callback terminate(term(), state()) -> any().
+-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
+-callback handle_stream_start(state()) -> state().
+-callback handle_stream_end(stop_reason(), state()) -> state().
+-callback handle_stream_close(stop_reason(), state()) -> state().
+-callback handle_cdata(binary(), state()) -> state().
+-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
+-callback handle_authenticated_packet(xmpp_element(), state()) -> state().
+-callback handle_unbinded_packet(xmpp_element(), state()) -> state().
+-callback handle_auth_success(binary(), binary(), module(), state()) -> state().
+-callback handle_auth_failure(binary(), binary(), atom(), state()) -> state().
+-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
+-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
+-callback get_password_fun(state()) -> fun().
+-callback check_password_fun(state()) -> fun().
+-callback check_password_digest_fun(state()) -> fun().
-callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}.
--callback handshake(binary(), state()) -> {ok, state()} | {error, stream_error(), state()}.
-callback compress_methods(state()) -> [binary()].
-callback tls_options(state()) -> [proplists:property()].
-callback tls_required(state()) -> boolean().
--callback sasl_mechanisms(state()) -> [binary()].
+-callback tls_verify(state()) -> boolean().
-callback unauthenticated_stream_features(state()) -> [xmpp_element()].
-callback authenticated_stream_features(state()) -> [xmpp_element()].
%% All callbacks are optional
-optional_callbacks([init/1,
+ handle_cast/2,
+ handle_call/3,
+ handle_info/2,
+ terminate/2,
+ code_change/3,
handle_stream_start/1,
- handle_stream_end/1,
- handle_stream_close/1,
+ handle_stream_end/2,
+ handle_stream_close/2,
handle_cdata/2,
handle_authenticated_packet/2,
handle_unauthenticated_packet/2,
handle_unbinded_packet/2,
handle_auth_success/4,
handle_auth_failure/4,
- handle_send/5,
- init_sasl/1,
+ handle_send/3,
+ handle_recv/3,
+ get_password_fun/1,
+ check_password_fun/1,
+ check_password_digest_fun/1,
bind/2,
- handshake/2,
compress_methods/1,
tls_options/1,
tls_required/1,
- sasl_mechanisms/1,
+ tls_verify/1,
unauthenticated_stream_features/1,
authenticated_stream_features/1]).
@@ -84,7 +107,10 @@
%%% API
%%%===================================================================
start(Mod, Args, Opts) ->
- gen_server:start(?MODULE, [Mod|Args], Opts).
+ gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+
+start_link(Mod, Args, Opts) ->
+ gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
call(Ref, Msg, Timeout) ->
gen_server:call(Ref, Msg, Timeout).
@@ -95,16 +121,80 @@ cast(Ref, Msg) ->
reply(Ref, Reply) ->
gen_server:reply(Ref, Reply).
--spec send(state(), xmpp_element()) -> next_state().
-send(State, Pkt) ->
- send_element(State, Pkt).
+-spec stop(pid()) -> ok;
+ (state()) -> no_return().
+stop(Pid) when is_pid(Pid) ->
+ cast(Pid, stop);
+stop(#{owner := Owner} = State) when Owner == self() ->
+ terminate(normal, State),
+ exit(normal);
+stop(_) ->
+ erlang:error(badarg).
-get_transport(#{sockmod := SockMod, socket := Socket}) ->
- SockMod:get_transport(Socket).
+-spec send(pid(), xmpp_element()) -> ok;
+ (state(), xmpp_element()) -> state().
+send(Pid, Pkt) when is_pid(Pid) ->
+ cast(Pid, {send, Pkt});
+send(#{owner := Owner} = State, Pkt) when Owner == self() ->
+ send_element(State, Pkt);
+send(_, _) ->
+ erlang:error(badarg).
+
+-spec close(pid()) -> ok;
+ (state()) -> state().
+close(Ref) ->
+ close(Ref, true).
+
+-spec close(pid(), boolean()) -> ok;
+ (state(), boolean()) -> state().
+close(Pid, SendTrailer) when is_pid(Pid) ->
+ cast(Pid, {close, SendTrailer});
+close(#{owner := Owner} = State, SendTrailer) when Owner == self() ->
+ if SendTrailer -> send_trailer(State);
+ true -> close_socket(State)
+ end;
+close(_, _) ->
+ erlang:error(badarg).
+
+-spec establish(state()) -> state().
+establish(State) ->
+ process_stream_established(State).
+
+-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
+set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
+ case Timeout of
+ infinity -> State#{stream_timeout => infinity};
+ _ ->
+ Time = p1_time_compat:monotonic_time(milli_seconds),
+ State#{stream_timeout => {Timeout, Time}}
+ end;
+set_timeout(_, _) ->
+ erlang:error(badarg).
+
+get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner})
+ when Owner == self() ->
+ SockMod:get_transport(Socket);
+get_transport(_) ->
+ erlang:error(badarg).
-spec change_shaper(state(), shaper:shaper()) -> ok.
-change_shaper(#{sockmod := SockMod, socket := Socket}, Shaper) ->
- SockMod:change_shaper(Socket, Shaper).
+change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper)
+ when Owner == self() ->
+ SockMod:change_shaper(Socket, Shaper);
+change_shaper(_, _) ->
+ erlang:error(badarg).
+
+-spec format_error(stop_reason()) -> binary().
+format_error({socket, Reason}) ->
+ format("Connection failed: ~s", [format_inet_error(Reason)]);
+format_error({stream, reset}) ->
+ <<"Stream reset by peer">>;
+format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
+ format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({tls, Reason}) ->
+ format("TLS failed: ~w", [Reason]);
+format_error(Err) ->
+ format("Unrecognized error: ~w", [Err]).
%%%===================================================================
%%% gen_server callbacks
@@ -114,19 +204,24 @@ init([Module, {SockMod, Socket}, Opts]) ->
{_, XS} -> XS;
false -> false
end,
- TLSEnabled = proplists:get_bool(tls, Opts),
+ Encrypted = proplists:get_bool(tls, Opts),
SocketMonitor = SockMod:monitor(Socket),
case peername(SockMod, Socket) of
{ok, IP} ->
- State = #{mod => Module,
+ Time = p1_time_compat:monotonic_time(milli_seconds),
+ State = #{owner => self(),
+ mod => Module,
socket => Socket,
sockmod => SockMod,
socket_monitor => SocketMonitor,
+ stream_timeout => {timer:seconds(30), Time},
+ stream_direction => in,
stream_id => new_id(),
stream_state => wait_for_stream,
+ stream_header_sent => false,
stream_restarted => false,
stream_compressed => false,
- stream_tlsed => TLSEnabled,
+ stream_encrypted => Encrypted,
stream_version => {1,0},
stream_authenticated => false,
xml_socket => XMLSocket,
@@ -137,97 +232,133 @@ init([Module, {SockMod, Socket}, Opts]) ->
resource => <<"">>,
lserver => <<"">>,
ip => IP},
- try Module:init([State, Opts])
- catch _:undef -> {ok, State}
+ case try Module:init([State, Opts])
+ catch _:undef -> {ok, State}
+ end of
+ {ok, State1} ->
+ {_, State2, Timeout} = noreply(State1),
+ {ok, State2, Timeout};
+ Err ->
+ Err
end;
{error, Reason} ->
{stop, Reason}
end.
+handle_cast({send, Pkt}, State) ->
+ noreply(send_element(State, Pkt));
+handle_cast(stop, State) ->
+ {stop, normal, State};
handle_cast(Cast, #{mod := Mod} = State) ->
- try Mod:handle_cast(Cast, State)
- catch _:undef -> {noreply, State}
- end.
+ noreply(try Mod:handle_cast(Cast, State)
+ catch _:undef -> State
+ end).
handle_call(Call, From, #{mod := Mod} = State) ->
- try Mod:handle_call(Call, From, State)
- catch _:undef -> {reply, ok, State}
- end.
+ noreply(try Mod:handle_call(Call, From, State)
+ catch _:undef -> State
+ end).
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
- #{stream_state := wait_for_stream, xmlns := XMLNS} = State) ->
- try xmpp:decode(#xmlel{name = Name, attrs = Attrs}, XMLNS, []) of
+ #{stream_state := wait_for_stream,
+ xmlns := XMLNS, lang := MyLang} = State) ->
+ El = #xmlel{name = Name, attrs = Attrs},
+ try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
- case send_header(State, Pkt) of
- {noreply, State1} ->
- process_stream(Pkt, State1);
- Err ->
- Err
+ State1 = send_header(State, Pkt),
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> noreply(process_stream(Pkt, State1))
end;
_ ->
- case send_header(State) of
- {noreply, State1} ->
- send_element(State1, xmpp:serr_invalid_xml());
- Err ->
- Err
+ State1 = send_header(State),
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> noreply(send_element(State1, xmpp:serr_invalid_xml()))
end
catch _:{xmpp_codec, Why} ->
- case send_header(State) of
- {noreply, State1} -> process_invalid_xml(Why, State1);
- Err -> Err
+ State1 = send_header(State),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ Txt = xmpp:io_format_error(Why),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ Err = xmpp:serr_invalid_xml(Txt, Lang),
+ noreply(send_element(State1, Err))
end
end;
-handle_info({'$gen_event', {xmlstreamend, _}}, #{mod := Mod} = State) ->
- try Mod:handle_stream_end(State)
- catch _:undef -> {stop, normal, State}
- end;
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
- case send_header(State) of
- {noreply, State1} ->
+ State1 = send_header(State),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
_ ->
xmpp:serr_not_well_formed()
end,
- send_element(State1, Err);
- Err ->
- Err
+ noreply(send_element(State1, Err))
end;
handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS} = State) ->
+ #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
- process_element(Pkt, State)
+ State1 = try Mod:handle_recv(El, Pkt, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> noreply(process_element(Pkt, State1))
+ end
catch _:{xmpp_codec, Why} ->
- process_invalid_xml(Why, State)
+ State1 = try Mod:handle_recv(El, {error, Why}, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ Txt = xmpp:io_format_error(Why),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
+ end
end;
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
- try Mod:handle_cdata(Data, State)
- catch _:undef -> {noreply, State}
- end;
-handle_info(closed, #{mod := Mod} = State) ->
- try Mod:handle_stream_close(State)
- catch _:undef -> {stop, normal, State}
- end;
+ noreply(try Mod:handle_cdata(Data, State)
+ catch _:undef -> State
+ end);
+handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
+ noreply(process_stream_end({error, {stream, reset}}, State));
+handle_info({'$gen_event', closed}, State) ->
+ noreply(process_stream_close({error, {socket, closed}}, State));
+handle_info(timeout, #{mod := Mod} = State) ->
+ Disconnected = is_disconnected(State),
+ noreply(try Mod:handle_timeout(State)
+ catch _:undef when not Disconnected ->
+ send_element(State, xmpp:serr_connection_timeout());
+ _:undef ->
+ stop(State)
+ end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
- #{socket_monitor := MRef, mod := Mod} = State) ->
- try Mod:handle_stream_close(State)
- catch _:undef -> {stop, normal, State}
- end;
+ #{socket_monitor := MRef} = State) ->
+ noreply(process_stream_close({error, {socket, closed}}, State));
handle_info(Info, #{mod := Mod} = State) ->
- try Mod:handle_info(Info, State)
- catch _:undef -> {noreply, State}
- end.
+ noreply(try Mod:handle_info(Info, State)
+ catch _:undef -> State
+ end).
-terminate(Reason, #{mod := Mod, socket := Socket,
- sockmod := SockMod} = State) ->
- try Mod:terminate(Reason, State)
- catch _:undef -> ok
- end,
- send_text(State, <<"</stream:stream>">>),
- SockMod:close(Socket).
+terminate(Reason, #{mod := Mod} = State) ->
+ case get(already_terminated) of
+ true ->
+ State;
+ _ ->
+ put(already_terminated, true),
+ try Mod:terminate(Reason, State)
+ catch _:undef -> ok
+ end,
+ send_trailer(State)
+ end.
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
Mod:code_change(OldVsn, State, Extra).
@@ -235,20 +366,49 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) ->
%%%===================================================================
%%% Internal functions
%%%===================================================================
+-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
+noreply(#{stream_timeout := infinity} = State) ->
+ {noreply, State, infinity};
+noreply(#{stream_timeout := {MSecs, StartTime}} = State) ->
+ CurrentTime = p1_time_compat:monotonic_time(milli_seconds),
+ Timeout = max(0, MSecs - CurrentTime + StartTime),
+ {noreply, State, Timeout}.
+
-spec new_id() -> binary().
new_id() ->
randoms:get_string().
+-spec is_disconnected(state()) -> boolean().
+is_disconnected(#{stream_state := StreamState}) ->
+ StreamState == disconnected.
+
+-spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}|
+ {error, inet:posix()}.
peername(SockMod, Socket) ->
case SockMod of
gen_tcp -> inet:peername(Socket);
_ -> SockMod:peername(Socket)
end.
-process_invalid_xml(Reason, #{lang := Lang} = State) ->
- Txt = xmpp:io_format_error(Reason),
- send_element(State, xmpp:serr_invalid_xml(Txt, Lang)).
+-spec process_stream_close(stop_reason(), state()) -> state().
+process_stream_close(_, #{stream_state := disconnected} = State) ->
+ State;
+process_stream_close(Reason, #{mod := Mod} = State) ->
+ State1 = send_trailer(State),
+ try Mod:handle_stream_close(Reason, State1)
+ catch _:undef -> stop(State1)
+ end.
+
+-spec process_stream_end(stop_reason(), state()) -> state().
+process_stream_end(_, #{stream_state := disconnected} = State) ->
+ State;
+process_stream_end(Reason, #{mod := Mod} = State) ->
+ State1 = send_trailer(State),
+ try Mod:handle_stream_end(Reason, State1)
+ catch _:undef -> stop(State1)
+ end.
+-spec process_stream(stream_start(), state()) -> state().
process_stream(#stream_start{xmlns = XML_NS,
stream_xmlns = STREAM_NS},
#{xmlns := NS} = State)
@@ -268,73 +428,67 @@ process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) ->
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{from = undefined, version = {1,0}},
#{lang := Lang, xmlns := ?NS_SERVER,
- stream_tlsed := true} = State) ->
+ stream_encrypted := true} = State) ->
Txt = <<"Missing 'from' attribute">>,
send_element(State, xmpp:serr_invalid_from(Txt, Lang));
process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
#{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
Txt = <<"Improper 'to' attribute">>,
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
-process_stream(#stream_start{to = #jid{lserver = RemoteServer}},
+process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
#{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
- State1 = State#{remote_server => RemoteServer},
- case try Mod:handle_stream_start(State1)
- catch _:undef -> {noreply, State1}
- end of
- {noreply, State2} ->
- {noreply, State2#{stream_state => wait_for_handshake}};
- Err ->
- Err
+ State1 = State#{remote_server => RemoteServer,
+ stream_state => wait_for_handshake},
+ try Mod:handle_stream_start(StreamStart, State1)
+ catch _:undef -> State1
end;
process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
- from = From},
+ from = From} = StreamStart,
#{stream_authenticated := Authenticated,
stream_restarted := StreamWasRestarted,
mod := Mod, xmlns := NS, resource := Resource,
- stream_tlsed := TLSEnabled} = State) ->
- case if not StreamWasRestarted ->
- State1 = State#{server => Server, lserver => LServer},
- try Mod:handle_stream_start(State1)
- catch _:undef -> {noreply, State1}
- end;
- true ->
- {noreply, State}
- end of
- {noreply, State2} ->
- State3 = if NS == ?NS_SERVER andalso TLSEnabled ->
- State2#{remote_server => From#jid.lserver};
- true ->
- State2
- end,
- case send_features(State3) of
- {noreply, State4} ->
+ stream_encrypted := Encrypted} = State) ->
+ State1 = if not StreamWasRestarted ->
+ State#{server => Server, lserver => LServer};
+ true ->
+ State
+ end,
+ State2 = if NS == ?NS_SERVER andalso Encrypted ->
+ State1#{remote_server => From#jid.lserver};
+ true ->
+ State1
+ end,
+ State3 = try Mod:handle_stream_start(StreamStart, State2)
+ catch _:undef -> State2
+ end,
+ case is_disconnected(State3) of
+ true -> State3;
+ false ->
+ State4 = send_features(State3),
+ case is_disconnected(State4) of
+ true -> State4;
+ false ->
TLSRequired = is_starttls_required(State4),
- NewStreamState =
- if not Authenticated and
- (not TLSEnabled and TLSRequired) ->
- wait_for_starttls;
- not Authenticated ->
- wait_for_sasl_request;
- (NS == ?NS_CLIENT) and (Resource == <<"">>) ->
- wait_for_bind;
- true ->
- session_established
- end,
- {noreply, State4#{stream_state => NewStreamState}};
- Err ->
- Err
- end;
- Err ->
- Err
+ if not Authenticated and (TLSRequired and not Encrypted) ->
+ State4#{stream_state => wait_for_starttls};
+ not Authenticated ->
+ State4#{stream_state => wait_for_sasl_request};
+ (NS == ?NS_CLIENT) and (Resource == <<"">>) ->
+ State4#{stream_state => wait_for_bind};
+ true ->
+ process_stream_established(State4)
+ end
+ end
end.
+-spec process_element(xmpp_element(), state()) -> state().
process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
case Pkt of
#starttls{} when StateName == wait_for_starttls;
StateName == wait_for_sasl_request ->
process_starttls(State);
#starttls{} ->
- send_element(State, #starttls_failure{});
+ process_starttls_failure(unexpected_starttls_request, State);
#sasl_auth{} when StateName == wait_for_starttls ->
send_element(State, #sasl_failure{reason = 'encryption-required'});
#sasl_auth{} when StateName == wait_for_sasl_request ->
@@ -356,7 +510,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
#sasl_abort{} ->
send_element(State, #sasl_failure{reason = 'aborted'});
#sasl_success{} ->
- {noreply, State};
+ State;
#compress{} when StateName == wait_for_sasl_response ->
send_element(State, #compress_failure{reason = 'setup-failed'});
#compress{} ->
@@ -364,7 +518,9 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
#handshake{} when StateName == wait_for_handshake ->
process_handshake(Pkt, State);
#handshake{} ->
- {noreply, State};
+ State;
+ #stream_error{} ->
+ process_stream_end({error, {stream, Pkt}}, State);
_ when StateName == wait_for_sasl_request;
StateName == wait_for_handshake;
StateName == wait_for_sasl_response ->
@@ -375,10 +531,11 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
send_error(State, Pkt, Err);
_ when StateName == wait_for_bind ->
process_bind(Pkt, State);
- _ when StateName == session_established ->
+ _ when StateName == established ->
process_authenticated_packet(Pkt, State)
end.
+-spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
NewPkt = set_lang(Pkt, State),
try Mod:handle_unauthenticated_packet(NewPkt, State)
@@ -387,6 +544,7 @@ process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
send_error(State, Pkt, Err)
end.
+-spec process_authenticated_packet(xmpp_element(), state()) -> state().
process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
Pkt1 = set_lang(Pkt, State),
case set_from_to(Pkt1, State) of
@@ -411,6 +569,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
send_element(State, Err)
end.
+-spec process_bind(xmpp_element(), state()) -> state().
process_bind(#iq{type = set, sub_els = [_]} = Pkt,
#{xmlns := ?NS_CLIENT, mod := Mod, lang := Lang} = State) ->
case xmpp:get_subtag(Pkt, #bind{}) of
@@ -426,8 +585,8 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
server := S,
resource := NewR} = State1} when NewR /= <<"">> ->
Reply = #bind{jid = jid:make(U, S, NewR)},
- State2 = State1#{stream_state => session_established},
- send_element(State2, xmpp:make_iq_result(Pkt, Reply));
+ State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)),
+ process_stream_established(State2);
{error, #stanza_error{}, State1} = Err ->
send_error(State1, Pkt, Err)
end
@@ -446,16 +605,55 @@ process_bind(Pkt, #{mod := Mod} = State) ->
send_error(State, Pkt, Err)
end.
-process_handshake(#handshake{data = Data}, #{mod := Mod} = State) ->
- case Mod:handshake(Data, State) of
- {ok, State1} ->
- State2 = State1#{stream_state => session_established,
- stream_authenticated => true},
- send_element(State2, #handshake{});
- {error, #stream_error{} = Err, State1} ->
- send_element(State1, Err)
+-spec process_handshake(handshake(), state()) -> state().
+process_handshake(#handshake{data = Digest},
+ #{mod := Mod, stream_id := StreamID,
+ remote_server := RemoteServer} = State) ->
+ GetPW = try Mod:get_password_fun(State)
+ catch _:undef -> fun(_) -> {false, undefined} end
+ end,
+ AuthRes = case GetPW(<<"">>) of
+ {false, _} ->
+ false;
+ {Password, _} ->
+ p1_sha:sha(<<StreamID/binary, Password/binary>>) == Digest
+ end,
+ case AuthRes of
+ true ->
+ State1 = try Mod:handle_auth_success(
+ RemoteServer, <<"handshake">>, undefined, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ State2 = send_element(State1, #handshake{}),
+ process_stream_established(State2)
+ end;
+ false ->
+ State1 = try Mod:handle_auth_failure(
+ RemoteServer, <<"handshake">>, 'not-authorized', State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> send_element(State1, xmpp:serr_not_authorized())
+ end
+ end.
+
+-spec process_stream_established(state()) -> state().
+process_stream_established(#{stream_state := StateName} = State)
+ when StateName == disconnected; StateName == established ->
+ State;
+process_stream_established(#{mod := Mod} = State) ->
+ State1 = State#{stream_authenticated := true,
+ stream_state => established,
+ stream_timeout => infinity},
+ try Mod:handle_stream_established(State1)
+ catch _:undef -> State1
end.
+-spec process_compress(compress(), state()) -> state().
process_compress(#compress{}, #{stream_compressed := true} = State) ->
send_element(State, #compress_failure{reason = 'setup-failed'});
process_compress(#compress{methods = HisMethods},
@@ -468,16 +666,17 @@ process_compress(#compress{methods = HisMethods},
true ->
BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})),
ZlibSocket = SockMod:compress(Socket, BCompressed),
- State1 = State#{socket => ZlibSocket,
- stream_id => new_id(),
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_compressed => true},
- {noreply, State1};
+ State#{socket => ZlibSocket,
+ stream_id => new_id(),
+ stream_header_sent => false,
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ stream_compressed => true};
false ->
send_element(State, #compress_failure{reason = 'unsupported-method'})
end.
+-spec process_starttls(state()) -> state().
process_starttls(#{socket := Socket,
sockmod := SockMod, mod := Mod} = State) ->
TLSOpts = try Mod:tls_options(State)
@@ -485,38 +684,69 @@ process_starttls(#{socket := Socket,
end,
case SockMod:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
- case send_element(State, #starttls_proceed{}) of
- {noreply, State1} ->
- {noreply, State1#{socket => TLSSocket,
- stream_id => new_id(),
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_tlsed => true}};
- Err ->
- Err
+ State1 = send_element(State, #starttls_proceed{}),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ State1#{socket => TLSSocket,
+ stream_id => new_id(),
+ stream_header_sent => false,
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ stream_encrypted => true}
end;
- {error, _Reason} ->
- send_element(State, #starttls_failure{})
+ {error, Reason} ->
+ process_starttls_failure(Reason, State)
end.
-process_sasl_request(#sasl_auth{mechanism = <<"EXTERNAL">>},
- #{stream_tlsed := false} = State) ->
- process_sasl_failure('encryption-required', <<"">>, State);
-process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
- #{mod := Mod} = State) ->
- try Mod:init_sasl(State) of
- SASLState ->
- SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn),
- process_sasl_result(SASLResult, State)
- catch _:undef ->
- process_sasl_failure('temporary-auth-failure', <<"">>, State)
+-spec process_starttls_failure(term(), state()) -> state().
+process_starttls_failure(Why, State) ->
+ State1 = send_element(State, #starttls_failure{}),
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> process_stream_end({error, {tls, Why}}, State1)
end.
+-spec process_sasl_request(sasl_auth(), state()) -> state().
+process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
+ #{mod := Mod, lserver := LServer} = State) ->
+ GetPW = try Mod:get_password_fun(State)
+ catch _:undef -> fun(_) -> false end
+ end,
+ CheckPW = try Mod:check_password_fun(State)
+ catch _:undef -> fun(_, _, _) -> false end
+ end,
+ CheckPWDigest = try Mod:check_password_digest_fun(State)
+ catch _:undef -> fun(_, _, _, _, _) -> false end
+ end,
+ SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
+ GetPW, CheckPW, CheckPWDigest),
+ State1 = State#{sasl_state => SASLState, sasl_mech => Mech},
+ Mechs = get_sasl_mechanisms(State1),
+ SASLResult = case lists:member(Mech, Mechs) of
+ true when Mech == <<"EXTERNAL">> ->
+ case xmpp_stream_pkix:authenticate(State1, ClientIn) of
+ {ok, Peer} ->
+ {ok, [{auth_module, pkix},
+ {username, Peer}]};
+ {error, _Reason, Peer} ->
+ %% TODO: return meaningful error
+ {error, 'not-authorized', Peer}
+ end;
+ true ->
+ cyrsasl:server_start(SASLState, Mech, ClientIn);
+ false ->
+ {error, 'invalid-mechanism'}
+ end,
+ process_sasl_result(SASLResult, State1).
+
+-spec process_sasl_response(sasl_response(), state()) -> state().
process_sasl_response(#sasl_response{text = ClientIn},
#{sasl_state := SASLState} = State) ->
SASLResult = cyrsasl:server_step(SASLState, ClientIn),
process_sasl_result(SASLResult, State).
+-spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state().
process_sasl_result({ok, Props}, State) ->
process_sasl_success(Props, <<"">>, State);
process_sasl_result({ok, Props, ServerOut}, State) ->
@@ -528,58 +758,59 @@ process_sasl_result({error, Reason, User}, State) ->
process_sasl_result({error, Reason}, State) ->
process_sasl_failure(Reason, <<"">>, State).
+-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
process_sasl_success(Props, ServerOut,
#{socket := Socket, sockmod := SockMod,
- mod := Mod, sasl_state := SASLState} = State) ->
- Mech = cyrsasl:get_mech(SASLState),
+ mod := Mod, sasl_mech := Mech} = State) ->
User = identity(Props),
AuthModule = proplists:get_value(auth_module, Props),
- case try Mod:handle_auth_success(User, Mech, AuthModule, State)
- catch _:undef -> {noreply, State}
- end of
- {noreply, State1} ->
+ State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
SockMod:reset_stream(Socket),
- case send_element(State1, #sasl_success{text = ServerOut}) of
- {noreply, State2} ->
- State3 = maps:remove(sasl_state, State2),
- {noreply, State3#{stream_id => new_id(),
- stream_authenticated => true,
- stream_restarted => true,
- stream_state => wait_for_stream,
- user => User}};
- Err ->
- Err
- end;
- Err ->
- Err
+ State2 = send_element(State1, #sasl_success{text = ServerOut}),
+ case is_disconnected(State2) of
+ true -> State2;
+ false ->
+ State3 = maps:remove(sasl_state,
+ maps:remove(sasl_mech, State2)),
+ State3#{stream_id => new_id(),
+ stream_authenticated => true,
+ stream_header_sent => false,
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ user => User}
+ end
end.
+-spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state().
process_sasl_continue(ServerOut, NewSASLState, State) ->
- send_element(State, #sasl_challenge{text = ServerOut}),
- {noreply, State#{sasl_state => NewSASLState,
- stream_state => wait_for_sasl_response}}.
+ State1 = State#{sasl_state => NewSASLState,
+ stream_state => wait_for_sasl_response},
+ send_element(State1, #sasl_challenge{text = ServerOut}).
+-spec process_sasl_failure(atom(), binary(), state()) -> state().
process_sasl_failure(Reason, User,
- #{mod := Mod, sasl_state := SASLState} = State) ->
- Mech = cyrsasl:get_mech(SASLState),
- case try Mod:handle_auth_failure(User, Mech, Reason, State)
- catch _:undef -> {noreply, State}
- end of
- {noreply, State1} ->
- State2 = maps:remove(sasl_state, State1),
- State3 = State2#{stream_state => wait_for_sasl_request},
- send_element(State3, #sasl_failure{reason = Reason});
- Err ->
- Err
- end.
+ #{mod := Mod, sasl_mech := Mech} = State) ->
+ State1 = try Mod:handle_auth_failure(User, Mech, Reason, State)
+ catch _:undef -> State
+ end,
+ State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)),
+ State3 = State2#{stream_state => wait_for_sasl_request},
+ send_element(State3, #sasl_failure{reason = Reason}).
+-spec process_sasl_abort(state()) -> state().
process_sasl_abort(State) ->
process_sasl_failure('aborted', <<"">>, State).
+-spec send_features(state()) -> state().
send_features(#{stream_version := {1,0},
- stream_tlsed := TLSEnabled} = State) ->
+ stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
- Features = if TLSRequired and not TLSEnabled ->
+ Features = if TLSRequired and not Encrypted ->
get_tls_feature(State);
true ->
get_sasl_feature(State) ++ get_compress_feature(State)
@@ -588,26 +819,38 @@ send_features(#{stream_version := {1,0},
end,
send_element(State, #stream_features{sub_els = Features});
send_features(State) ->
- %% clients from stone age
- {noreply, State}.
+ %% clients and servers from stone age
+ State.
+-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
+get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
+ xmlns := NS, lserver := LServer} = State) ->
+ Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
+ true -> []
+ end,
+ TLSVerify = try Mod:tls_verify(State)
+ catch _:undef -> false
+ end,
+ if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
+ [<<"EXTERNAL">>|Mechs];
+ true ->
+ Mechs
+ end.
+
+-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
get_sasl_feature(#{stream_authenticated := false,
- mod := Mod,
- stream_tlsed := TLSEnabled} = State) ->
+ stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
- if TLSEnabled or not TLSRequired ->
- try Mod:sasl_mechanisms(State) of
- [] -> [];
- List -> [#sasl_mechanisms{list = List}]
- catch _:undef ->
- []
- end;
+ if Encrypted or not TLSRequired ->
+ Mechs = get_sasl_mechanisms(State),
+ [#sasl_mechanisms{list = Mechs}];
true ->
[]
end;
get_sasl_feature(_) ->
[].
+-spec get_compress_feature(state()) -> [compression()].
get_compress_feature(#{stream_compressed := false, mod := Mod} = State) ->
try Mod:compress_methods(State) of
[] -> [];
@@ -618,23 +861,31 @@ get_compress_feature(#{stream_compressed := false, mod := Mod} = State) ->
get_compress_feature(_) ->
[].
+-spec get_tls_feature(state()) -> [starttls()].
get_tls_feature(#{stream_authenticated := false,
- stream_tlsed := false} = State) ->
+ stream_encrypted := false} = State) ->
TLSRequired = is_starttls_required(State),
[#starttls{required = TLSRequired}];
get_tls_feature(_) ->
[].
-get_bind_feature(#{stream_authenticated := true, resource := <<"">>}) ->
+-spec get_bind_feature(state()) -> [bind()].
+get_bind_feature(#{xmlns := ?NS_CLIENT,
+ stream_authenticated := true,
+ resource := <<"">>}) ->
[#bind{}];
get_bind_feature(_) ->
[].
-get_session_feature(#{stream_authenticated := true, resource := <<"">>}) ->
+-spec get_session_feature(state()) -> [xmpp_session()].
+get_session_feature(#{xmlns := ?NS_CLIENT,
+ stream_authenticated := true,
+ resource := <<"">>}) ->
[#xmpp_session{optional = true}];
get_session_feature(_) ->
[].
+-spec get_other_features(state()) -> [xmpp_element()].
get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
try
if Auth -> Mod:authenticated_stream_features(State);
@@ -644,15 +895,18 @@ get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
[]
end.
+-spec is_starttls_required(state()) -> boolean().
is_starttls_required(#{mod := Mod} = State) ->
try Mod:tls_required(State)
catch _:undef -> false
end.
+-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
+ {error, stream_error()}.
set_from_to(Pkt, _State) when not ?is_stanza(Pkt) ->
{ok, Pkt};
set_from_to(Pkt, #{user := U, server := S, resource := R,
- xmlns := ?NS_CLIENT}) ->
+ lang := Lang, xmlns := ?NS_CLIENT}) ->
JID = jid:make(U, S, R),
From = case xmpp:get_from(Pkt) of
undefined -> JID;
@@ -668,7 +922,8 @@ set_from_to(Pkt, #{user := U, server := S, resource := R,
end,
{ok, xmpp:set_from_to(Pkt, JID, To)};
true ->
- {error, xmpp:serr_invalid_from()}
+ Txt = <<"Improper 'from' attribute">>,
+ {error, xmpp:serr_invalid_from(Txt, Lang)}
end;
set_from_to(Pkt, #{lang := Lang}) ->
From = xmpp:get_from(Pkt),
@@ -683,17 +938,22 @@ set_from_to(Pkt, #{lang := Lang}) ->
{ok, Pkt}
end.
+-spec send_header(state()) -> state().
send_header(State) ->
send_header(State, #stream_start{}).
-send_header(#{stream_state := wait_for_stream,
- stream_id := StreamID,
+-spec send_header(state(), stream_start()) -> state().
+send_header(#{stream_id := StreamID,
stream_version := MyVersion,
+ stream_header_sent := false,
lang := MyLang,
xmlns := NS,
server := DefaultServer} = State,
#stream_start{to = To, lang = HisLang, version = HisVersion}) ->
- Lang = choose_lang(MyLang, HisLang),
+ Lang = select_lang(MyLang, HisLang),
+ NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
+ true -> <<"">>
+ end,
From = case To of
#jid{} -> To;
undefined -> jid:make(DefaultServer)
@@ -706,63 +966,114 @@ send_header(#{stream_state := wait_for_stream,
lang = Lang,
xmlns = NS,
stream_xmlns = ?NS_STREAM,
+ db_xmlns = NS_DB,
id = StreamID,
from = From}),
- State1 = State#{lang => Lang},
+ State1 = State#{lang => Lang, stream_header_sent => true},
case send_text(State1, fxml:element_to_header(Header)) of
- ok -> {noreply, State1};
- {error, _} -> {stop, normal, State1}
+ ok -> State1;
+ {error, Why} -> process_stream_close({error, {socket, Why}}, State1)
end;
send_header(State, _) ->
- {noreply, State}.
+ State.
+-spec send_element(state(), xmpp_element()) -> state().
send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
El = xmpp:encode(Pkt, NS),
Data = fxml:element_to_binary(El),
- case send_text(State, Data) of
- ok when is_record(Pkt, stream_error) ->
- {stop, normal, State};
- ok when is_record(Pkt, starttls_failure) ->
- {stop, normal, State};
- Res ->
- try Mod:handle_send(Res, Pkt, El, Data, State)
- catch _:undef when Res == ok ->
- {noreply, State};
- _:undef ->
- {stop, normal, State}
- end
+ Result = send_text(State, Data),
+ State1 = try Mod:handle_send(Pkt, Result, State)
+ catch _:undef -> State
+ end,
+ case Result of
+ _ when is_record(Pkt, stream_error) ->
+ process_stream_end({error, {stream, Pkt}}, State1);
+ ok ->
+ State1;
+ {error, Why} ->
+ process_stream_close({error, {socket, Why}}, State1)
end.
-send_error(State, Pkt, Err) when ?is_stanza(Pkt) ->
- case xmpp:get_type(Pkt) of
- result -> {noreply, State};
- error -> {noreply, State};
- _ ->
- ErrPkt = xmpp:make_error(Pkt, Err),
- send_element(State, ErrPkt)
- end;
-send_error(State, _, _) ->
- {noreply, State}.
+-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
+send_error(State, Pkt, Err) ->
+ case xmpp:is_stanza(Pkt) of
+ true ->
+ case xmpp:get_type(Pkt) of
+ result -> State;
+ error -> State;
+ <<"result">> -> State;
+ <<"error">> -> State;
+ _ ->
+ ErrPkt = xmpp:make_error(Pkt, Err),
+ send_element(State, ErrPkt)
+ end;
+ false ->
+ State
+ end.
+
+-spec send_trailer(state()) -> state().
+send_trailer(State) ->
+ send_text(State, <<"</stream:stream>">>),
+ close_socket(State).
-send_text(#{socket := Sock, sockmod := SockMod}, Data) ->
- SockMod:send(Sock, Data).
+-spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
+send_text(#{socket := Sock, sockmod := SockMod,
+ stream_state := StateName,
+ stream_header_sent := true}, Data) when StateName /= disconnected ->
+ SockMod:send(Sock, Data);
+send_text(_, _) ->
+ {error, einval}.
-choose_lang(Lang, <<"">>) -> Lang;
-choose_lang(_, Lang) -> Lang.
+-spec close_socket(state()) -> state().
+close_socket(#{sockmod := SockMod, socket := Socket} = State) ->
+ SockMod:close(Socket),
+ State#{stream_timeout => infinity,
+ stream_state => disconnected}.
+-spec select_lang(binary(), binary()) -> binary().
+select_lang(Lang, <<"">>) -> Lang;
+select_lang(_, Lang) -> Lang.
+
+-spec set_lang(xmpp_element(), state()) -> xmpp_element().
set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) ->
HisLang = xmpp:get_lang(Pkt),
- Lang = choose_lang(MyLang, HisLang),
+ Lang = select_lang(MyLang, HisLang),
xmpp:set_lang(Pkt, Lang);
set_lang(Pkt, _) ->
Pkt.
+-spec format_inet_error(atom()) -> string().
+format_inet_error(Reason) ->
+ case inet:format_error(Reason) of
+ "unknown POSIX error" -> atom_to_list(Reason);
+ Txt -> Txt
+ end.
+
+-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
+format_stream_error(Reason, Txt) ->
+ Slogan = case Reason of
+ #'see-other-host'{} -> "see-other-host";
+ _ -> atom_to_list(Reason)
+ end,
+ case Txt of
+ undefined -> Slogan;
+ #text{data = <<"">>} -> Slogan;
+ #text{data = Data} ->
+ binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
+ end.
+
+-spec format(io:format(), list()) -> binary().
+format(Fmt, Args) ->
+ iolist_to_binary(io_lib:format(Fmt, Args)).
+
+-spec lists_intersection(list(), list()) -> list().
lists_intersection(L1, L2) ->
lists:filter(
fun(E) ->
lists:member(E, L2)
end, L1).
+-spec identity([cyrsasl:sasl_property()]) -> binary().
identity(Props) ->
case proplists:get_value(authzid, Props, <<>>) of
<<>> -> proplists:get_value(username, Props, <<>>);