summaryrefslogtreecommitdiff
path: root/src/mod_mqtt_session.erl
diff options
context:
space:
mode:
authorAlexey Shchepin <alexey@process-one.net>2022-03-14 15:37:21 +0300
committerAlexey Shchepin <alexey@process-one.net>2022-03-14 15:41:01 +0300
commit5506b838c803c33c6fd5b1af00d62482c4a75c60 (patch)
treef9a611c32d29527519bdf9493569f16cdd4fcdf9 /src/mod_mqtt_session.erl
parentmix.exs: Don't allow admins to override boot flags (diff)
Add TLS certificate authentication for MQTT connections
Diffstat (limited to 'src/mod_mqtt_session.erl')
-rw-r--r--src/mod_mqtt_session.erl100
1 files changed, 87 insertions, 13 deletions
diff --git a/src/mod_mqtt_session.erl b/src/mod_mqtt_session.erl
index 8ce04066..6a551f00 100644
--- a/src/mod_mqtt_session.erl
+++ b/src/mod_mqtt_session.erl
@@ -29,6 +29,7 @@
-include("logger.hrl").
-include("mqtt.hrl").
-include_lib("xmpp/include/xmpp.hrl").
+-include_lib("public_key/include/public_key.hrl").
-record(state, {vsn = ?VSN :: integer(),
version :: undefined | mqtt_version(),
@@ -47,7 +48,8 @@
in_flight :: undefined | publish() | pubrel(),
codec :: mqtt_codec:state(),
queue :: undefined | p1_queue:queue(publish()),
- tls :: boolean()}).
+ tls :: boolean(),
+ tls_verify :: boolean()}).
-type acks() :: #{non_neg_integer() => pubrec()}.
-type subscriptions() :: #{binary() => {sub_opts(), non_neg_integer()}}.
@@ -162,6 +164,7 @@ init([SockMod, Socket, ListenOpts]) ->
State1 = #state{socket = {SockMod, Socket},
id = p1_rand:uniform(65535),
tls = proplists:get_bool(tls, ListenOpts),
+ tls_verify = proplists:get_bool(tls_verify, ListenOpts),
codec = mqtt_codec:new(MaxSize)},
Timeout = timer:seconds(30),
State2 = set_timeout(State1, Timeout),
@@ -553,7 +556,7 @@ unregister_session(_, _) ->
{error, state(), error_reason()}.
handle_connect(#connect{clean_start = CleanStart} = Pkt,
#state{jid = undefined, peername = IP} = State) ->
- case authenticate(Pkt, IP) of
+ case authenticate(Pkt, IP, State) of
{ok, JID} ->
case validate_will(Pkt, JID) of
ok ->
@@ -939,7 +942,12 @@ check_sock_result({_, Sock}, {error, Why}) ->
starttls(#state{socket = {gen_tcp, Socket}, tls = true}) ->
case ejabberd_pkix:get_certfile() of
{ok, Cert} ->
- case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of
+ CAFileOpt =
+ case ejabberd_option:c2s_cafile(ejabberd_config:get_myname()) of
+ undefined -> [];
+ CAFile -> [{cafile, CAFile}]
+ end,
+ case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}] ++ CAFileOpt) of
{ok, TLSSock} ->
{ok, {fast_tls, TLSSock}};
{error, Why} ->
@@ -1172,9 +1180,9 @@ parse_credentials(JID, ClientID) ->
end
end.
--spec authenticate(connect(), peername()) -> {ok, jid:jid()} | {error, reason_code()}.
-authenticate(Pkt, IP) ->
- case authenticate(Pkt) of
+-spec authenticate(connect(), peername(), state()) -> {ok, jid:jid()} | {error, reason_code()}.
+authenticate(Pkt, IP, State) ->
+ case authenticate(Pkt, State) of
{ok, JID, AuthModule} ->
?INFO_MSG("Accepted MQTT authentication for ~ts by ~s backend from ~s",
[jid:encode(JID),
@@ -1185,8 +1193,8 @@ authenticate(Pkt, IP) ->
Err
end.
--spec authenticate(connect()) -> {ok, jid:jid(), module()} | {error, reason_code()}.
-authenticate(#connect{password = Pass, properties = Props} = Pkt) ->
+-spec authenticate(connect(), state()) -> {ok, jid:jid(), module()} | {error, reason_code()}.
+authenticate(#connect{password = Pass, properties = Props} = Pkt, State) ->
case parse_credentials(Pkt) of
{ok, #jid{luser = LUser, lserver = LServer} = JID} ->
case maps:find(authentication_method, Props) of
@@ -1200,16 +1208,82 @@ authenticate(#connect{password = Pass, properties = Props} = Pkt) ->
{ok, _} ->
{error, 'bad-authentication-method'};
error ->
- case ejabberd_auth:check_password_with_authmodule(
- LUser, <<>>, LServer, Pass) of
- {true, AuthModule} -> {ok, JID, AuthModule};
- false -> {error, 'not-authorized'}
- end
+ case Pass of
+ <<>> ->
+ case tls_auth(JID, State) of
+ true ->
+ {ok, JID, pkix};
+ false ->
+ {error, 'not-authorized'}
+ end;
+ _ ->
+ case ejabberd_auth:check_password_with_authmodule(
+ LUser, <<>>, LServer, Pass) of
+ {true, AuthModule} -> {ok, JID, AuthModule};
+ false -> {error, 'not-authorized'}
+ end
+ end
end;
{error, _} = Err ->
Err
end.
+-spec tls_auth(jid:jid(), state()) -> boolean().
+tls_auth(_JID, #state{tls_verify = false}) ->
+ false;
+tls_auth(JID, State) ->
+ case State#state.socket of
+ {fast_tls, Sock} ->
+ case fast_tls:get_peer_certificate(Sock, otp) of
+ {ok, Cert} ->
+ case fast_tls:get_verify_result(Sock) of
+ 0 ->
+ case get_cert_jid(Cert) of
+ {ok, JID2} ->
+ jid:remove_resource(jid:tolower(JID)) ==
+ jid:remove_resource(jid:tolower(JID2));
+ error ->
+ false
+ end;
+ VerifyRes ->
+ Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert),
+ ?WARNING_MSG("TLS verify failed: ~s", [Reason]),
+ false
+ end;
+ error ->
+ false
+ end;
+ _ ->
+ false
+ end.
+
+get_cert_jid(Cert) ->
+ case Cert#'OTPCertificate'.tbsCertificate#'OTPTBSCertificate'.subject of
+ {rdnSequence, Attrs1} ->
+ Attrs = lists:flatten(Attrs1),
+ case lists:keyfind(?'id-at-commonName',
+ #'AttributeTypeAndValue'.type, Attrs) of
+ #'AttributeTypeAndValue'{value = {utf8String, Val}} ->
+ try jid:decode(Val) of
+ #jid{luser = <<>>} ->
+ case jid:make(Val, ejabberd_config:get_myname()) of
+ error ->
+ error;
+ JID ->
+ {ok, JID}
+ end;
+ JID ->
+ {ok, JID}
+ catch _:{bad_jid, _} ->
+ error
+ end;
+ _ ->
+ error
+ end;
+ _ ->
+ error
+ end.
+
%%%===================================================================
%%% Validators
%%%===================================================================