summaryrefslogtreecommitdiff
path: root/src/xmpp_stream_in.erl
diff options
context:
space:
mode:
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>2016-12-31 13:48:55 +0300
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>2016-12-31 13:48:55 +0300
commitcf87c5664f3abb9be035b92d762e6380985684cf (patch)
tree49ae41d020767a8158224fbb9d79b8929c4685c5 /src/xmpp_stream_in.erl
parentImprove return values in cyrsasl API (diff)
Reflect cyrsasl API changes in remaining code
Diffstat (limited to 'src/xmpp_stream_in.erl')
-rw-r--r--src/xmpp_stream_in.erl154
1 files changed, 94 insertions, 60 deletions
diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl
index a0387064..1ad78d45 100644
--- a/src/xmpp_stream_in.erl
+++ b/src/xmpp_stream_in.erl
@@ -42,7 +42,7 @@
-include("xmpp.hrl").
-type state() :: map().
--type stop_reason() :: {stream, reset | stream_error()} |
+-type stop_reason() :: {stream, reset | {in | out, stream_error()}} |
{tls, term()} |
{socket, inet:posix() | closed | timeout} |
internal_failure.
@@ -188,8 +188,10 @@ format_error({socket, Reason}) ->
format("Connection failed: ~s", [format_inet_error(Reason)]);
format_error({stream, reset}) ->
<<"Stream reset by peer">>;
-format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
- format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
+ format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
+ format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
format_error({tls, Reason}) ->
format("TLS failed: ~w", [Reason]);
format_error(internal_failure) ->
@@ -304,7 +306,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
send_element(State1, Err)
end);
handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
+ #{xmlns := NS, mod := Mod} = State) ->
noreply(
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
@@ -321,10 +323,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
end,
case is_disconnected(State1) of
true -> State1;
- false ->
- Txt = xmpp:io_format_error(Why),
- Lang = select_lang(MyLang, xmpp:get_lang(El)),
- send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
+ false -> process_invalid_xml(State1, El, Why)
end
end);
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
@@ -394,6 +393,33 @@ peername(SockMod, Socket) ->
_ -> SockMod:peername(Socket)
end.
+-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
+process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
+ case xmpp:is_stanza(El) of
+ true ->
+ Txt = xmpp:io_format_error(Reason),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ send_error(State, El, xmpp:err_bad_request(Txt, Lang));
+ false ->
+ case {xmpp:get_name(El), xmpp:get_ns(El)} of
+ {Tag, ?NS_SASL} when Tag == <<"auth">>;
+ Tag == <<"response">>;
+ Tag == <<"abort">> ->
+ Txt = xmpp:io_format_error(Reason),
+ Err = #sasl_failure{reason = 'malformed-request',
+ text = xmpp:mk_text(Txt, MyLang)},
+ send_element(State, Err);
+ {<<"starttls">>, ?NS_TLS} ->
+ send_element(State, #starttls_failure{});
+ {<<"compress">>, ?NS_COMPRESS} ->
+ Err = #compress_failure{reason = 'setup-failed'},
+ send_element(State, Err);
+ _ ->
+ %% Maybe add something more?
+ State
+ end
+ end.
+
-spec process_stream_end(stop_reason(), state()) -> state().
process_stream_end(_, #{stream_state := disconnected} = State) ->
State;
@@ -423,11 +449,6 @@ process_stream(#stream_start{lang = Lang},
process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) ->
Txt = <<"Missing 'to' attribute">>,
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
-process_stream(#stream_start{from = undefined, version = {1,0}},
- #{lang := Lang, xmlns := ?NS_SERVER,
- stream_encrypted := true} = State) ->
- Txt = <<"Missing 'from' attribute">>,
- send_element(State, xmpp:serr_invalid_from(Txt, Lang));
process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
#{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
Txt = <<"Improper 'to' attribute">>,
@@ -450,9 +471,10 @@ process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
true ->
State
end,
- State2 = if NS == ?NS_SERVER andalso Encrypted ->
- State1#{remote_server => From#jid.lserver};
- true ->
+ State2 = case From of
+ #jid{lserver = RemoteServer} when NS == ?NS_SERVER ->
+ State1#{remote_server => RemoteServer};
+ _ ->
State1
end,
State3 = try Mod:handle_stream_start(StreamStart, State2)
@@ -517,7 +539,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
#handshake{} ->
State;
#stream_error{} ->
- process_stream_end({stream, Pkt}, State);
+ process_stream_end({stream, {in, Pkt}}, State);
_ when StateName == wait_for_sasl_request;
StateName == wait_for_handshake;
StateName == wait_for_sasl_response ->
@@ -707,35 +729,34 @@ process_starttls_failure(Why, State) ->
-spec process_sasl_request(sasl_auth(), state()) -> state().
process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
#{mod := Mod, lserver := LServer} = State) ->
- GetPW = try Mod:get_password_fun(State)
- catch _:undef -> fun(_) -> false end
- end,
- CheckPW = try Mod:check_password_fun(State)
- catch _:undef -> fun(_, _, _) -> false end
- end,
- CheckPWDigest = try Mod:check_password_digest_fun(State)
- catch _:undef -> fun(_, _, _, _, _) -> false end
- end,
- SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
- GetPW, CheckPW, CheckPWDigest),
- State1 = State#{sasl_state => SASLState, sasl_mech => Mech},
+ State1 = State#{sasl_mech => Mech},
Mechs = get_sasl_mechanisms(State1),
- SASLResult = case lists:member(Mech, Mechs) of
- true when Mech == <<"EXTERNAL">> ->
- case xmpp_stream_pkix:authenticate(State1, ClientIn) of
- {ok, Peer} ->
- {ok, [{auth_module, pkix},
- {username, Peer}]};
- {error, _Reason, Peer} ->
- %% TODO: return meaningful error
- {error, 'not-authorized', Peer}
- end;
- true ->
- cyrsasl:server_start(SASLState, Mech, ClientIn);
- false ->
- {error, 'invalid-mechanism'}
- end,
- process_sasl_result(SASLResult, State1).
+ case lists:member(Mech, Mechs) of
+ true when Mech == <<"EXTERNAL">> ->
+ Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of
+ {ok, Peer} ->
+ {ok, [{auth_module, pkix}, {username, Peer}]};
+ {error, Reason, Peer} ->
+ {error, Reason, Peer}
+ end,
+ process_sasl_result(Res, State1);
+ true ->
+ GetPW = try Mod:get_password_fun(State1)
+ catch _:undef -> fun(_) -> false end
+ end,
+ CheckPW = try Mod:check_password_fun(State1)
+ catch _:undef -> fun(_, _, _) -> false end
+ end,
+ CheckPWDigest = try Mod:check_password_digest_fun(State1)
+ catch _:undef -> fun(_, _, _, _, _) -> false end
+ end,
+ SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
+ GetPW, CheckPW, CheckPWDigest),
+ Res = cyrsasl:server_start(SASLState, Mech, ClientIn),
+ process_sasl_result(Res, State1#{sasl_state => SASLState});
+ false ->
+ process_sasl_result({error, unsupported_mechanism, <<"">>}, State1)
+ end.
-spec process_sasl_response(sasl_response(), state()) -> state().
process_sasl_response(#sasl_response{text = ClientIn},
@@ -751,9 +772,7 @@ process_sasl_result({ok, Props, ServerOut}, State) ->
process_sasl_result({continue, ServerOut, NewSASLState}, State) ->
process_sasl_continue(ServerOut, NewSASLState, State);
process_sasl_result({error, Reason, User}, State) ->
- process_sasl_failure(Reason, User, State);
-process_sasl_result({error, Reason}, State) ->
- process_sasl_failure(Reason, <<"">>, State).
+ process_sasl_failure(Reason, User, State).
-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
process_sasl_success(Props, ServerOut,
@@ -790,18 +809,20 @@ process_sasl_continue(ServerOut, NewSASLState, State) ->
send_element(State1, #sasl_challenge{text = ServerOut}).
-spec process_sasl_failure(atom(), binary(), state()) -> state().
-process_sasl_failure(Reason, User,
- #{mod := Mod, sasl_mech := Mech} = State) ->
- State1 = try Mod:handle_auth_failure(User, Mech, Reason, State)
+process_sasl_failure(Err, User,
+ #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) ->
+ {Reason, Text} = format_sasl_error(Mech, Err),
+ State1 = try Mod:handle_auth_failure(User, Mech, Text, State)
catch _:undef -> State
end,
State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)),
State3 = State2#{stream_state => wait_for_sasl_request},
- send_element(State3, #sasl_failure{reason = Reason}).
+ send_element(State3, #sasl_failure{reason = Reason,
+ text = xmpp:mk_text(Text, Lang)}).
-spec process_sasl_abort(state()) -> state().
process_sasl_abort(State) ->
- process_sasl_failure('aborted', <<"">>, State).
+ process_sasl_failure(aborted, <<"">>, State).
-spec send_features(state()) -> state().
send_features(#{stream_version := {1,0},
@@ -985,13 +1006,17 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
State1 = try Mod:handle_send(Pkt, Result, State)
catch _:undef -> State
end,
- case Result of
- _ when is_record(Pkt, stream_error) ->
- process_stream_end({stream, Pkt}, State1);
- ok ->
- State1;
- {error, Why} ->
- process_stream_end({socket, Why}, State1)
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ case Result of
+ _ when is_record(Pkt, stream_error) ->
+ process_stream_end({stream, {out, Pkt}}, State1);
+ ok ->
+ State1;
+ {error, Why} ->
+ process_stream_end({socket, Why}, State1)
+ end
end.
-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
@@ -1025,6 +1050,8 @@ send_text(_, _) ->
{error, closed}.
-spec close_socket(state()) -> state().
+close_socket(#{stream_state := disconnected} = State) ->
+ State;
close_socket(#{sockmod := SockMod, socket := Socket} = State) ->
SockMod:close(Socket),
State#{stream_timeout => infinity,
@@ -1052,6 +1079,7 @@ format_inet_error(Reason) ->
-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
format_stream_error(Reason, Txt) ->
Slogan = case Reason of
+ undefined -> "no reason";
#'see-other-host'{} -> "see-other-host";
_ -> atom_to_list(Reason)
end,
@@ -1062,6 +1090,12 @@ format_stream_error(Reason, Txt) ->
binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
end.
+-spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}.
+format_sasl_error(<<"EXTERNAL">>, Err) ->
+ xmpp_stream_pkix:format_error(Err);
+format_sasl_error(Mech, Err) ->
+ cyrsasl:format_error(Mech, Err).
+
-spec format(io:format(), list()) -> binary().
format(Fmt, Args) ->
iolist_to_binary(io_lib:format(Fmt, Args)).