summaryrefslogtreecommitdiff
path: root/src/xmpp_stream_in.erl
diff options
context:
space:
mode:
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>2017-01-09 17:02:17 +0300
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>2017-01-09 17:02:17 +0300
commit1e55e018e534aa82541c5f460063a237192b768c (patch)
tree9584ed46fe2b18770343399254b0ba15ff591e51 /src/xmpp_stream_in.erl
parentGet rid of "jlib.hrl" header in some files (diff)
Adopt remaining code to support new hooks
Diffstat (limited to 'src/xmpp_stream_in.erl')
-rw-r--r--src/xmpp_stream_in.erl333
1 files changed, 194 insertions, 139 deletions
diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl
index 1ad78d45..b2b3b307 100644
--- a/src/xmpp_stream_in.erl
+++ b/src/xmpp_stream_in.erl
@@ -20,9 +20,11 @@
%%%
%%%-------------------------------------------------------------------
-module(xmpp_stream_in).
--behaviour(gen_server).
+-define(GEN_SERVER, gen_server).
+-behaviour(?GEN_SERVER).
-protocol({rfc, 6120}).
+-protocol({xep, 114, '1.6'}).
%% API
-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1,
@@ -43,17 +45,18 @@
-include("xmpp.hrl").
-type state() :: map().
-type stop_reason() :: {stream, reset | {in | out, stream_error()}} |
- {tls, term()} |
+ {tls, inet:posix() | atom() | binary()} |
{socket, inet:posix() | closed | timeout} |
internal_failure.
--callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
+-callback init(list()) -> {ok, state()} | {error, term()} | ignore.
-callback handle_cast(term(), state()) -> state().
-callback handle_call(term(), term(), state()) -> state().
-callback handle_info(term(), state()) -> state().
-callback terminate(term(), state()) -> any().
-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
--callback handle_stream_start(state()) -> state().
+-callback handle_stream_start(stream_start(), state()) -> state().
+-callback handle_stream_established(state()) -> state().
-callback handle_stream_end(stop_reason(), state()) -> state().
-callback handle_cdata(binary(), state()) -> state().
-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
@@ -63,6 +66,7 @@
-callback handle_auth_failure(binary(), binary(), atom(), state()) -> state().
-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
+-callback handle_timeout(state()) -> state().
-callback get_password_fun(state()) -> fun().
-callback check_password_fun(state()) -> fun().
-callback check_password_digest_fun(state()) -> fun().
@@ -71,6 +75,8 @@
-callback tls_options(state()) -> [proplists:property()].
-callback tls_required(state()) -> boolean().
-callback tls_verify(state()) -> boolean().
+-callback tls_enabled(state()) -> boolean().
+-callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()].
-callback unauthenticated_stream_features(state()) -> [xmpp_element()].
-callback authenticated_stream_features(state()) -> [xmpp_element()].
@@ -81,7 +87,8 @@
handle_info/2,
terminate/2,
code_change/3,
- handle_stream_start/1,
+ handle_stream_start/2,
+ handle_stream_established/1,
handle_stream_end/2,
handle_cdata/2,
handle_authenticated_packet/2,
@@ -91,6 +98,7 @@
handle_auth_failure/4,
handle_send/3,
handle_recv/3,
+ handle_timeout/1,
get_password_fun/1,
check_password_fun/1,
check_password_digest_fun/1,
@@ -99,6 +107,8 @@
tls_options/1,
tls_required/1,
tls_verify/1,
+ tls_enabled/1,
+ sasl_mechanisms/2,
unauthenticated_stream_features/1,
authenticated_stream_features/1]).
@@ -106,19 +116,19 @@
%%% API
%%%===================================================================
start(Mod, Args, Opts) ->
- gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+ ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
start_link(Mod, Args, Opts) ->
- gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+ ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
call(Ref, Msg, Timeout) ->
- gen_server:call(Ref, Msg, Timeout).
+ ?GEN_SERVER:call(Ref, Msg, Timeout).
cast(Ref, Msg) ->
- gen_server:cast(Ref, Msg).
+ ?GEN_SERVER:cast(Ref, Msg).
reply(Ref, Reply) ->
- gen_server:reply(Ref, Reply).
+ ?GEN_SERVER:reply(Ref, Reply).
-spec stop(pid()) -> ok;
(state()) -> no_return().
@@ -135,7 +145,7 @@ stop(_) ->
send(Pid, Pkt) when is_pid(Pid) ->
cast(Pid, {send, Pkt});
send(#{owner := Owner} = State, Pkt) when Owner == self() ->
- send_element(State, Pkt);
+ send_pkt(State, Pkt);
send(_, _) ->
erlang:error(badarg).
@@ -193,7 +203,7 @@ format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) ->
- format("TLS failed: ~w", [Reason]);
+ format("TLS failed: ~s", [format_tls_error(Reason)]);
format_error(internal_failure) ->
<<"Internal server error">>;
format_error(Err) ->
@@ -203,13 +213,9 @@ format_error(Err) ->
%%% gen_server callbacks
%%%===================================================================
init([Module, {SockMod, Socket}, Opts]) ->
- XMLSocket = case lists:keyfind(xml_socket, 1, Opts) of
- {_, XS} -> XS;
- false -> false
- end,
Encrypted = proplists:get_bool(tls, Opts),
SocketMonitor = SockMod:monitor(Socket),
- case peername(SockMod, Socket) of
+ case SockMod:peername(Socket) of
{ok, IP} ->
Time = p1_time_compat:monotonic_time(milli_seconds),
State = #{owner => self(),
@@ -227,7 +233,6 @@ init([Module, {SockMod, Socket}, Opts]) ->
stream_encrypted => Encrypted,
stream_version => {1,0},
stream_authenticated => false,
- xml_socket => XMLSocket,
xmlns => ?NS_CLIENT,
lang => <<"">>,
user => <<"">>,
@@ -238,18 +243,32 @@ init([Module, {SockMod, Socket}, Opts]) ->
case try Module:init([State, Opts])
catch _:undef -> {ok, State}
end of
- {ok, State1} ->
+ {ok, State1} when not Encrypted ->
{_, State2, Timeout} = noreply(State1),
{ok, State2, Timeout};
- Err ->
- Err
+ {ok, State1} when Encrypted ->
+ TLSOpts = try Module:tls_options(State1)
+ catch _:undef -> []
+ end,
+ case SockMod:starttls(Socket, TLSOpts) of
+ {ok, TLSSocket} ->
+ State2 = State1#{socket => TLSSocket},
+ {_, State3, Timeout} = noreply(State2),
+ {ok, State3, Timeout};
+ {error, Reason} ->
+ {stop, Reason}
+ end;
+ {error, Reason} ->
+ {stop, Reason};
+ ignore ->
+ ignore
end;
- {error, Reason} ->
- {stop, Reason}
+ {error, _Reason} ->
+ ignore
end.
handle_cast({send, Pkt}, State) ->
- noreply(send_element(State, Pkt));
+ noreply(send_pkt(State, Pkt));
handle_cast(stop, State) ->
{stop, normal, State};
handle_cast(Cast, #{mod := Mod} = State) ->
@@ -278,7 +297,7 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
State1 = send_header(State),
case is_disconnected(State1) of
true -> State1;
- false -> send_element(State1, xmpp:serr_invalid_xml())
+ false -> send_pkt(State1, xmpp:serr_invalid_xml())
end
catch _:{xmpp_codec, Why} ->
State1 = send_header(State),
@@ -288,7 +307,7 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
Txt = xmpp:io_format_error(Why),
Lang = select_lang(MyLang, xmpp:get_lang(El)),
Err = xmpp:serr_invalid_xml(Txt, Lang),
- send_element(State1, Err)
+ send_pkt(State1, Err)
end
end);
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
@@ -303,7 +322,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
_ ->
xmpp:serr_not_well_formed()
end,
- send_element(State1, Err)
+ send_pkt(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
#{xmlns := NS, mod := Mod} = State) ->
@@ -339,7 +358,7 @@ handle_info(timeout, #{mod := Mod} = State) ->
Disconnected = is_disconnected(State),
noreply(try Mod:handle_timeout(State)
catch _:undef when not Disconnected ->
- send_element(State, xmpp:serr_connection_timeout());
+ send_pkt(State, xmpp:serr_connection_timeout());
_:undef ->
stop(State)
end);
@@ -385,14 +404,6 @@ new_id() ->
is_disconnected(#{stream_state := StreamState}) ->
StreamState == disconnected.
--spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}|
- {error, inet:posix()}.
-peername(SockMod, Socket) ->
- case SockMod of
- gen_tcp -> inet:peername(Socket);
- _ -> SockMod:peername(Socket)
- end.
-
-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
case xmpp:is_stanza(El) of
@@ -408,12 +419,12 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
Txt = xmpp:io_format_error(Reason),
Err = #sasl_failure{reason = 'malformed-request',
text = xmpp:mk_text(Txt, MyLang)},
- send_element(State, Err);
+ send_pkt(State, Err);
{<<"starttls">>, ?NS_TLS} ->
- send_element(State, #starttls_failure{});
+ send_pkt(State, #starttls_failure{});
{<<"compress">>, ?NS_COMPRESS} ->
Err = #compress_failure{reason = 'setup-failed'},
- send_element(State, Err);
+ send_pkt(State, Err);
_ ->
%% Maybe add something more?
State
@@ -434,9 +445,9 @@ process_stream(#stream_start{xmlns = XML_NS,
stream_xmlns = STREAM_NS},
#{xmlns := NS} = State)
when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
- send_element(State, xmpp:serr_invalid_namespace());
+ send_pkt(State, xmpp:serr_invalid_namespace());
process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
- send_element(State, xmpp:serr_unsupported_version());
+ send_pkt(State, xmpp:serr_unsupported_version());
process_stream(#stream_start{lang = Lang},
#{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
when size(Lang) > 35 ->
@@ -445,14 +456,14 @@ process_stream(#stream_start{lang = Lang},
%% language tags MUST allow for language tags of at least 35 characters.
%% Do not store long language tag to avoid possible DoS/flood attacks
Txt = <<"Too long value of 'xml:lang' attribute">>,
- send_element(State, xmpp:serr_policy_violation(Txt, DefaultLang));
+ send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang));
process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) ->
Txt = <<"Missing 'to' attribute">>,
- send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
+ send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
#{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
Txt = <<"Improper 'to' attribute">>,
- send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
+ send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
#{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
State1 = State#{remote_server => RemoteServer,
@@ -509,29 +520,29 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
#starttls{} ->
process_starttls_failure(unexpected_starttls_request, State);
#sasl_auth{} when StateName == wait_for_starttls ->
- send_element(State, #sasl_failure{reason = 'encryption-required'});
+ send_pkt(State, #sasl_failure{reason = 'encryption-required'});
#sasl_auth{} when StateName == wait_for_sasl_request ->
process_sasl_request(Pkt, State);
#sasl_auth{} ->
Txt = <<"SASL negotiation is not allowed in this state">>,
- send_element(State, #sasl_failure{reason = 'not-authorized',
+ send_pkt(State, #sasl_failure{reason = 'not-authorized',
text = xmpp:mk_text(Txt, Lang)});
#sasl_response{} when StateName == wait_for_starttls ->
- send_element(State, #sasl_failure{reason = 'encryption-required'});
+ send_pkt(State, #sasl_failure{reason = 'encryption-required'});
#sasl_response{} when StateName == wait_for_sasl_response ->
process_sasl_response(Pkt, State);
#sasl_response{} ->
Txt = <<"SASL negotiation is not allowed in this state">>,
- send_element(State, #sasl_failure{reason = 'not-authorized',
+ send_pkt(State, #sasl_failure{reason = 'not-authorized',
text = xmpp:mk_text(Txt, Lang)});
#sasl_abort{} when StateName == wait_for_sasl_response ->
process_sasl_abort(State);
#sasl_abort{} ->
- send_element(State, #sasl_failure{reason = 'aborted'});
+ send_pkt(State, #sasl_failure{reason = 'aborted'});
#sasl_success{} ->
State;
#compress{} when StateName == wait_for_sasl_response ->
- send_element(State, #compress_failure{reason = 'setup-failed'});
+ send_pkt(State, #compress_failure{reason = 'setup-failed'});
#compress{} ->
process_compress(Pkt, State);
#handshake{} when StateName == wait_for_handshake ->
@@ -570,7 +581,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
{ok, #iq{type = set, sub_els = [_]} = Pkt2} when NS == ?NS_CLIENT ->
case xmpp:get_subtag(Pkt2, #xmpp_session{}) of
#xmpp_session{} ->
- send_element(State, xmpp:make_iq_result(Pkt2));
+ send_pkt(State, xmpp:make_iq_result(Pkt2));
_ ->
try Mod:handle_authenticated_packet(Pkt2, State)
catch _:undef ->
@@ -585,7 +596,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
send_error(State, Pkt, Err)
end;
{error, Err} ->
- send_element(State, Err)
+ send_pkt(State, Err)
end.
-spec process_bind(xmpp_element(), state()) -> state().
@@ -604,7 +615,7 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
server := S,
resource := NewR} = State1} when NewR /= <<"">> ->
Reply = #bind{jid = jid:make(U, S, NewR)},
- State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)),
+ State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)),
process_stream_established(State2);
{error, #stanza_error{}, State1} = Err ->
send_error(State1, Pkt, Err)
@@ -646,7 +657,7 @@ process_handshake(#handshake{data = Digest},
case is_disconnected(State1) of
true -> State1;
false ->
- State2 = send_element(State1, #handshake{}),
+ State2 = send_pkt(State1, #handshake{}),
process_stream_established(State2)
end;
false ->
@@ -656,7 +667,7 @@ process_handshake(#handshake{data = Digest},
end,
case is_disconnected(State1) of
true -> State1;
- false -> send_element(State1, xmpp:serr_not_authorized())
+ false -> send_pkt(State1, xmpp:serr_not_authorized())
end
end.
@@ -674,7 +685,7 @@ process_stream_established(#{mod := Mod} = State) ->
-spec process_compress(compress(), state()) -> state().
process_compress(#compress{}, #{stream_compressed := true} = State) ->
- send_element(State, #compress_failure{reason = 'setup-failed'});
+ send_pkt(State, #compress_failure{reason = 'setup-failed'});
process_compress(#compress{methods = HisMethods},
#{socket := Socket, sockmod := SockMod, mod := Mod} = State) ->
MyMethods = try Mod:compress_methods(State)
@@ -683,44 +694,60 @@ process_compress(#compress{methods = HisMethods},
CommonMethods = lists_intersection(MyMethods, HisMethods),
case lists:member(<<"zlib">>, CommonMethods) of
true ->
- BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})),
- ZlibSocket = SockMod:compress(Socket, BCompressed),
- State#{socket => ZlibSocket,
- stream_id => new_id(),
- stream_header_sent => false,
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_compressed => true};
+ State1 = send_pkt(State, #compressed{}),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ case SockMod:compress(Socket) of
+ {ok, ZlibSocket} ->
+ State1#{socket => ZlibSocket,
+ stream_id => new_id(),
+ stream_header_sent => false,
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ stream_compressed => true};
+ {error, _} ->
+ Err = #compress_failure{reason = 'setup-failed'},
+ send_pkt(State1, Err)
+ end
+ end;
false ->
- send_element(State, #compress_failure{reason = 'unsupported-method'})
+ send_pkt(State, #compress_failure{reason = 'unsupported-method'})
end.
-spec process_starttls(state()) -> state().
+process_starttls(#{stream_encrypted := true} = State) ->
+ process_starttls_failure(already_encrypted, State);
process_starttls(#{socket := Socket,
sockmod := SockMod, mod := Mod} = State) ->
- TLSOpts = try Mod:tls_options(State)
- catch _:undef -> []
- end,
- case SockMod:starttls(Socket, TLSOpts) of
- {ok, TLSSocket} ->
- State1 = send_element(State, #starttls_proceed{}),
- case is_disconnected(State1) of
- true -> State1;
- false ->
- State1#{socket => TLSSocket,
- stream_id => new_id(),
- stream_header_sent => false,
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_encrypted => true}
+ case is_starttls_available(State) of
+ true ->
+ TLSOpts = try Mod:tls_options(State)
+ catch _:undef -> []
+ end,
+ case SockMod:starttls(Socket, TLSOpts) of
+ {ok, TLSSocket} ->
+ State1 = send_pkt(State, #starttls_proceed{}),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ State1#{socket => TLSSocket,
+ stream_id => new_id(),
+ stream_header_sent => false,
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ stream_encrypted => true}
+ end;
+ {error, Reason} ->
+ process_starttls_failure(Reason, State)
end;
- {error, Reason} ->
- process_starttls_failure(Reason, State)
+ false ->
+ process_starttls_failure(starttls_unsupported, State)
end.
-spec process_starttls_failure(term(), state()) -> state().
process_starttls_failure(Why, State) ->
- State1 = send_element(State, #starttls_failure{}),
+ State1 = send_pkt(State, #starttls_failure{}),
case is_disconnected(State1) of
true -> State1;
false -> process_stream_end({tls, Why}, State1)
@@ -780,17 +807,17 @@ process_sasl_success(Props, ServerOut,
mod := Mod, sasl_mech := Mech} = State) ->
User = identity(Props),
AuthModule = proplists:get_value(auth_module, Props),
- State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State)
- catch _:undef -> State
- end,
+ State1 = send_pkt(State, #sasl_success{text = ServerOut}),
case is_disconnected(State1) of
true -> State1;
false ->
- SockMod:reset_stream(Socket),
- State2 = send_element(State1, #sasl_success{text = ServerOut}),
+ State2 = try Mod:handle_auth_success(User, Mech, AuthModule, State1)
+ catch _:undef -> State1
+ end,
case is_disconnected(State2) of
true -> State2;
false ->
+ SockMod:reset_stream(Socket),
State3 = maps:remove(sasl_state,
maps:remove(sasl_mech, State2)),
State3#{stream_id => new_id(),
@@ -806,19 +833,23 @@ process_sasl_success(Props, ServerOut,
process_sasl_continue(ServerOut, NewSASLState, State) ->
State1 = State#{sasl_state => NewSASLState,
stream_state => wait_for_sasl_response},
- send_element(State1, #sasl_challenge{text = ServerOut}).
+ send_pkt(State1, #sasl_challenge{text = ServerOut}).
-spec process_sasl_failure(atom(), binary(), state()) -> state().
process_sasl_failure(Err, User,
#{mod := Mod, sasl_mech := Mech, lang := Lang} = State) ->
{Reason, Text} = format_sasl_error(Mech, Err),
- State1 = try Mod:handle_auth_failure(User, Mech, Text, State)
- catch _:undef -> State
- end,
- State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)),
- State3 = State2#{stream_state => wait_for_sasl_request},
- send_element(State3, #sasl_failure{reason = Reason,
- text = xmpp:mk_text(Text, Lang)}).
+ State1 = send_pkt(State, #sasl_failure{reason = Reason,
+ text = xmpp:mk_text(Text, Lang)}),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ State2 = try Mod:handle_auth_failure(User, Mech, Text, State1)
+ catch _:undef -> State1
+ end,
+ State3 = maps:remove(sasl_state, maps:remove(sasl_mech, State2)),
+ State3#{stream_state => wait_for_sasl_request}
+ end.
-spec process_sasl_abort(state()) -> state().
process_sasl_abort(State) ->
@@ -835,7 +866,7 @@ send_features(#{stream_version := {1,0},
++ get_tls_feature(State) ++ get_bind_feature(State)
++ get_session_feature(State) ++ get_other_features(State)
end,
- send_element(State, #stream_features{sub_els = Features});
+ send_pkt(State, #stream_features{sub_els = Features});
send_features(State) ->
%% clients and servers from stone age
State.
@@ -849,10 +880,13 @@ get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
TLSVerify = try Mod:tls_verify(State)
catch _:undef -> false
end,
- if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
- [<<"EXTERNAL">>|Mechs];
- true ->
- Mechs
+ Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
+ [<<"EXTERNAL">>|Mechs];
+ true ->
+ Mechs
+ end,
+ try Mod:sasl_mechanisms(Mechs1, State)
+ catch _:undef -> Mechs1
end.
-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
@@ -882,8 +916,13 @@ get_compress_feature(_) ->
-spec get_tls_feature(state()) -> [starttls()].
get_tls_feature(#{stream_authenticated := false,
stream_encrypted := false} = State) ->
- TLSRequired = is_starttls_required(State),
- [#starttls{required = TLSRequired}];
+ case is_starttls_available(State) of
+ true ->
+ TLSRequired = is_starttls_required(State),
+ [#starttls{required = TLSRequired}];
+ false ->
+ []
+ end;
get_tls_feature(_) ->
[].
@@ -913,6 +952,12 @@ get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
[]
end.
+-spec is_starttls_available(state()) -> boolean().
+is_starttls_available(#{mod := Mod} = State) ->
+ try Mod:tls_enabled(State)
+ catch _:undef -> true
+ end.
+
-spec is_starttls_required(state()) -> boolean().
is_starttls_required(#{mod := Mod} = State) ->
try Mod:tls_required(State)
@@ -967,13 +1012,14 @@ send_header(#{stream_id := StreamID,
lang := MyLang,
xmlns := NS,
server := DefaultServer} = State,
- #stream_start{to = To, lang = HisLang, version = HisVersion}) ->
+ #stream_start{to = HisTo, from = HisFrom,
+ lang = HisLang, version = HisVersion}) ->
Lang = select_lang(MyLang, HisLang),
NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
true -> <<"">>
end,
- From = case To of
- #jid{} -> To;
+ From = case HisTo of
+ #jid{} -> HisTo;
undefined -> jid:make(DefaultServer)
end,
Version = case HisVersion of
@@ -981,45 +1027,40 @@ send_header(#{stream_id := StreamID,
{0,_} -> HisVersion;
_ -> MyVersion
end,
- Header = xmpp:encode(#stream_start{version = Version,
- lang = Lang,
- xmlns = NS,
- stream_xmlns = ?NS_STREAM,
- db_xmlns = NS_DB,
- id = StreamID,
- from = From}),
+ StreamStart = #stream_start{version = Version,
+ lang = Lang,
+ xmlns = NS,
+ stream_xmlns = ?NS_STREAM,
+ db_xmlns = NS_DB,
+ id = StreamID,
+ to = HisFrom,
+ from = From},
State1 = State#{lang => Lang,
stream_version => Version,
stream_header_sent => true},
- case send_text(State1, fxml:element_to_header(Header)) of
+ case socket_send(State1, StreamStart) of
ok -> State1;
{error, Why} -> process_stream_end({socket, Why}, State1)
end;
send_header(State, _) ->
State.
--spec send_element(state(), xmpp_element()) -> state().
-send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
- El = xmpp:encode(Pkt, NS),
- Data = fxml:element_to_binary(El),
- Result = send_text(State, Data),
+-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
+send_pkt(#{mod := Mod} = State, Pkt) ->
+ Result = socket_send(State, Pkt),
State1 = try Mod:handle_send(Pkt, Result, State)
catch _:undef -> State
end,
- case is_disconnected(State1) of
- true -> State1;
- false ->
- case Result of
- _ when is_record(Pkt, stream_error) ->
- process_stream_end({stream, {out, Pkt}}, State1);
- ok ->
- State1;
- {error, Why} ->
- process_stream_end({socket, Why}, State1)
- end
+ case Result of
+ _ when is_record(Pkt, stream_error) ->
+ process_stream_end({stream, {out, Pkt}}, State1);
+ ok ->
+ State1;
+ {error, Why} ->
+ process_stream_end({socket, Why}, State1)
end.
--spec send_error(state(), xmpp_element(), stanza_error()) -> state().
+-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state().
send_error(State, Pkt, Err) ->
case xmpp:is_stanza(Pkt) of
true ->
@@ -1030,7 +1071,7 @@ send_error(State, Pkt, Err) ->
<<"error">> -> State;
_ ->
ErrPkt = xmpp:make_error(Pkt, Err),
- send_element(State, ErrPkt)
+ send_pkt(State, ErrPkt)
end;
false ->
State
@@ -1038,15 +1079,23 @@ send_error(State, Pkt, Err) ->
-spec send_trailer(state()) -> state().
send_trailer(State) ->
- send_text(State, <<"</stream:stream>">>),
+ socket_send(State, trailer),
close_socket(State).
--spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
-send_text(#{socket := Sock, sockmod := SockMod,
- stream_state := StateName,
- stream_header_sent := true}, Data) when StateName /= disconnected ->
- SockMod:send(Sock, Data);
-send_text(_, _) ->
+-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}.
+socket_send(#{socket := Sock, sockmod := SockMod,
+ stream_state := StateName,
+ xmlns := NS,
+ stream_header_sent := true}, Pkt) when StateName /= disconnected ->
+ case Pkt of
+ trailer ->
+ SockMod:send_trailer(Sock);
+ #stream_start{} ->
+ SockMod:send_header(Sock, xmpp:encode(Pkt));
+ _ ->
+ SockMod:send_element(Sock, xmpp:encode(Pkt, NS))
+ end;
+socket_send(_, _) ->
{error, closed}.
-spec close_socket(state()) -> state().
@@ -1096,6 +1145,12 @@ format_sasl_error(<<"EXTERNAL">>, Err) ->
format_sasl_error(Mech, Err) ->
cyrsasl:format_error(Mech, Err).
+-spec format_tls_error(atom() | binary()) -> list().
+format_tls_error(Reason) when is_atom(Reason) ->
+ format_inet_error(Reason);
+format_tls_error(Reason) ->
+ Reason.
+
-spec format(io:format(), list()) -> binary().
format(Fmt, Args) ->
iolist_to_binary(io_lib:format(Fmt, Args)).