summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>2016-12-31 13:48:55 +0300
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>2016-12-31 13:48:55 +0300
commitcf87c5664f3abb9be035b92d762e6380985684cf (patch)
tree49ae41d020767a8158224fbb9d79b8929c4685c5 /src
parentImprove return values in cyrsasl API (diff)
Reflect cyrsasl API changes in remaining code
Diffstat (limited to 'src')
-rw-r--r--src/ejabberd_c2s.erl6
-rw-r--r--src/mod_s2s_dialback.erl27
-rw-r--r--src/mod_sm.erl53
-rw-r--r--src/xmpp_stream_in.erl154
-rw-r--r--src/xmpp_stream_out.erl58
-rw-r--r--src/xmpp_stream_pkix.erl39
6 files changed, 220 insertions, 117 deletions
diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl
index f22960c5..a10ee59a 100644
--- a/src/ejabberd_c2s.erl
+++ b/src/ejabberd_c2s.erl
@@ -221,7 +221,7 @@ process_closed(State, Reason) ->
process_terminated(#{socket := Socket, jid := JID} = State,
Reason) ->
Status = format_reason(State, Reason),
- ?INFO_MSG("(~s) Closing c2s connection for ~s: ~s",
+ ?INFO_MSG("(~s) Closing c2s session for ~s: ~s",
[ejabberd_socket:pp(Socket), jid:to_string(JID), Status]),
Pres = #presence{type = unavailable,
status = xmpp:mk_text(Status),
@@ -292,12 +292,12 @@ bind(R, #{user := U, server := S, access := Access, lang := Lang,
State1 = open_session(State#{resource => Resource}),
State2 = ejabberd_hooks:run_fold(
c2s_session_opened, LServer, State1, []),
- ?INFO_MSG("(~s) Opened session for ~s",
+ ?INFO_MSG("(~s) Opened c2s session for ~s",
[ejabberd_socket:pp(Socket), jid:to_string(JID)]),
{ok, State2};
deny ->
ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]),
- ?INFO_MSG("(~s) Forbidden session for ~s",
+ ?INFO_MSG("(~s) Forbidden c2s session for ~s",
[ejabberd_socket:pp(Socket), jid:to_string(JID)]),
Txt = <<"Denied by ACL">>,
{error, xmpp:err_not_allowed(Txt, Lang), State}
diff --git a/src/mod_s2s_dialback.erl b/src/mod_s2s_dialback.erl
index 4bdda2ca..d0d78a30 100644
--- a/src/mod_s2s_dialback.erl
+++ b/src/mod_s2s_dialback.erl
@@ -29,7 +29,7 @@
-export([start/2, stop/1, depends/2, mod_opt_type/1]).
%% Hooks
-export([s2s_out_auth_result/2, s2s_out_downgraded/2,
- s2s_in_packet/2, s2s_out_packet/2,
+ s2s_in_packet/2, s2s_out_packet/2, s2s_in_recv/3,
s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]).
-include("ejabberd.hrl").
@@ -52,6 +52,8 @@ start(Host, _Opts) ->
s2s_in_features, 50),
ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE,
s2s_in_features, 50),
+ ejabberd_hooks:add(s2s_in_handle_recv, Host, ?MODULE,
+ s2s_in_recv, 50),
ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE,
s2s_in_packet, 50),
ejabberd_hooks:add(s2s_in_authenticated_packet, Host, ?MODULE,
@@ -71,6 +73,8 @@ stop(Host) ->
s2s_in_features, 50),
ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE,
s2s_in_features, 50),
+ ejabberd_hooks:delete(s2s_in_handle_recv, Host, ?MODULE,
+ s2s_in_recv, 50),
ejabberd_hooks:delete(s2s_in_unauthenticated_packet, Host, ?MODULE,
s2s_in_packet, 50),
ejabberd_hooks:delete(s2s_in_authenticated_packet, Host, ?MODULE,
@@ -191,6 +195,25 @@ s2s_in_packet(State, Pkt) when is_record(Pkt, db_result);
s2s_in_packet(State, _) ->
State.
+s2s_in_recv(State, El, {error, Why}) ->
+ case xmpp:get_name(El) of
+ Tag when Tag == <<"db:result">>;
+ Tag == <<"db:verify">> ->
+ case xmpp:get_type(El) of
+ T when T /= <<"valid">>,
+ T /= <<"invalid">>,
+ T /= <<"error">> ->
+ Err = xmpp:make_error(El, mk_error({codec_error, Why})),
+ {stop, ejabberd_s2s_in:send(State, Err)};
+ _ ->
+ State
+ end;
+ _ ->
+ State
+ end;
+s2s_in_recv(State, _El, _Pkt) ->
+ State.
+
s2s_out_packet(#{server := LServer,
remote_server := RServer,
db_verify := {StreamID, _Key, Pid}} = State,
@@ -286,6 +309,8 @@ mk_error(forbidden) ->
xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG);
mk_error(host_unknown) ->
xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG);
+mk_error({codec_error, Why}) ->
+ xmpp:err_bad_request(xmpp:io_format_error(Why), ?MYLANG);
mk_error({_Class, _Reason} = Why) ->
Txt = xmpp_stream_out:format_error(Why),
xmpp:err_remote_server_not_found(Txt, ?MYLANG);
diff --git a/src/mod_sm.erl b/src/mod_sm.erl
index 70323441..7e64e6a0 100644
--- a/src/mod_sm.erl
+++ b/src/mod_sm.erl
@@ -179,16 +179,14 @@ c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) ->
c2s_handle_recv(State, _, _) ->
State.
-c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
+c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, _Result)
when MgmtState == pending; MgmtState == active ->
State1 = mgmt_queue_add(State, Pkt),
- case Result of
- ok when ?is_stanza(Pkt) ->
+ case xmpp:is_stanza(Pkt) of
+ true ->
send_rack(State1);
- ok ->
- State1;
- {error, _} ->
- transition_to_pending(State1)
+ false ->
+ State1
end;
c2s_handle_send(State, _Pkt, _Result) ->
State.
@@ -210,8 +208,9 @@ c2s_handle_info(#{mgmt_ack_timer := TRef, jid := JID} = State,
{timeout, TRef, ack_timeout}) ->
?DEBUG("Timed out waiting for stream management acknowledgement of ~s",
[jid:to_string(JID)]),
- State1 = ejabberd_c2s:close(State, _SendTrailer = false),
- {stop, transition_to_pending(State1)};
+ State1 = State#{stop_reason => {socket, timeout}},
+ State2 = ejabberd_c2s:close(State1, _SendTrailer = false),
+ {stop, transition_to_pending(State2)};
c2s_handle_info(#{mgmt_state := pending, jid := JID} = State,
{timeout, _, pending_timeout}) ->
?DEBUG("Timed out waiting for resumption of stream for ~s",
@@ -222,8 +221,8 @@ c2s_handle_info(State, _) ->
c2s_closed(State, {stream, _}) ->
State;
-c2s_closed(#{mgmt_state := active} = State, Reason) ->
- {stop, transition_to_pending(State#{stop_reason => Reason})};
+c2s_closed(#{mgmt_state := active} = State, _Reason) ->
+ {stop, transition_to_pending(State)};
c2s_closed(State, _Reason) ->
State.
@@ -368,10 +367,9 @@ transition_to_pending(#{mgmt_state := active, jid := JID,
lserver := LServer, mgmt_timeout := Timeout} = State) ->
State1 = cancel_ack_timer(State),
?INFO_MSG("Waiting for resumption of stream for ~s", [jid:to_string(JID)]),
- State2 = ejabberd_hooks:run_fold(c2s_session_pending, LServer, State1, []),
- State3 = ejabberd_c2s:close(State2, _SendTrailer = false),
erlang:start_timer(timer:seconds(Timeout), self(), pending_timeout),
- State3#{mgmt_state => pending};
+ State2 = State1#{mgmt_state => pending},
+ ejabberd_hooks:run_fold(c2s_session_pending, LServer, State2, []);
transition_to_pending(State) ->
State.
@@ -405,8 +403,8 @@ update_num_stanzas_in(State, _El) ->
send_rack(#{mgmt_ack_timer := _} = State) ->
State;
send_rack(#{mgmt_xmlns := Xmlns,
- mgmt_stanzas_out := NumStanzasOut,
- mgmt_ack_timeout := AckTimeout} = State) ->
+ mgmt_stanzas_out := NumStanzasOut,
+ mgmt_ack_timeout := AckTimeout} = State) ->
State1 = send(State, #sm_r{xmlns = Xmlns}),
TRef = erlang:start_timer(AckTimeout, self(), ack_timeout),
State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}.
@@ -425,16 +423,19 @@ resend_rack(State) ->
-spec mgmt_queue_add(state(), xmpp_element()) -> state().
mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut,
- mgmt_queue := Queue} = State, Stanza) when ?is_stanza(Stanza) ->
- NewNum = case NumStanzasOut of
- 4294967295 -> 0;
- Num -> Num + 1
- end,
- Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Stanza}, Queue),
- State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum},
- check_queue_length(State1);
-mgmt_queue_add(State, _Nonza) ->
- State.
+ mgmt_queue := Queue} = State, Pkt) ->
+ case xmpp:is_stanza(Pkt) of
+ true ->
+ NewNum = case NumStanzasOut of
+ 4294967295 -> 0;
+ Num -> Num + 1
+ end,
+ Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Pkt}, Queue),
+ State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum},
+ check_queue_length(State1);
+ false ->
+ State
+ end.
-spec mgmt_queue_drop(state(), non_neg_integer()) -> state().
mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) ->
diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl
index a0387064..1ad78d45 100644
--- a/src/xmpp_stream_in.erl
+++ b/src/xmpp_stream_in.erl
@@ -42,7 +42,7 @@
-include("xmpp.hrl").
-type state() :: map().
--type stop_reason() :: {stream, reset | stream_error()} |
+-type stop_reason() :: {stream, reset | {in | out, stream_error()}} |
{tls, term()} |
{socket, inet:posix() | closed | timeout} |
internal_failure.
@@ -188,8 +188,10 @@ format_error({socket, Reason}) ->
format("Connection failed: ~s", [format_inet_error(Reason)]);
format_error({stream, reset}) ->
<<"Stream reset by peer">>;
-format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
- format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
+ format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
+ format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]);
format_error(internal_failure) ->
@@ -304,7 +306,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
send_element(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
+ #{xmlns := NS, mod := Mod} = State) ->
noreply(
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
@@ -321,10 +323,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end,
case is_disconnected(State1) of
true -> State1;
- false ->
- Txt = xmpp:io_format_error(Why),
- Lang = select_lang(MyLang, xmpp:get_lang(El)),
- send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
+ false -> process_invalid_xml(State1, El, Why)
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
@@ -394,6 +393,33 @@ peername(SockMod, Socket) ->
_ -> SockMod:peername(Socket)
end.
+-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
+process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
+ case xmpp:is_stanza(El) of
+ true ->
+ Txt = xmpp:io_format_error(Reason),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ send_error(State, El, xmpp:err_bad_request(Txt, Lang));
+ false ->
+ case {xmpp:get_name(El), xmpp:get_ns(El)} of
+ {Tag, ?NS_SASL} when Tag == <<"auth">>;
+ Tag == <<"response">>;
+ Tag == <<"abort">> ->
+ Txt = xmpp:io_format_error(Reason),
+ Err = #sasl_failure{reason = 'malformed-request',
+ text = xmpp:mk_text(Txt, MyLang)},
+ send_element(State, Err);
+ {<<"starttls">>, ?NS_TLS} ->
+ send_element(State, #starttls_failure{});
+ {<<"compress">>, ?NS_COMPRESS} ->
+ Err = #compress_failure{reason = 'setup-failed'},
+ send_element(State, Err);
+ _ ->
+ %% Maybe add something more?
+ State
+ end
+ end.
+
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
@@ -423,11 +449,6 @@ process_stream(#stream_start{lang = Lang},
process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) ->
Txt = <<"Missing 'to' attribute">>,
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
-process_stream(#stream_start{from = undefined, version = {1,0}},
- #{lang := Lang, xmlns := ?NS_SERVER,
- stream_encrypted := true} = State) ->
- Txt = <<"Missing 'from' attribute">>,
- send_element(State, xmpp:serr_invalid_from(Txt, Lang));
process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
#{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
Txt = <<"Improper 'to' attribute">>,
@@ -450,9 +471,10 @@ process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
true ->
State
end,
- State2 = if NS == ?NS_SERVER andalso Encrypted ->
- State1#{remote_server => From#jid.lserver};
- true ->
+ State2 = case From of
+ #jid{lserver = RemoteServer} when NS == ?NS_SERVER ->
+ State1#{remote_server => RemoteServer};
+ _ ->
State1
end,
State3 = try Mod:handle_stream_start(StreamStart, State2)
@@ -517,7 +539,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
#handshake{} ->
State;
#stream_error{} ->
- process_stream_end({stream, Pkt}, State);
+ process_stream_end({stream, {in, Pkt}}, State);
_ when StateName == wait_for_sasl_request;
StateName == wait_for_handshake;
StateName == wait_for_sasl_response ->
@@ -707,35 +729,34 @@ process_starttls_failure(Why, State) ->
-spec process_sasl_request(sasl_auth(), state()) -> state().
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
#{mod := Mod, lserver := LServer} = State) ->
- GetPW = try Mod:get_password_fun(State)
- catch _:undef -> fun(_) -> false end
- end,
- CheckPW = try Mod:check_password_fun(State)
- catch _:undef -> fun(_, _, _) -> false end
- end,
- CheckPWDigest = try Mod:check_password_digest_fun(State)
- catch _:undef -> fun(_, _, _, _, _) -> false end
- end,
- SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
- GetPW, CheckPW, CheckPWDigest),
- State1 = State#{sasl_state => SASLState, sasl_mech => Mech},
+ State1 = State#{sasl_mech => Mech},
Mechs = get_sasl_mechanisms(State1),
- SASLResult = case lists:member(Mech, Mechs) of
- true when Mech == <<"EXTERNAL">> ->
- case xmpp_stream_pkix:authenticate(State1, ClientIn) of
- {ok, Peer} ->
- {ok, [{auth_module, pkix},
- {username, Peer}]};
- {error, _Reason, Peer} ->
- %% TODO: return meaningful error
- {error, 'not-authorized', Peer}
- end;
- true ->
- cyrsasl:server_start(SASLState, Mech, ClientIn);
- false ->
- {error, 'invalid-mechanism'}
- end,
- process_sasl_result(SASLResult, State1).
+ case lists:member(Mech, Mechs) of
+ true when Mech == <<"EXTERNAL">> ->
+ Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of
+ {ok, Peer} ->
+ {ok, [{auth_module, pkix}, {username, Peer}]};
+ {error, Reason, Peer} ->
+ {error, Reason, Peer}
+ end,
+ process_sasl_result(Res, State1);
+ true ->
+ GetPW = try Mod:get_password_fun(State1)
+ catch _:undef -> fun(_) -> false end
+ end,
+ CheckPW = try Mod:check_password_fun(State1)
+ catch _:undef -> fun(_, _, _) -> false end
+ end,
+ CheckPWDigest = try Mod:check_password_digest_fun(State1)
+ catch _:undef -> fun(_, _, _, _, _) -> false end
+ end,
+ SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
+ GetPW, CheckPW, CheckPWDigest),
+ Res = cyrsasl:server_start(SASLState, Mech, ClientIn),
+ process_sasl_result(Res, State1#{sasl_state => SASLState});
+ false ->
+ process_sasl_result({error, unsupported_mechanism, <<"">>}, State1)
+ end.
-spec process_sasl_response(sasl_response(), state()) -> state().
process_sasl_response(#sasl_response{text = ClientIn},
@@ -751,9 +772,7 @@ process_sasl_result({ok, Props, ServerOut}, State) ->
process_sasl_result({continue, ServerOut, NewSASLState}, State) ->
process_sasl_continue(ServerOut, NewSASLState, State);
process_sasl_result({error, Reason, User}, State) ->
- process_sasl_failure(Reason, User, State);
-process_sasl_result({error, Reason}, State) ->
- process_sasl_failure(Reason, <<"">>, State).
+ process_sasl_failure(Reason, User, State).
-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
process_sasl_success(Props, ServerOut,
@@ -790,18 +809,20 @@ process_sasl_continue(ServerOut, NewSASLState, State) ->
send_element(State1, #sasl_challenge{text = ServerOut}).
-spec process_sasl_failure(atom(), binary(), state()) -> state().
-process_sasl_failure(Reason, User,
- #{mod := Mod, sasl_mech := Mech} = State) ->
- State1 = try Mod:handle_auth_failure(User, Mech, Reason, State)
+process_sasl_failure(Err, User,
+ #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) ->
+ {Reason, Text} = format_sasl_error(Mech, Err),
+ State1 = try Mod:handle_auth_failure(User, Mech, Text, State)
catch _:undef -> State
end,
State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)),
State3 = State2#{stream_state => wait_for_sasl_request},
- send_element(State3, #sasl_failure{reason = Reason}).
+ send_element(State3, #sasl_failure{reason = Reason,
+ text = xmpp:mk_text(Text, Lang)}).
-spec process_sasl_abort(state()) -> state().
process_sasl_abort(State) ->
- process_sasl_failure('aborted', <<"">>, State).
+ process_sasl_failure(aborted, <<"">>, State).
-spec send_features(state()) -> state().
send_features(#{stream_version := {1,0},
@@ -985,13 +1006,17 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
State1 = try Mod:handle_send(Pkt, Result, State)
catch _:undef -> State
end,
- case Result of
- _ when is_record(Pkt, stream_error) ->
- process_stream_end({stream, Pkt}, State1);
- ok ->
- State1;
- {error, Why} ->
- process_stream_end({socket, Why}, State1)
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ case Result of
+ _ when is_record(Pkt, stream_error) ->
+ process_stream_end({stream, {out, Pkt}}, State1);
+ ok ->
+ State1;
+ {error, Why} ->
+ process_stream_end({socket, Why}, State1)
+ end
end.
-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
@@ -1025,6 +1050,8 @@ send_text(_, _) ->
{error, closed}.
-spec close_socket(state()) -> state().
+close_socket(#{stream_state := disconnected} = State) ->
+ State;
close_socket(#{sockmod := SockMod, socket := Socket} = State) ->
SockMod:close(Socket),
State#{stream_timeout => infinity,
@@ -1052,6 +1079,7 @@ format_inet_error(Reason) ->
-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
format_stream_error(Reason, Txt) ->
Slogan = case Reason of
+ undefined -> "no reason";
#'see-other-host'{} -> "see-other-host";
_ -> atom_to_list(Reason)
end,
@@ -1062,6 +1090,12 @@ format_stream_error(Reason, Txt) ->
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
end.
+-spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}.
+format_sasl_error(<<"EXTERNAL">>, Err) ->
+ xmpp_stream_pkix:format_error(Err);
+format_sasl_error(Mech, Err) ->
+ cyrsasl:format_error(Mech, Err).
+
-spec format(io:format(), list()) -> binary().
format(Fmt, Args) ->
iolist_to_binary(io_lib:format(Fmt, Args)).
diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl
index 08804e43..290a92a4 100644
--- a/src/xmpp_stream_out.erl
+++ b/src/xmpp_stream_out.erl
@@ -1,10 +1,23 @@
%%%-------------------------------------------------------------------
-%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
-%%% @copyright (C) 2016, Evgeny Khramtsov
-%%% @doc
-%%%
-%%% @end
%%% Created : 14 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
%%%-------------------------------------------------------------------
-module(xmpp_stream_out).
-behaviour(gen_server).
@@ -39,7 +52,7 @@
-type network_error() :: {error, inet:posix() | inet_res:res_error()}.
-type stop_reason() :: {idna, bad_string} |
{dns, inet:posix() | inet_res:res_error()} |
- {stream, reset | stream_error()} |
+ {stream, reset | {in | out, stream_error()}} |
{tls, term()} |
{pkix, binary()} |
{auth, atom() | binary() | string()} |
@@ -135,7 +148,7 @@ change_shaper(_, _) ->
-spec format_error(stop_reason()) -> binary().
format_error({idna, _}) ->
- <<"Not an IDN hostname">>;
+ <<"Remote domain is not an IDN hostname">>;
format_error({dns, Reason}) ->
format("DNS lookup failed: ~s", [format_inet_error(Reason)]);
format_error({socket, Reason}) ->
@@ -144,8 +157,10 @@ format_error({pkix, Reason}) ->
format("Peer certificate rejected: ~s", [Reason]);
format_error({stream, reset}) ->
<<"Stream reset by peer">>;
-format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
- format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
+ format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
+ format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]);
format_error({auth, Reason}) ->
@@ -264,7 +279,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
send_element(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
+ #{xmlns := NS, mod := Mod} = State) ->
noreply(
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
@@ -281,10 +296,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end,
case is_disconnected(State1) of
true -> State1;
- false ->
- Txt = xmpp:io_format_error(Why),
- Lang = select_lang(MyLang, xmpp:get_lang(El)),
- send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
+ false -> process_invalid_xml(State1, El, Why)
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
@@ -347,6 +359,17 @@ new_id() ->
is_disconnected(#{stream_state := StreamState}) ->
StreamState == disconnected.
+-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
+process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
+ case xmpp:is_stanza(El) of
+ true ->
+ Txt = xmpp:io_format_error(Reason),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ send_error(State, El, xmpp:err_bad_request(Txt, Lang));
+ false ->
+ State
+ end.
+
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
@@ -394,7 +417,7 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
#sasl_failure{} when StateName == wait_for_sasl_response ->
process_sasl_failure(Pkt, State);
#stream_error{} ->
- process_stream_end({stream, Pkt}, State);
+ process_stream_end({stream, {in, Pkt}}, State);
_ when is_record(Pkt, stream_features);
is_record(Pkt, starttls_proceed);
is_record(Pkt, starttls);
@@ -612,7 +635,7 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
false ->
case send_text(State1, Data) of
_ when is_record(Pkt, stream_error) ->
- process_stream_end({stream, Pkt}, State1);
+ process_stream_end({stream, {out, Pkt}}, State1);
ok ->
State1;
{error, Why} ->
@@ -650,6 +673,8 @@ send_trailer(State) ->
close_socket(State).
-spec close_socket(state()) -> state().
+close_socket(#{stream_state := disconnected} = State) ->
+ State;
close_socket(State) ->
case State of
#{sockmod := SockMod, socket := Socket} ->
@@ -674,6 +699,7 @@ format_inet_error(Reason) ->
-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
format_stream_error(Reason, Txt) ->
Slogan = case Reason of
+ undefined -> "no reason";
#'see-other-host'{} -> "see-other-host";
_ -> atom_to_list(Reason)
end,
diff --git a/src/xmpp_stream_pkix.erl b/src/xmpp_stream_pkix.erl
index 59f5d820..5d64c5eb 100644
--- a/src/xmpp_stream_pkix.erl
+++ b/src/xmpp_stream_pkix.erl
@@ -9,7 +9,7 @@
-module(xmpp_stream_pkix).
%% API
--export([authenticate/1, authenticate/2]).
+-export([authenticate/1, authenticate/2, format_error/1]).
-include("xmpp.hrl").
-include_lib("public_key/include/public_key.hrl").
@@ -19,21 +19,24 @@
%%% API
%%%===================================================================
-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state())
- -> {ok, binary()} | {error, binary(), binary()}.
+ -> {ok, binary()} | {error, atom(), binary()}.
authenticate(State) ->
authenticate(State, <<"">>).
-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary())
- -> {ok, binary()} | {error, binary(), binary()}.
-authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer,
- sockmod := SockMod, socket := Socket}, _Authzid) ->
+ -> {ok, binary()} | {error, atom(), binary()}.
+authenticate(#{xmlns := ?NS_SERVER, sockmod := SockMod,
+ socket := Socket} = State, Authzid) ->
+ Peer = try maps:get(remote_server, State)
+ catch _:{badkey, _} -> Authzid
+ end,
case SockMod:get_peer_certificate(Socket) of
{ok, Cert} ->
case SockMod:get_verify_result(Socket) of
0 ->
case ejabberd_idna:domain_utf8_to_ascii(Peer) of
false ->
- {error, <<"Cannot decode remote server name">>, Peer};
+ {error, idna_failed, Peer};
AsciiPeer ->
case lists:any(
fun(D) -> match_domain(AsciiPeer, D) end,
@@ -41,20 +44,34 @@ authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer,
true ->
{ok, Peer};
false ->
- {error, <<"Certificate host name mismatch">>, Peer}
+ {error, hostname_mismatch, Peer}
end
end;
VerifyRes ->
- {error, fast_tls:get_cert_verify_string(VerifyRes, Cert), Peer}
+ %% TODO: return atomic errors
+ %% This should be improved in fast_tls
+ Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert),
+ {error, erlang:binary_to_atom(Reason, utf8), Peer}
end;
{error, _Reason} ->
- {error, <<"Cannot get peer certificate">>, Peer};
+ {error, get_cert_failed, Peer};
error ->
- {error, <<"Cannot get peer certificate">>, Peer}
+ {error, get_cert_failed, Peer}
end;
authenticate(_State, _Authzid) ->
%% TODO: client PKIX authentication
- {error, <<"Client certificate verification not implemented">>, <<"">>}.
+ {error, client_not_supported, <<"">>}.
+
+format_error(idna_failed) ->
+ {'bad-protocol', <<"Remote domain is not an IDN hostname">>};
+format_error(hostname_mismatch) ->
+ {'not-authorized', <<"Certificate host name mismatch">>};
+format_error(get_cert_failed) ->
+ {'bad-protocol', <<"Failed to get peer certificate">>};
+format_error(client_not_supported) ->
+ {'invalid-mechanism', <<"Client certificate verification is not supported">>};
+format_error(Other) ->
+ {'not-authorized', erlang:atom_to_binary(Other, utf8)}.
%%%===================================================================
%%% Internal functions