summaryrefslogtreecommitdiff
path: root/src/xmpp_stream_out.erl
diff options
context:
space:
mode:
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>2016-12-30 00:00:36 +0300
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>2016-12-30 00:00:36 +0300
commite7fe4dc474ed180a4200b2bdefc2ff58b12340c0 (patch)
tree83a2ca201a8fae1f1ba49a2f2fb541ad64e02f91 /src/xmpp_stream_out.erl
parentAdd xmpp_stream_out behaviour and rewrite s2s/SM code (diff)
More refactoring on session management
Diffstat (limited to 'src/xmpp_stream_out.erl')
-rw-r--r--src/xmpp_stream_out.erl204
1 files changed, 110 insertions, 94 deletions
diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl
index fc373fff..08804e43 100644
--- a/src/xmpp_stream_out.erl
+++ b/src/xmpp_stream_out.erl
@@ -33,6 +33,7 @@
-include_lib("kernel/include/inet.hrl").
-type state() :: map().
+-type noreply() :: {noreply, state(), timeout()}.
-type host_port() :: {inet:hostname(), inet:port_number()}.
-type ip_port() :: {inet:ip_address(), inet:port_number()}.
-type network_error() :: {error, inet:posix() | inet_res:res_error()}.
@@ -42,7 +43,8 @@
{tls, term()} |
{pkix, binary()} |
{auth, atom() | binary() | string()} |
- {socket, inet:posix() | closed | timeout}.
+ {socket, inet:posix() | closed | timeout} |
+ internal_failure.
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
@@ -107,7 +109,7 @@ close(_, _) ->
establish(State) ->
process_stream_established(State).
--spec set_timeout(state(), non_neg_integer() | infinity) -> state().
+-spec set_timeout(state(), timeout()) -> state().
set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
case Timeout of
infinity -> State#{stream_timeout => infinity};
@@ -148,12 +150,15 @@ format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]);
format_error({auth, Reason}) ->
format("Authentication failed: ~s", [Reason]);
+format_error(internal_failure) ->
+ <<"Internal server error">>;
format_error(Err) ->
format("Unrecognized error: ~w", [Err]).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
+-spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore.
init([Mod, SockMod, From, To, Opts]) ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(),
@@ -183,36 +188,38 @@ init([Mod, SockMod, From, To, Opts]) ->
Err
end.
+-spec handle_call(term(), term(), state()) -> noreply().
handle_call(Call, From, #{mod := Mod} = State) ->
noreply(try Mod:handle_call(Call, From, State)
catch _:undef -> State
end).
+-spec handle_cast(term(), state()) -> noreply().
handle_cast(connect, #{remote_server := RemoteServer,
sockmod := SockMod,
stream_state := connecting} = State) ->
- case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
- false ->
- noreply(process_stream_close({error, {idna, bad_string}}, State));
- ASCIIName ->
- case resolve(binary_to_list(ASCIIName), State) of
- {ok, AddrPorts} ->
- case connect(AddrPorts, State) of
- {ok, Socket, AddrPort} ->
- SocketMonitor = SockMod:monitor(Socket),
- State1 = State#{ip => AddrPort,
- socket => Socket,
- socket_monitor => SocketMonitor},
- State2 = State1#{stream_state => wait_for_stream},
- noreply(send_header(State2));
- {error, Why} ->
- Err = {error, {socket, Why}},
- noreply(process_stream_close(Err, State))
- end;
- {error, Why} ->
- noreply(process_stream_close({error, {dns, Why}}, State))
- end
- end;
+ noreply(
+ case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
+ false ->
+ process_stream_end({idna, bad_string}, State);
+ ASCIIName ->
+ case resolve(binary_to_list(ASCIIName), State) of
+ {ok, AddrPorts} ->
+ case connect(AddrPorts, State) of
+ {ok, Socket, AddrPort} ->
+ SocketMonitor = SockMod:monitor(Socket),
+ State1 = State#{ip => AddrPort,
+ socket => Socket,
+ socket_monitor => SocketMonitor},
+ State2 = State1#{stream_state => wait_for_stream},
+ send_header(State2);
+ {error, Why} ->
+ process_stream_end({socket, Why}, State)
+ end;
+ {error, Why} ->
+ process_stream_end({dns, Why}, State)
+ end
+ end);
handle_cast(connect, State) ->
%% Ignoring connection attempts in other states
noreply(State);
@@ -225,66 +232,70 @@ handle_cast(Cast, #{mod := Mod} = State) ->
catch _:undef -> State
end).
+-spec handle_info(term(), state()) -> noreply().
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
#{stream_state := wait_for_stream,
xmlns := XMLNS, lang := MyLang} = State) ->
El = #xmlel{name = Name, attrs = Attrs},
- try xmpp:decode(El, XMLNS, []) of
- #stream_start{} = Pkt ->
- noreply(process_stream(Pkt, State));
- _ ->
- noreply(send_element(State, xmpp:serr_invalid_xml()))
- catch _:{xmpp_codec, Why} ->
- Txt = xmpp:io_format_error(Why),
- Lang = select_lang(MyLang, xmpp:get_lang(El)),
- Err = xmpp:serr_invalid_xml(Txt, Lang),
- noreply(send_element(State, Err))
- end;
+ noreply(
+ try xmpp:decode(El, XMLNS, []) of
+ #stream_start{} = Pkt ->
+ process_stream(Pkt, State);
+ _ ->
+ send_element(State, xmpp:serr_invalid_xml())
+ catch _:{xmpp_codec, Why} ->
+ Txt = xmpp:io_format_error(Why),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ Err = xmpp:serr_invalid_xml(Txt, Lang),
+ send_element(State, Err)
+ end);
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
State1 = send_header(State),
- case is_disconnected(State1) of
- true -> State1;
- false ->
- Err = case Reason of
- <<"XML stanza is too big">> ->
- xmpp:serr_policy_violation(Reason, Lang);
- _ ->
- xmpp:serr_not_well_formed()
- end,
- noreply(send_element(State1, Err))
- end;
+ noreply(
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ Err = case Reason of
+ <<"XML stanza is too big">> ->
+ xmpp:serr_policy_violation(Reason, Lang);
+ _ ->
+ xmpp:serr_not_well_formed()
+ end,
+ send_element(State1, Err)
+ end);
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
- try xmpp:decode(El, NS, [ignore_els]) of
- Pkt ->
- State1 = try Mod:handle_recv(El, Pkt, State)
- catch _:undef -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false -> noreply(process_element(Pkt, State1))
- end
- catch _:{xmpp_codec, Why} ->
- State1 = try Mod:handle_recv(El, undefined, State)
- catch _:undef -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false ->
- Txt = xmpp:io_format_error(Why),
- Lang = select_lang(MyLang, xmpp:get_lang(El)),
- noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
- end
- end;
+ noreply(
+ try xmpp:decode(El, NS, [ignore_els]) of
+ Pkt ->
+ State1 = try Mod:handle_recv(El, Pkt, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> process_element(Pkt, State1)
+ end
+ catch _:{xmpp_codec, Why} ->
+ State1 = try Mod:handle_recv(El, undefined, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ Txt = xmpp:io_format_error(Why),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
+ end
+ end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
noreply(try Mod:handle_cdata(Data, State)
catch _:undef -> State
end);
handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
- noreply(process_stream_end({error, {stream, reset}}, State));
+ noreply(process_stream_end({stream, reset}, State));
handle_info({'$gen_event', closed}, State) ->
- noreply(process_stream_close({error, {socket, closed}}, State));
+ noreply(process_stream_end({socket, closed}, State));
handle_info(timeout, #{mod := Mod} = State) ->
Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State)
@@ -295,12 +306,13 @@ handle_info(timeout, #{mod := Mod} = State) ->
end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
#{socket_monitor := MRef} = State) ->
- noreply(process_stream_close({error, {socket, closed}}, State));
+ noreply(process_stream_end({socket, closed}, State));
handle_info(Info, #{mod := Mod} = State) ->
noreply(try Mod:handle_info(Info, State)
catch _:undef -> State
end).
+-spec terminate(term(), state()) -> any().
terminate(Reason, #{mod := Mod} = State) ->
case get(already_terminated) of
true ->
@@ -319,7 +331,7 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) ->
%%%===================================================================
%%% Internal functions
%%%===================================================================
--spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
+-spec noreply(state()) -> noreply().
noreply(#{stream_timeout := infinity} = State) ->
{noreply, State, infinity};
noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
@@ -335,15 +347,6 @@ new_id() ->
is_disconnected(#{stream_state := StreamState}) ->
StreamState == disconnected.
--spec process_stream_close(stop_reason(), state()) -> state().
-process_stream_close(_, #{stream_state := disconnected} = State) ->
- State;
-process_stream_close(Reason, #{mod := Mod} = State) ->
- State1 = send_trailer(State),
- try Mod:handle_stream_close(Reason, State1)
- catch _:undef -> stop(State1)
- end.
-
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
@@ -359,6 +362,8 @@ process_stream(#stream_start{xmlns = XML_NS,
#{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
send_element(State, xmpp:serr_invalid_namespace());
+process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
+ send_element(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang, id = ID,
version = Version} = StreamStart,
#{mod := Mod} = State) ->
@@ -370,8 +375,10 @@ process_stream(#stream_start{lang = Lang, id = ID,
true -> State2;
false ->
case Version of
- {1,0} -> State2#{stream_state => wait_for_features};
- _ -> process_stream_downgrade(StreamStart, State)
+ {1, _} ->
+ State2#{stream_state => wait_for_features};
+ _ ->
+ process_stream_downgrade(StreamStart, State2)
end
end.
@@ -387,7 +394,7 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
#sasl_failure{} when StateName == wait_for_sasl_response ->
process_sasl_failure(Pkt, State);
#stream_error{} ->
- process_stream_end({error, {stream, Pkt}}, State);
+ process_stream_end({stream, Pkt}, State);
_ when is_record(Pkt, stream_features);
is_record(Pkt, starttls_proceed);
is_record(Pkt, starttls);
@@ -487,14 +494,23 @@ process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) ->
stream_encrypted => true},
send_header(State1);
{error, Why} ->
- process_stream_close({error, {tls, Why}}, State)
+ process_stream_end({tls, Why}, State)
end.
-spec process_stream_downgrade(stream_start(), state()) -> state().
-process_stream_downgrade(StreamStart, #{mod := Mod} = State) ->
- try Mod:downgrade_stream(StreamStart, State)
- catch _:undef ->
- send_element(State, xmpp:serr_unsupported_version())
+process_stream_downgrade(StreamStart,
+ #{mod := Mod, lang := Lang,
+ stream_encrypted := Encrypted} = State) ->
+ TLSRequired = is_starttls_required(State),
+ if not Encrypted and TLSRequired ->
+ Txt = <<"Use of STARTTLS required">>,
+ send_element(State, xmpp:err_policy_violation(Txt, Lang));
+ true ->
+ State1 = State#{stream_state => downgraded},
+ try Mod:handle_stream_downgraded(StreamStart, State1)
+ catch _:undef ->
+ send_element(State1, xmpp:serr_unsupported_version())
+ end
end.
-spec process_cert_verification(state()) -> state().
@@ -509,7 +525,7 @@ process_cert_verification(#{stream_encrypted := true,
{ok, _} ->
State#{stream_verified => true};
{error, Why, _Peer} ->
- process_stream_close({error, {pkix, Why}}, State)
+ process_stream_end({pkix, Why}, State)
end;
false ->
State#{stream_verified => true}
@@ -538,7 +554,7 @@ process_sasl_success(#{mod := Mod,
-spec process_sasl_failure(sasl_failure(), state()) -> state().
process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) ->
try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State)
- catch _:undef -> process_stream_close({error, {auth, Reason}}, State)
+ catch _:undef -> process_stream_end({auth, Reason}, State)
end.
-spec process_packet(xmpp_element(), state()) -> state().
@@ -581,7 +597,7 @@ send_header(#{remote_server := RemoteServer,
version = {1,0}}),
case send_text(State, fxml:element_to_header(Header)) of
ok -> State;
- {error, Why} -> process_stream_close({error, {socket, Why}}, State)
+ {error, Why} -> process_stream_end({socket, Why}, State)
end.
-spec send_element(state(), xmpp_element()) -> state().
@@ -596,11 +612,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
false ->
case send_text(State1, Data) of
_ when is_record(Pkt, stream_error) ->
- process_stream_end({error, {stream, Pkt}}, State1);
+ process_stream_end({stream, Pkt}, State1);
ok ->
State1;
{error, Why} ->
- process_stream_close({error, {socket, Why}}, State1)
+ process_stream_end({socket, Why}, State1)
end
end.
@@ -626,7 +642,7 @@ send_text(#{sockmod := SockMod, socket := Socket,
stream_state := StateName}, Data) when StateName /= disconnected ->
SockMod:send(Socket, Data);
send_text(_, _) ->
- {error, einval}.
+ {error, closed}.
-spec send_trailer(state()) -> state().
send_trailer(State) ->