aboutsummaryrefslogtreecommitdiff
path: root/src/mod_mqtt_session.erl
diff options
context:
space:
mode:
authorEvgeny Khramtsov <ekhramtsov@process-one.net>2019-02-25 11:42:09 +0300
committerEvgeny Khramtsov <ekhramtsov@process-one.net>2019-02-25 11:42:09 +0300
commita3df791373c30ccc79a6082f4c910a378d726cdc (patch)
treee7efece6aaaec749f0291f0845abdd4a75b7a059 /src/mod_mqtt_session.erl
parentmod_muc_admin: Fix indentation (diff)
Add MQTT support
Diffstat (limited to 'src/mod_mqtt_session.erl')
-rw-r--r--src/mod_mqtt_session.erl1318
1 files changed, 1318 insertions, 0 deletions
diff --git a/src/mod_mqtt_session.erl b/src/mod_mqtt_session.erl
new file mode 100644
index 000000000..3df36b8fb
--- /dev/null
+++ b/src/mod_mqtt_session.erl
@@ -0,0 +1,1318 @@
+%%%-------------------------------------------------------------------
+%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%% @copyright (C) 2002-2019 ProcessOne, SARL. All Rights Reserved.
+%%%
+%%% Licensed under the Apache License, Version 2.0 (the "License");
+%%% you may not use this file except in compliance with the License.
+%%% You may obtain a copy of the License at
+%%%
+%%% http://www.apache.org/licenses/LICENSE-2.0
+%%%
+%%% Unless required by applicable law or agreed to in writing, software
+%%% distributed under the License is distributed on an "AS IS" BASIS,
+%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%%% See the License for the specific language governing permissions and
+%%% limitations under the License.
+%%%
+%%%-------------------------------------------------------------------
+-module(mod_mqtt_session).
+-behaviour(p1_server).
+-define(VSN, 1).
+-vsn(?VSN).
+
+%% API
+-export([start/3, start_link/3, accept/1, route/2]).
+%% gen_server callbacks
+-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
+ terminate/2, code_change/3]).
+
+-include("logger.hrl").
+-include("mqtt.hrl").
+-include("xmpp.hrl").
+
+-record(state, {vsn = ?VSN :: integer(),
+ version :: undefined | mqtt_version(),
+ socket :: undefined | socket(),
+ peername :: peername(),
+ timeout = infinity :: timer(),
+ jid :: undefined | jid:jid(),
+ session_expiry = 0 :: seconds(),
+ will :: undefined | publish(),
+ will_delay = 0 :: seconds(),
+ stop_reason :: undefined | error_reason(),
+ acks = #{} :: map(),
+ subscriptions = #{} :: map(),
+ topic_aliases = #{} :: map(),
+ id = 0 :: non_neg_integer(),
+ in_flight :: undefined | publish() | pubrel(),
+ codec :: mqtt_codec:state(),
+ queue :: undefined | p1_queue:queue()}).
+
+-type error_reason() :: {auth, reason_code()} |
+ {code, reason_code()} |
+ {peer_disconnected, reason_code(), binary()} |
+ {socket, socket_error_reason()} |
+ {codec, mqtt_codec:error_reason()} |
+ {unexpected_packet, atom()} |
+ {tls, inet:posix() | atom() | binary()} |
+ {replaced, pid()} | {resumed, pid()} |
+ subscribe_forbidden | publish_forbidden |
+ will_topic_forbidden | internal_server_error |
+ session_expired | idle_connection |
+ queue_full | shutdown | db_failure |
+ {payload_format_invalid, will | publish} |
+ session_expiry_non_zero | unknown_topic_alias.
+
+-type state() :: #state{}.
+-type sockmod() :: gen_tcp | fast_tls.
+-type socket() :: {sockmod(), inet:socket() | fast_tls:tls_socket()}.
+-type peername() :: {inet:ip_address(), inet:port_number()}.
+-type seconds() :: non_neg_integer().
+-type milli_seconds() :: non_neg_integer().
+-type timer() :: infinity | {milli_seconds(), integer()}.
+-type socket_error_reason() :: closed | timeout | inet:posix().
+
+-define(CALL_TIMEOUT, timer:minutes(1)).
+-define(RELAY_TIMEOUT, timer:minutes(1)).
+-define(MAX_UINT32, 4294967295).
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+start(SockMod, Socket, ListenOpts) ->
+ p1_server:start(?MODULE, [SockMod, Socket, ListenOpts],
+ ejabberd_config:fsm_limit_opts(ListenOpts)).
+
+start_link(SockMod, Socket, ListenOpts) ->
+ p1_server:start_link(?MODULE, [SockMod, Socket, ListenOpts],
+ ejabberd_config:fsm_limit_opts(ListenOpts)).
+
+-spec accept(pid()) -> ok.
+accept(Pid) ->
+ p1_server:cast(Pid, accept).
+
+-spec route(pid(), term()) -> boolean().
+route(Pid, Term) ->
+ ejabberd_cluster:send(Pid, Term).
+
+-spec format_error(error_reason()) -> string().
+format_error(session_expired) ->
+ "Disconnected session is expired";
+format_error(idle_connection) ->
+ "Idle connection";
+format_error(queue_full) ->
+ "Message queue is overloaded";
+format_error(internal_server_error) ->
+ "Internal server error";
+format_error(db_failure) ->
+ "Database failure";
+format_error(shutdown) ->
+ "System shutting down";
+format_error(subscribe_forbidden) ->
+ "Subscribing to this topic is forbidden by service policy";
+format_error(publish_forbidden) ->
+ "Publishing to this topic is forbidden by service policy";
+format_error(will_topic_forbidden) ->
+ "Publishing to this will topic is forbidden by service policy";
+format_error(session_expiry_non_zero) ->
+ "Session Expiry Interval in DISCONNECT packet should have been zero";
+format_error(unknown_topic_alias) ->
+ "No mapping found for this Topic Alias";
+format_error({payload_format_invalid, will}) ->
+ "Will payload format doesn't match its indicator";
+format_error({payload_format_invalid, publish}) ->
+ "PUBLISH payload format doesn't match its indicator";
+format_error({peer_disconnected, Code, <<>>}) ->
+ format("Peer disconnected with reason: ~s",
+ [mqtt_codec:format_reason_code(Code)]);
+format_error({peer_disconnected, Code, Reason}) ->
+ format("Peer disconnected with reason: ~s (~s)", [Reason, Code]);
+format_error({replaced, Pid}) ->
+ format("Replaced by ~p at ~s", [Pid, node(Pid)]);
+format_error({resumed, Pid}) ->
+ format("Resumed by ~p at ~s", [Pid, node(Pid)]);
+format_error({unexpected_packet, Name}) ->
+ format("Unexpected ~s packet", [string:to_upper(atom_to_list(Name))]);
+format_error({tls, Reason}) ->
+ format("TLS failed: ~s", [format_tls_error(Reason)]);
+format_error({socket, A}) ->
+ format("Connection failed: ~s", [format_inet_error(A)]);
+format_error({code, Code}) ->
+ format("Protocol error: ~s", [mqtt_codec:format_reason_code(Code)]);
+format_error({auth, Code}) ->
+ format("Authentication failed: ~s", [mqtt_codec:format_reason_code(Code)]);
+format_error({codec, CodecError}) ->
+ format("Protocol error: ~s", [mqtt_codec:format_error(CodecError)]);
+format_error(A) when is_atom(A) ->
+ atom_to_list(A);
+format_error(Reason) ->
+ format("Unrecognized error: ~w", [Reason]).
+
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+init([SockMod, Socket, ListenOpts]) ->
+ MaxSize = proplists:get_value(max_payload_size, ListenOpts, infinity),
+ SockMod1 = case proplists:get_bool(tls, ListenOpts) of
+ true -> fast_tls;
+ false -> SockMod
+ end,
+ State1 = #state{socket = {SockMod1, Socket},
+ id = p1_rand:uniform(65535),
+ codec = mqtt_codec:new(MaxSize)},
+ Timeout = timer:seconds(30),
+ State2 = set_timeout(State1, Timeout),
+ {ok, State2, Timeout}.
+
+handle_call({get_state, _}, From, #state{stop_reason = {resumed, Pid}} = State) ->
+ p1_server:reply(From, {error, {resumed, Pid}}),
+ noreply(State);
+handle_call({get_state, Pid}, From, State) ->
+ case stop(State, {resumed, Pid}) of
+ {stop, Status, State1} ->
+ {stop, Status, State1#state{stop_reason = {replaced, Pid}}};
+ {noreply, State1, _} ->
+ ?DEBUG("Transfering MQTT session state to ~p at ~s", [Pid, node(Pid)]),
+ Q1 = p1_queue:file_to_ram(State1#state.queue),
+ p1_server:reply(From, {ok, State1#state{queue = Q1}}),
+ SessionExpiry = timer:seconds(State1#state.session_expiry),
+ State2 = set_timeout(State1, min(SessionExpiry, ?RELAY_TIMEOUT)),
+ State3 = State2#state{queue = undefined,
+ stop_reason = {resumed, Pid},
+ acks = #{},
+ will = undefined,
+ session_expiry = 0,
+ topic_aliases = #{},
+ subscriptions = #{}},
+ noreply(State3)
+ end;
+handle_call(Request, From, State) ->
+ ?WARNING_MSG("Got unexpected call from ~p: ~p", [From, Request]),
+ noreply(State).
+
+handle_cast(accept, #state{socket = {_, TCPSock} = Socket} = State) ->
+ case inet:peername(TCPSock) of
+ {ok, IPPort} ->
+ State1 = State#state{peername = IPPort},
+ case starttls(Socket) of
+ {ok, Socket1} ->
+ State2 = State1#state{socket = Socket1},
+ handle_info({tcp, TCPSock, <<>>}, State2);
+ {error, Why} ->
+ stop(State1, Why)
+ end;
+ {error, Why} ->
+ stop(State, {socket, Why})
+ end;
+handle_cast(Msg, State) ->
+ ?WARNING_MSG("Got unexpected cast: ~p", [Msg]),
+ noreply(State).
+
+handle_info(Msg, #state{stop_reason = {resumed, Pid} = Reason} = State) ->
+ case Msg of
+ {#publish{}, _} ->
+ ?DEBUG("Relaying delayed publish to ~p at ~s", [Pid, node(Pid)]),
+ ejabberd_cluster:send(Pid, Msg),
+ noreply(State);
+ timeout ->
+ stop(State, Reason);
+ _ ->
+ noreply(State)
+ end;
+handle_info({#publish{meta = Meta} = Pkt, ExpiryTime}, State) ->
+ ID = next_id(State#state.id),
+ Meta1 = Meta#{expiry_time => ExpiryTime},
+ Pkt1 = Pkt#publish{id = ID, meta = Meta1},
+ State1 = State#state{id = ID},
+ case send(State1, Pkt1) of
+ {ok, State2} -> noreply(State2);
+ {error, State2, Reason} -> stop(State2, Reason)
+ end;
+handle_info({tcp, TCPSock, TCPData},
+ #state{codec = Codec, socket = Socket} = State) ->
+ case recv_data(Socket, TCPData) of
+ {ok, Data} ->
+ case mqtt_codec:decode(Codec, Data) of
+ {ok, Pkt, Codec1} ->
+ ?DEBUG("Got MQTT packet:~n~s", [pp(Pkt)]),
+ State1 = State#state{codec = Codec1},
+ case handle_packet(Pkt, State1) of
+ {ok, State2} ->
+ handle_info({tcp, TCPSock, <<>>}, State2);
+ {error, State2, Reason} ->
+ stop(State2, Reason)
+ end;
+ {more, Codec1} ->
+ State1 = State#state{codec = Codec1},
+ State2 = reset_keep_alive(State1),
+ activate(Socket),
+ noreply(State2);
+ {error, Why} ->
+ stop(State, {codec, Why})
+ end;
+ {error, Why} ->
+ stop(State, Why)
+ end;
+handle_info({tcp_closed, _Sock}, State) ->
+ ?DEBUG("MQTT connection reset by peer", []),
+ stop(State, {socket, closed});
+handle_info({tcp_error, _Sock, Reason}, State) ->
+ ?DEBUG("MQTT connection error: ~s", [format_inet_error(Reason)]),
+ stop(State, {socket, Reason});
+handle_info(timeout, #state{socket = Socket} = State) ->
+ case Socket of
+ undefined ->
+ ?DEBUG("MQTT session expired", []),
+ stop(State#state{session_expiry = 0}, session_expired);
+ _ ->
+ ?DEBUG("MQTT connection timed out", []),
+ stop(State, idle_connection)
+ end;
+handle_info({replaced, Pid}, State) ->
+ stop(State#state{session_expiry = 0}, {replaced, Pid});
+handle_info({timeout, _TRef, publish_will}, State) ->
+ noreply(publish_will(State));
+handle_info({Ref, badarg}, State) when is_reference(Ref) ->
+ %% TODO: figure out from where this messages comes from
+ noreply(State);
+handle_info(Info, State) ->
+ ?WARNING_MSG("Got unexpected info: ~p", [Info]),
+ noreply(State).
+
+-spec handle_packet(mqtt_packet(), state()) -> {ok, state()} |
+ {error, state(), error_reason()}.
+handle_packet(#connect{proto_level = Version} = Pkt, State) ->
+ handle_connect(Pkt, State#state{version = Version});
+handle_packet(#publish{} = Pkt, State) ->
+ handle_publish(Pkt, State);
+handle_packet(#puback{id = ID}, #state{in_flight = #publish{qos = 1, id = ID}} = State) ->
+ resend(State#state{in_flight = undefined});
+handle_packet(#puback{id = ID, code = Code}, State) ->
+ ?DEBUG("Ignoring unexpected PUBACK with id=~B and code '~s'", [ID, Code]),
+ {ok, State};
+handle_packet(#pubrec{id = ID, code = Code},
+ #state{in_flight = #publish{qos = 2, id = ID}} = State) ->
+ case mqtt_codec:is_error_code(Code) of
+ true ->
+ ?DEBUG("Got PUBREC with error code '~s', "
+ "aborting acknowledgement", [Code]),
+ resend(State#state{in_flight = undefined});
+ false ->
+ Pubrel = #pubrel{id = ID},
+ send(State#state{in_flight = Pubrel}, Pubrel)
+ end;
+handle_packet(#pubrec{id = ID, code = Code}, State) ->
+ case mqtt_codec:is_error_code(Code) of
+ true ->
+ ?DEBUG("Ignoring unexpected PUBREC with id=~B and code '~s'",
+ [ID, Code]),
+ {ok, State};
+ false ->
+ Code1 = 'packet-identifier-not-found',
+ ?DEBUG("Got unexpected PUBREC with id=~B, "
+ "sending PUBREL with error code '~s'", [ID, Code1]),
+ send(State, #pubrel{id = ID, code = Code1})
+ end;
+handle_packet(#pubcomp{id = ID}, #state{in_flight = #pubrel{id = ID}} = State) ->
+ resend(State#state{in_flight = undefined});
+handle_packet(#pubcomp{id = ID}, State) ->
+ ?DEBUG("Ignoring unexpected PUBCOMP with id=~B: most likely "
+ "it's a repeated response to duplicated PUBREL", [ID]),
+ {ok, State};
+handle_packet(#pubrel{id = ID}, State) ->
+ case maps:take(ID, State#state.acks) of
+ {_, Acks} ->
+ send(State#state{acks = Acks}, #pubcomp{id = ID});
+ error ->
+ Code = 'packet-identifier-not-found',
+ ?DEBUG("Got unexpected PUBREL with id=~B, "
+ "sending PUBCOMP with error code '~s'", [ID, Code]),
+ Pubcomp = #pubcomp{id = ID, code = Code},
+ send(State, Pubcomp)
+ end;
+handle_packet(#subscribe{} = Pkt, State) ->
+ handle_subscribe(Pkt, State);
+handle_packet(#unsubscribe{} = Pkt, State) ->
+ handle_unsubscribe(Pkt, State);
+handle_packet(#pingreq{}, State) ->
+ send(State, #pingresp{});
+handle_packet(#disconnect{properties = #{session_expiry_interval := SE}},
+ #state{session_expiry = 0} = State) when SE>0 ->
+ %% Protocol violation
+ {error, State, session_expiry_non_zero};
+handle_packet(#disconnect{code = Code, properties = Props},
+ #state{jid = #jid{lserver = Server}} = State) ->
+ Reason = maps:get(reason_string, Props, <<>>),
+ Expiry = case maps:get(session_expiry_interval, Props, undefined) of
+ undefined -> State#state.session_expiry;
+ SE -> min(SE, session_expiry(Server))
+ end,
+ State1 = State#state{session_expiry = Expiry},
+ State2 = case Code of
+ 'normal-disconnection' -> State1#state{will = undefined};
+ _ -> State1
+ end,
+ {error, State2, {peer_disconnected, Code, Reason}};
+handle_packet(Pkt, State) ->
+ ?WARNING_MSG("Unexpected packet:~n~s~n** when state:~n~s",
+ [pp(Pkt), pp(State)]),
+ {error, State, {unexpected_packet, element(1, Pkt)}}.
+
+terminate(_, #state{peername = undefined}) ->
+ ok;
+terminate(Reason, State) ->
+ Reason1 = case Reason of
+ shutdown -> shutdown;
+ {shutdown, _} -> shutdown;
+ normal -> State#state.stop_reason;
+ {process_limit, _} -> queue_full;
+ _ -> internal_server_error
+ end,
+ case State#state.jid of
+ #jid{} -> unregister_session(State, Reason1);
+ undefined -> log_disconnection(State, Reason1)
+ end,
+ State1 = disconnect(State, Reason1),
+ publish_will(State1).
+
+code_change(_OldVsn, State, _Extra) ->
+ {ok, upgrade_state(State)}.
+
+%%%===================================================================
+%%% State transitions
+%%%===================================================================
+-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
+noreply(#state{timeout = infinity} = State) ->
+ {noreply, State, infinity};
+noreply(#state{timeout = {MSecs, StartTime}} = State) ->
+ CurrentTime = current_time(),
+ Timeout = max(0, MSecs - CurrentTime + StartTime),
+ {noreply, State, Timeout}.
+
+-spec stop(state(), error_reason()) -> {noreply, state(), infinity} |
+ {stop, normal, state()}.
+stop(#state{session_expiry = 0} = State, Reason) ->
+ {stop, normal, State#state{stop_reason = Reason}};
+stop(#state{session_expiry = SessExp} = State, Reason) ->
+ case State#state.socket of
+ undefined ->
+ noreply(State);
+ _ ->
+ WillDelay = State#state.will_delay,
+ log_disconnection(State, Reason),
+ State1 = disconnect(State, Reason),
+ State2 = if WillDelay == 0 ->
+ publish_will(State1);
+ WillDelay < SessExp ->
+ erlang:start_timer(
+ timer:seconds(WillDelay), self(), publish_will),
+ State1;
+ true ->
+ State1
+ end,
+ State3 = set_timeout(State2, timer:seconds(SessExp)),
+ State4 = State3#state{stop_reason = Reason},
+ noreply(State4)
+ end.
+
+-spec upgrade_state(term()) -> state().
+upgrade_state(State) ->
+ %% Here will be the code upgrading state between different
+ %% code versions. This is needed when doing session resumption from
+ %% remote node running the version of the code with incompatible #state{}
+ %% record fields. Also used by code_change/3 callback.
+ %% Use element(2, State) for vsn comparison.
+ State.
+
+%%%===================================================================
+%%% Session management
+%%%===================================================================
+-spec open_session(state(), jid(), boolean()) -> {ok, boolean(), state()} |
+ {error, state(), error_reason()}.
+open_session(State, JID, _CleanStart = false) ->
+ USR = {_, S, _} = jid:tolower(JID),
+ case mod_mqtt:lookup_session(USR) of
+ {ok, Pid} ->
+ try p1_server:call(Pid, {get_state, self()}, ?CALL_TIMEOUT) of
+ {ok, State1} ->
+ State2 = upgrade_state(State1),
+ Q1 = case queue_type(S) of
+ ram -> State2#state.queue;
+ _ -> p1_queue:ram_to_file(State2#state.queue)
+ end,
+ Q2 = p1_queue:set_limit(Q1, queue_limit(S)),
+ State3 = State#state{queue = Q2,
+ acks = State2#state.acks,
+ subscriptions = State2#state.subscriptions,
+ id = State2#state.id,
+ in_flight = State2#state.in_flight},
+ ?DEBUG("Resumed state from ~p at ~s:~n~s",
+ [Pid, node(Pid), pp(State3)]),
+ register_session(State3, JID, Pid);
+ {error, Why} ->
+ {error, State, Why}
+ catch exit:{Why, {p1_server, _, _}} ->
+ ?WARNING_MSG("Failed to copy session state from ~p at ~s: ~s",
+ [Pid, node(Pid), format_exit_reason(Why)]),
+ register_session(State, JID, undefined)
+ end;
+ {error, notfound} ->
+ register_session(State, JID, undefined);
+ {error, Why} ->
+ {error, State, Why}
+ end;
+open_session(State, JID, _CleanStart = true) ->
+ register_session(State, JID, undefined).
+
+-spec register_session(state(), jid(), undefined | pid()) ->
+ {ok, boolean(), state()} | {error, state(), error_reason()}.
+register_session(#state{peername = IP} = State, JID, Parent) ->
+ USR = {_, S, _} = jid:tolower(JID),
+ case mod_mqtt:open_session(USR) of
+ ok ->
+ case resubscribe(USR, State#state.subscriptions) of
+ ok ->
+ ?INFO_MSG("~s for ~s from ~s",
+ [if is_pid(Parent) ->
+ io_lib:format(
+ "Reopened MQTT session via ~p",
+ [Parent]);
+ true ->
+ "Opened MQTT session"
+ end,
+ jid:encode(JID),
+ ejabberd_config:may_hide_data(
+ misc:ip_to_list(IP))]),
+ Q = case State#state.queue of
+ undefined ->
+ p1_queue:new(queue_type(S), queue_limit(S));
+ Q1 ->
+ Q1
+ end,
+ {ok, is_pid(Parent), State#state{jid = JID, queue = Q}};
+ {error, Why} ->
+ mod_mqtt:close_session(USR),
+ {error, State#state{session_expiry = 0}, Why}
+ end;
+ {error, Reason} ->
+ ?ERROR_MSG("Failed to register MQTT session for ~s from ~s: ~s",
+ err_args(JID, IP, Reason)),
+ {error, State, Reason}
+ end.
+
+-spec unregister_session(state(), error_reason()) -> ok.
+unregister_session(#state{jid = #jid{} = JID, peername = IP} = State, Reason) ->
+ Msg = "Closing MQTT session for ~s from ~s: ~s",
+ case Reason of
+ {Tag, _} when Tag == replaced; Tag == resumed ->
+ ?DEBUG(Msg, err_args(JID, IP, Reason));
+ {socket, _} ->
+ ?INFO_MSG(Msg, err_args(JID, IP, Reason));
+ Tag when Tag == idle_connection; Tag == session_expired; Tag == shutdown ->
+ ?INFO_MSG(Msg, err_args(JID, IP, Reason));
+ {peer_disconnected, Code, _} ->
+ case mqtt_codec:is_error_code(Code) of
+ true -> ?WARNING_MSG(Msg, err_args(JID, IP, Reason));
+ false -> ?INFO_MSG(Msg, err_args(JID, IP, Reason))
+ end;
+ _ ->
+ ?WARNING_MSG(Msg, err_args(JID, IP, Reason))
+ end,
+ USR = jid:tolower(JID),
+ unsubscribe(maps:keys(State#state.subscriptions), USR, #{}),
+ case mod_mqtt:close_session(USR) of
+ ok -> ok;
+ {error, Why} ->
+ ?ERROR_MSG(
+ "Failed to close MQTT session for ~s from ~s: ~s",
+ err_args(JID, IP, Why))
+ end;
+unregister_session(_, _) ->
+ ok.
+
+%%%===================================================================
+%%% CONNECT/PUBLISH/SUBSCRIBE/UNSUBSCRIBE handlers
+%%%===================================================================
+-spec handle_connect(connect(), state()) -> {ok, state()} |
+ {error, state(), error_reason()}.
+handle_connect(#connect{clean_start = CleanStart} = Pkt,
+ #state{jid = undefined, peername = IP} = State) ->
+ case authenticate(Pkt, IP) of
+ {ok, JID} ->
+ case validate_will(Pkt, JID) of
+ ok ->
+ case open_session(State, JID, CleanStart) of
+ {ok, SessionPresent, State1} ->
+ State2 = set_session_properties(State1, Pkt),
+ ConnackProps = get_connack_properties(State2, Pkt),
+ Connack = #connack{session_present = SessionPresent,
+ properties = ConnackProps},
+ case send(State2, Connack) of
+ {ok, State3} -> resend(State3);
+ {error, _, _} = Err -> Err
+ end;
+ {error, _, _} = Err ->
+ Err
+ end;
+ {error, Reason} ->
+ {error, State, Reason}
+ end;
+ {error, Code} ->
+ {error, State, {auth, Code}}
+ end.
+
+-spec handle_publish(publish(), state()) -> {ok, state()} |
+ {error, state(), error_reason()}.
+handle_publish(#publish{qos = QoS, id = ID} = Publish, State) ->
+ case QoS == 2 andalso maps:is_key(ID, State#state.acks) of
+ true ->
+ send(State, maps:get(ID, State#state.acks));
+ false ->
+ case validate_publish(Publish, State) of
+ ok ->
+ State1 = store_topic_alias(State, Publish),
+ Ret = publish(State1, Publish),
+ {Code, Props} = get_publish_code_props(Ret),
+ case Ret of
+ {ok, _} when QoS == 2 ->
+ Pkt = #pubrec{id = ID, code = Code,
+ properties = Props},
+ Acks = maps:put(ID, Pkt, State1#state.acks),
+ State2 = State1#state{acks = Acks},
+ send(State2, Pkt);
+ {error, _} when QoS == 2 ->
+ Pkt = #pubrec{id = ID, code = Code,
+ properties = Props},
+ send(State1, Pkt);
+ _ when QoS == 1 ->
+ Pkt = #puback{id = ID, code = Code,
+ properties = Props},
+ send(State1, Pkt);
+ _ ->
+ {ok, State1}
+ end;
+ {error, Why} ->
+ {error, State, Why}
+ end
+ end.
+
+-spec handle_subscribe(subscribe(), state()) ->
+ {ok, state()} | {error, state(), error_reason()}.
+handle_subscribe(#subscribe{id = ID, filters = TopicFilters} = Pkt, State) ->
+ case validate_subscribe(Pkt) of
+ ok ->
+ USR = jid:tolower(State#state.jid),
+ SubID = maps:get(subscription_identifier, Pkt#subscribe.properties, 0),
+ OldSubs = State#state.subscriptions,
+ {Codes, NewSubs, Props} = subscribe(TopicFilters, USR, SubID),
+ Subs = maps:merge(OldSubs, NewSubs),
+ State1 = State#state{subscriptions = Subs},
+ Suback = #suback{id = ID, codes = Codes, properties = Props},
+ case send(State1, Suback) of
+ {ok, State2} ->
+ Pubs = select_retained(USR, NewSubs, OldSubs),
+ send_retained(State2, Pubs);
+ {error, _, _} = Err ->
+ Err
+ end;
+ {error, Why} ->
+ {error, State, Why}
+ end.
+
+-spec handle_unsubscribe(unsubscribe(), state()) ->
+ {ok, state()} | {error, state(), error_reason()}.
+handle_unsubscribe(#unsubscribe{id = ID, filters = TopicFilters}, State) ->
+ USR = jid:tolower(State#state.jid),
+ {Codes, Subs, Props} = unsubscribe(TopicFilters, USR, State#state.subscriptions),
+ State1 = State#state{subscriptions = Subs},
+ Unsuback = #unsuback{id = ID, codes = Codes, properties = Props},
+ send(State1, Unsuback).
+
+%%%===================================================================
+%%% Aux functions for CONNECT/PUBLISH/SUBSCRIBE/UNSUBSCRIBE handlers
+%%%===================================================================
+-spec set_session_properties(state(), connect()) -> state().
+set_session_properties(#state{version = Version,
+ jid = #jid{lserver = Server}} = State,
+ #connect{clean_start = CleanStart,
+ keep_alive = KeepAlive,
+ properties = Props} = Pkt) ->
+ SEMin = case CleanStart of
+ false when Version == ?MQTT_VERSION_4 -> infinity;
+ _ -> maps:get(session_expiry_interval, Props, 0)
+ end,
+ SEConfig = session_expiry(Server),
+ State1 = State#state{session_expiry = min(SEMin, SEConfig)},
+ State2 = set_will_properties(State1, Pkt),
+ set_keep_alive(State2, KeepAlive).
+
+-spec set_will_properties(state(), connect()) -> state().
+set_will_properties(State, #connect{will = #publish{} = Will,
+ will_properties = Props}) ->
+ {WillDelay, Props1} = case maps:take(will_delay_interval, Props) of
+ error -> {0, Props};
+ Ret -> Ret
+ end,
+ State#state{will = Will#publish{properties = Props1},
+ will_delay = WillDelay};
+set_will_properties(State, _) ->
+ State.
+
+-spec get_connack_properties(state(), connect()) -> properties().
+get_connack_properties(#state{session_expiry = SessExp, jid = JID},
+ #connect{client_id = ClientID,
+ keep_alive = KeepAlive}) ->
+ Props1 = case ClientID of
+ <<>> -> #{assigned_client_identifier => JID#jid.lresource};
+ _ -> #{}
+ end,
+ Props1#{session_expiry_interval => SessExp,
+ shared_subscription_available => false,
+ topic_alias_maximum => topic_alias_maximum(JID#jid.lserver),
+ server_keep_alive => KeepAlive}.
+
+-spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer()) ->
+ {[reason_code()], map(), properties()}.
+subscribe(TopicFilters, USR, SubID) ->
+ subscribe(TopicFilters, USR, SubID, [], #{}, ok).
+
+-spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer(),
+ [reason_code()], map(), ok | {error, error_reason()}) ->
+ {[reason_code()], map(), properties()}.
+subscribe([{TopicFilter, SubOpts}|TopicFilters], USR, SubID, Codes, Subs, Err) ->
+ case mod_mqtt:subscribe(USR, TopicFilter, SubOpts, SubID) of
+ ok ->
+ Code = subscribe_reason_code(SubOpts#sub_opts.qos),
+ subscribe(TopicFilters, USR, SubID, [Code|Codes],
+ maps:put(TopicFilter, {SubOpts, SubID}, Subs), Err);
+ {error, Why} = Err1 ->
+ Code = subscribe_reason_code(Why),
+ subscribe(TopicFilters, USR, SubID, [Code|Codes], Subs, Err1)
+ end;
+subscribe([], _USR, _SubID, Codes, Subs, Err) ->
+ Props = case Err of
+ ok -> #{};
+ {error, Why} ->
+ #{reason_string => format_reason_string(Why)}
+ end,
+ {lists:reverse(Codes), Subs, Props}.
+
+-spec unsubscribe([binary()], jid:ljid(), map()) ->
+ {[reason_code()], map(), properties()}.
+unsubscribe(TopicFilters, USR, Subs) ->
+ unsubscribe(TopicFilters, USR, [], Subs, ok).
+
+-spec unsubscribe([binary()], jid:ljid(),
+ [reason_code()], map(),
+ ok | {error, error_reason()}) ->
+ {[reason_code()], map(), properties()}.
+unsubscribe([TopicFilter|TopicFilters], USR, Codes, Subs, Err) ->
+ case mod_mqtt:unsubscribe(USR, TopicFilter) of
+ ok ->
+ unsubscribe(TopicFilters, USR, [success|Codes],
+ maps:remove(TopicFilter, Subs), Err);
+ {error, notfound} ->
+ unsubscribe(TopicFilters, USR,
+ ['no-subscription-existed'|Codes],
+ maps:remove(TopicFilter, Subs), Err);
+ {error, Why} = Err1 ->
+ Code = unsubscribe_reason_code(Why),
+ unsubscribe(TopicFilters, USR, [Code|Codes], Subs, Err1)
+ end;
+unsubscribe([], _USR, Codes, Subs, Err) ->
+ Props = case Err of
+ ok -> #{};
+ {error, Why} ->
+ #{reason_string => format_reason_string(Why)}
+ end,
+ {lists:reverse(Codes), Subs, Props}.
+
+-spec select_retained(jid:ljid(), map(), map()) -> [{publish(), seconds()}].
+select_retained(USR, NewSubs, OldSubs) ->
+ lists:flatten(
+ maps:fold(
+ fun(_Filter, {#sub_opts{retain_handling = 2}, _SubID}, Acc) ->
+ Acc;
+ (Filter, {#sub_opts{retain_handling = 1, qos = QoS}, SubID}, Acc) ->
+ case maps:is_key(Filter, OldSubs) of
+ true -> Acc;
+ false -> [mod_mqtt:select_retained(USR, Filter, QoS, SubID)|Acc]
+ end;
+ (Filter, {#sub_opts{qos = QoS}, SubID}, Acc) ->
+ [mod_mqtt:select_retained(USR, Filter, QoS, SubID)|Acc]
+ end, [], NewSubs)).
+
+-spec send_retained(state(), [{publish(), seconds()}]) ->
+ {ok, state()} | {error, state(), error_reason()}.
+send_retained(State, [{#publish{meta = Meta} = Pub, Expiry}|Pubs]) ->
+ I = next_id(State#state.id),
+ Meta1 = Meta#{expiry_time => Expiry},
+ Pub1 = Pub#publish{id = I, retain = true, meta = Meta1},
+ case send(State#state{id = I}, Pub1) of
+ {ok, State1} ->
+ send_retained(State1, Pubs);
+ Err ->
+ Err
+ end;
+send_retained(State, []) ->
+ {ok, State}.
+
+-spec publish(state(), publish()) -> {ok, non_neg_integer()} |
+ {error, error_reason()}.
+publish(State, #publish{topic = Topic, properties = Props} = Pkt) ->
+ MessageExpiry = maps:get(message_expiry_interval, Props, ?MAX_UINT32),
+ ExpiryTime = min(unix_time() + MessageExpiry, ?MAX_UINT32),
+ USR = jid:tolower(State#state.jid),
+ Props1 = maps:filter(
+ fun(payload_format_indicator, _) -> true;
+ (content_type, _) -> true;
+ (response_topic, _) -> true;
+ (correlation_data, _) -> true;
+ (user_property, _) -> true;
+ (_, _) -> false
+ end, Props),
+ Topic1 = case Topic of
+ <<>> ->
+ Alias = maps:get(topic_alias, Props),
+ maps:get(Alias, State#state.topic_aliases);
+ _ ->
+ Topic
+ end,
+ Pkt1 = Pkt#publish{topic = Topic1, properties = Props1},
+ mod_mqtt:publish(USR, Pkt1, ExpiryTime).
+
+-spec store_topic_alias(state(), publish()) -> state().
+store_topic_alias(State, #publish{topic = <<_, _/binary>> = Topic,
+ properties = #{topic_alias := Alias}}) ->
+ Aliases = maps:put(Alias, Topic, State#state.topic_aliases),
+ State#state{topic_aliases = Aliases};
+store_topic_alias(State, _) ->
+ State.
+
+%%%===================================================================
+%%% Socket management
+%%%===================================================================
+-spec send(state(), mqtt_packet()) -> {ok, state()} |
+ {error, state(), error_reason()}.
+send(State, #publish{} = Pkt) ->
+ case is_expired(Pkt) of
+ {false, Pkt1} ->
+ case State#state.in_flight == undefined andalso
+ p1_queue:is_empty(State#state.queue) of
+ true ->
+ Dup = case Pkt1#publish.qos of
+ 0 -> undefined;
+ _ -> Pkt1
+ end,
+ State1 = State#state{in_flight = Dup},
+ {ok, do_send(State1, Pkt1)};
+ false ->
+ ?DEBUG("Queueing packet:~n~s~n** when state:~n~s",
+ [pp(Pkt), pp(State)]),
+ try p1_queue:in(Pkt, State#state.queue) of
+ Q ->
+ State1 = State#state{queue = Q},
+ {ok, State1}
+ catch error:full ->
+ Q = p1_queue:clear(State#state.queue),
+ State1 = State#state{queue = Q, session_expiry = 0},
+ {error, State1, queue_full}
+ end
+ end;
+ true ->
+ {ok, State}
+ end;
+send(State, Pkt) ->
+ {ok, do_send(State, Pkt)}.
+
+-spec resend(state()) -> {ok, state()} | {error, state(), error_reason()}.
+resend(#state{in_flight = undefined} = State) ->
+ case p1_queue:out(State#state.queue) of
+ {{value, #publish{qos = QoS} = Pkt}, Q} ->
+ case is_expired(Pkt) of
+ true ->
+ resend(State#state{queue = Q});
+ {false, Pkt1} when QoS > 0 ->
+ State1 = State#state{in_flight = Pkt1, queue = Q},
+ {ok, do_send(State1, Pkt1)};
+ {false, Pkt1} ->
+ State1 = do_send(State#state{queue = Q}, Pkt1),
+ resend(State1)
+ end;
+ {empty, _} ->
+ {ok, State}
+ end;
+resend(#state{in_flight = Pkt} = State) ->
+ {ok, do_send(State, set_dup_flag(Pkt))}.
+
+-spec do_send(state(), mqtt_packet()) -> state().
+do_send(#state{socket = {SockMod, Sock} = Socket} = State, Pkt) ->
+ ?DEBUG("Send MQTT packet:~n~s", [pp(Pkt)]),
+ Data = mqtt_codec:encode(State#state.version, Pkt),
+ Res = SockMod:send(Sock, Data),
+ check_sock_result(Socket, Res),
+ State;
+do_send(State, _Pkt) ->
+ State.
+
+-spec activate(socket()) -> ok.
+activate({SockMod, Sock} = Socket) ->
+ Res = case SockMod of
+ gen_tcp -> inet:setopts(Sock, [{active, once}]);
+ _ -> SockMod:setopts(Sock, [{active, once}])
+ end,
+ check_sock_result(Socket, Res).
+
+-spec disconnect(state(), error_reason()) -> state().
+disconnect(#state{socket = {SockMod, Sock}} = State, Err) ->
+ State1 = case Err of
+ {auth, Code} ->
+ do_send(State, #connack{code = Code});
+ {codec, {Tag, _, _}} when Tag == unsupported_protocol_version;
+ Tag == unsupported_protocol_name ->
+ do_send(State#state{version = ?MQTT_VERSION_4},
+ #connack{code = connack_reason_code(Err)});
+ _ when State#state.version == undefined ->
+ State;
+ {Tag, _} when Tag == socket; Tag == tls ->
+ State;
+ {peer_disconnected, _, _} ->
+ State;
+ _ ->
+ Props = #{reason_string => format_reason_string(Err)},
+ case State#state.jid of
+ undefined ->
+ Code = connack_reason_code(Err),
+ Pkt = #connack{code = Code, properties = Props},
+ do_send(State, Pkt);
+ _ when State#state.version == ?MQTT_VERSION_5 ->
+ Code = disconnect_reason_code(Err),
+ Pkt = #disconnect{code = Code, properties = Props},
+ do_send(State, Pkt);
+ _ ->
+ State
+ end
+ end,
+ SockMod:close(Sock),
+ State1#state{socket = undefined,
+ version = undefined,
+ codec = mqtt_codec:renew(State#state.codec)};
+disconnect(State, _) ->
+ State.
+
+-spec check_sock_result(socket(), ok | {error, inet:posix()}) -> ok.
+check_sock_result(_, ok) ->
+ ok;
+check_sock_result({_, Sock}, {error, Why}) ->
+ self() ! {tcp_closed, Sock},
+ ?DEBUG("MQTT socket error: ~p", [format_inet_error(Why)]).
+
+-spec starttls(socket()) -> {ok, socket()} | {error, error_reason()}.
+starttls({fast_tls, Socket}) ->
+ case ejabberd_pkix:get_certfile() of
+ {ok, Cert} ->
+ case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of
+ {ok, TLSSock} ->
+ {ok, {fast_tls, TLSSock}};
+ {error, Why} ->
+ {error, {tls, Why}}
+ end;
+ error ->
+ {error, {tls, no_certfile}}
+ end;
+starttls(Socket) ->
+ {ok, Socket}.
+
+-spec recv_data(socket(), binary()) -> {ok, binary()} | {error, error_reason()}.
+recv_data({fast_tls, Sock}, Data) ->
+ case fast_tls:recv_data(Sock, Data) of
+ {ok, _} = OK -> OK;
+ {error, E} when is_atom(E) -> {error, {socket, E}};
+ {error, E} when is_binary(E) -> {error, {tls, E}};
+ {error, _} = Err -> Err
+ end;
+recv_data(_, Data) ->
+ {ok, Data}.
+
+%%%===================================================================
+%%% Formatters
+%%%===================================================================
+-spec pp(any()) -> iolist().
+pp(Term) ->
+ io_lib_pretty:print(Term, fun pp/2).
+
+-spec format_inet_error(socket_error_reason()) -> string().
+format_inet_error(closed) ->
+ "connection closed";
+format_inet_error(timeout) ->
+ format_inet_error(etimedout);
+format_inet_error(Reason) ->
+ case inet:format_error(Reason) of
+ "unknown POSIX error" -> atom_to_list(Reason);
+ Txt -> Txt
+ end.
+
+-spec format_tls_error(atom() | binary()) -> string() | binary().
+format_tls_error(no_cerfile) ->
+ "certificate not found";
+format_tls_error(Reason) when is_atom(Reason) ->
+ format_inet_error(Reason);
+format_tls_error(Reason) ->
+ Reason.
+
+-spec format_exit_reason(term()) -> string().
+format_exit_reason(noproc) ->
+ "process is dead";
+format_exit_reason(normal) ->
+ "process has exited";
+format_exit_reason(killed) ->
+ "process has been killed";
+format_exit_reason(timeout) ->
+ "remote call to process timed out";
+format_exit_reason(Why) ->
+ format("unexpected error: ~p", [Why]).
+
+%% Same as format_error/1, but hides sensitive data
+%% and returns result as binary
+-spec format_reason_string(error_reason()) -> binary().
+format_reason_string({resumed, _}) ->
+ <<"Resumed by another connection">>;
+format_reason_string({replaced, _}) ->
+ <<"Replaced by another connection">>;
+format_reason_string(Err) ->
+ list_to_binary(format_error(Err)).
+
+-spec format(io:format(), list()) -> string().
+format(Fmt, Args) ->
+ lists:flatten(io_lib:format(Fmt, Args)).
+
+-spec pp(atom(), non_neg_integer()) -> [atom()] | no.
+pp(state, 17) -> record_info(fields, state);
+pp(Rec, Size) -> mqtt_codec:pp(Rec, Size).
+
+-spec publish_reason_code(error_reason()) -> reason_code().
+publish_reason_code(publish_forbidden) -> 'topic-name-invalid';
+publish_reason_code(_) -> 'implementation-specific-error'.
+
+-spec subscribe_reason_code(qos() | error_reason()) -> reason_code().
+subscribe_reason_code(0) -> 'granted-qos-0';
+subscribe_reason_code(1) -> 'granted-qos-1';
+subscribe_reason_code(2) -> 'granted-qos-2';
+subscribe_reason_code(subscribe_forbidden) -> 'topic-filter-invalid';
+subscribe_reason_code(_) -> 'implementation-specific-error'.
+
+-spec unsubscribe_reason_code(error_reason()) -> reason_code().
+unsubscribe_reason_code(_) -> 'implementation-specific-error'.
+
+-spec disconnect_reason_code(error_reason()) -> reason_code().
+disconnect_reason_code({code, Code}) -> Code;
+disconnect_reason_code({codec, Err}) -> mqtt_codec:error_reason_code(Err);
+disconnect_reason_code({unexpected_packet, _}) -> 'protocol-error';
+disconnect_reason_code({replaced, _}) -> 'session-taken-over';
+disconnect_reason_code({resumed, _}) -> 'session-taken-over';
+disconnect_reason_code(internal_server_error) -> 'implementation-specific-error';
+disconnect_reason_code(db_failure) -> 'implementation-specific-error';
+disconnect_reason_code(idle_connection) -> 'keep-alive-timeout';
+disconnect_reason_code(queue_full) -> 'quota-exceeded';
+disconnect_reason_code(shutdown) -> 'server-shutting-down';
+disconnect_reason_code(subscribe_forbidden) -> 'topic-filter-invalid';
+disconnect_reason_code(publish_forbidden) -> 'topic-name-invalid';
+disconnect_reason_code(will_topic_forbidden) -> 'topic-name-invalid';
+disconnect_reason_code({payload_format_invalid, _}) -> 'payload-format-invalid';
+disconnect_reason_code(session_expiry_non_zero) -> 'protocol-error';
+disconnect_reason_code(unknown_topic_alias) -> 'protocol-error';
+disconnect_reason_code(_) -> 'unspecified-error'.
+
+-spec connack_reason_code(error_reason()) -> reason_code().
+connack_reason_code({Tag, Code}) when Tag == auth; Tag == code -> Code;
+connack_reason_code({codec, Err}) -> mqtt_codec:error_reason_code(Err);
+connack_reason_code({unexpected_packet, _}) -> 'protocol-error';
+connack_reason_code(internal_server_error) -> 'implementation-specific-error';
+connack_reason_code(db_failure) -> 'implementation-specific-error';
+connack_reason_code(idle_connection) -> 'keep-alive-timeout';
+connack_reason_code(queue_full) -> 'quota-exceeded';
+connack_reason_code(shutdown) -> 'server-shutting-down';
+connack_reason_code(will_topic_forbidden) -> 'topic-name-invalid';
+connack_reason_code({payload_format_invalid, _}) -> 'payload-format-invalid';
+connack_reason_code(session_expiry_non_zero) -> 'protocol-error';
+connack_reason_code(_) -> 'unspecified-error'.
+
+%%%===================================================================
+%%% Configuration processing
+%%%===================================================================
+-spec queue_type(binary()) -> ram | file.
+queue_type(Host) ->
+ gen_mod:get_module_opt(Host, mod_mqtt, queue_type).
+
+-spec queue_limit(binary()) -> non_neg_integer() | unlimited.
+queue_limit(Host) ->
+ gen_mod:get_module_opt(Host, mod_mqtt, max_queue).
+
+-spec session_expiry(binary()) -> seconds().
+session_expiry(Host) ->
+ gen_mod:get_module_opt(Host, mod_mqtt, session_expiry).
+
+-spec topic_alias_maximum(binary()) -> non_neg_integer().
+topic_alias_maximum(Host) ->
+ gen_mod:get_module_opt(Host, mod_mqtt, max_topic_aliases).
+
+%%%===================================================================
+%%% Timings
+%%%===================================================================
+-spec current_time() -> milli_seconds().
+current_time() ->
+ p1_time_compat:monotonic_time(milli_seconds).
+
+-spec unix_time() -> seconds().
+unix_time() ->
+ p1_time_compat:system_time(seconds).
+
+-spec set_keep_alive(state(), seconds()) -> state().
+set_keep_alive(State, 0) ->
+ ?DEBUG("Disabling MQTT keep-alive", []),
+ State#state{timeout = infinity};
+set_keep_alive(State, Secs) ->
+ Secs1 = round(Secs * 1.5),
+ ?DEBUG("Setting MQTT keep-alive to ~B seconds", [Secs1]),
+ set_timeout(State, timer:seconds(Secs1)).
+
+-spec reset_keep_alive(state()) -> state().
+reset_keep_alive(#state{timeout = {MSecs, _}, jid = #jid{}} = State) ->
+ set_timeout(State, MSecs);
+reset_keep_alive(State) ->
+ State.
+
+-spec set_timeout(state(), milli_seconds()) -> state().
+set_timeout(State, MSecs) ->
+ Time = current_time(),
+ State#state{timeout = {MSecs, Time}}.
+
+-spec is_expired(publish()) -> true | {false, publish()}.
+is_expired(#publish{meta = Meta, properties = Props} = Pkt) ->
+ case maps:get(expiry_time, Meta, ?MAX_UINT32) of
+ ?MAX_UINT32 ->
+ {false, Pkt};
+ ExpiryTime ->
+ Left = ExpiryTime - unix_time(),
+ if Left > 0 ->
+ Props1 = Props#{message_expiry_interval => Left},
+ {false, Pkt#publish{properties = Props1}};
+ true ->
+ ?DEBUG("Dropping expired packet:~n~s", [pp(Pkt)]),
+ true
+ end
+ end.
+
+%%%===================================================================
+%%% Authentication
+%%%===================================================================
+-spec parse_credentials(connect()) -> {ok, jid:jid()} | {error, reason_code()}.
+parse_credentials(#connect{client_id = <<>>}) ->
+ parse_credentials(#connect{client_id = p1_rand:get_string()});
+parse_credentials(#connect{username = <<>>, client_id = ClientID}) ->
+ Host = ejabberd_config:get_myname(),
+ JID = case jid:make(ClientID, Host) of
+ error -> jid:make(str:sha(ClientID), Host);
+ J -> J
+ end,
+ parse_credentials(JID, ClientID);
+parse_credentials(#connect{username = User} = Pkt) ->
+ try jid:decode(User) of
+ #jid{luser = <<>>} ->
+ case jid:make(User, ejabberd_config:get_myname()) of
+ error ->
+ {error, 'bad-user-name-or-password'};
+ JID ->
+ parse_credentials(JID, Pkt#connect.client_id)
+ end;
+ JID ->
+ parse_credentials(JID, Pkt#connect.client_id)
+ catch _:{bad_jid, _} ->
+ {error, 'bad-user-name-or-password'}
+ end.
+
+-spec parse_credentials(jid:jid(), binary()) -> {ok, jid:jid()} | {error, reason_code()}.
+parse_credentials(JID, ClientID) ->
+ case gen_mod:is_loaded(JID#jid.lserver, mod_mqtt) of
+ false ->
+ {error, 'server-unavailable'};
+ true ->
+ case jid:replace_resource(JID, ClientID) of
+ error ->
+ {error, 'client-identifier-not-valid'};
+ JID1 ->
+ {ok, JID1}
+ end
+ end.
+
+-spec authenticate(connect(), peername()) -> {ok, jid:jid()} | {error, reason_code()}.
+authenticate(#connect{password = Pass} = Pkt, IP) ->
+ case parse_credentials(Pkt) of
+ {ok, #jid{luser = LUser, lserver = LServer} = JID} ->
+ case ejabberd_auth:check_password_with_authmodule(
+ LUser, <<>>, LServer, Pass) of
+ {true, AuthModule} ->
+ ?INFO_MSG(
+ "Accepted MQTT authentication for ~s "
+ "by ~s backend from ~s",
+ [jid:encode(JID),
+ ejabberd_auth:backend_type(AuthModule),
+ ejabberd_config:may_hide_data(misc:ip_to_list(IP))]),
+ {ok, JID};
+ false ->
+ {error, 'not-authorized'}
+ end;
+ {error, _} = Err ->
+ Err
+ end.
+
+%%%===================================================================
+%%% Validators
+%%%===================================================================
+-spec validate_will(connect(), jid:jid()) -> ok | {error, reason_code()}.
+validate_will(#connect{will = undefined}, _) ->
+ ok;
+validate_will(#connect{will = #publish{topic = Topic, payload = Payload},
+ will_properties = Props}, JID) ->
+ case mod_mqtt:check_publish_access(Topic, jid:tolower(JID)) of
+ deny -> {error, will_topic_forbidden};
+ allow -> validate_payload(Props, Payload, will)
+ end.
+
+-spec validate_publish(publish(), state()) -> ok | {error, error_reason()}.
+validate_publish(#publish{topic = Topic, payload = Payload,
+ properties = Props}, State) ->
+ case validate_topic(Topic, Props, State) of
+ ok -> validate_payload(Props, Payload, publish);
+ Err -> Err
+ end.
+
+-spec validate_subscribe(subscribe()) -> ok | {error, error_reason()}.
+validate_subscribe(#subscribe{filters = Filters}) ->
+ case lists:any(
+ fun({<<"$share/", _/binary>>, _}) -> true;
+ (_) -> false
+ end, Filters) of
+ true ->
+ {error, {code, 'shared-subscriptions-not-supported'}};
+ false ->
+ ok
+ end.
+
+-spec validate_topic(binary(), properties(), state()) -> ok | {error, error_reason()}.
+validate_topic(<<>>, Props, State) ->
+ case maps:get(topic_alias, Props, 0) of
+ 0 ->
+ {error, {code, 'topic-alias-invalid'}};
+ Alias ->
+ case maps:is_key(Alias, State#state.topic_aliases) of
+ true -> ok;
+ false -> {error, unknown_topic_alias}
+ end
+ end;
+validate_topic(_, #{topic_alias := Alias}, State) ->
+ JID = State#state.jid,
+ Max = topic_alias_maximum(JID#jid.lserver),
+ if Alias > Max ->
+ {error, {code, 'topic-alias-invalid'}};
+ true ->
+ ok
+ end;
+validate_topic(_, _, _) ->
+ ok.
+
+-spec validate_payload(properties(), binary(), will | publish) -> ok | {error, error_reason()}.
+validate_payload(#{payload_format_indicator := utf8}, Payload, Type) ->
+ try mqtt_codec:utf8(Payload) of
+ _ -> ok
+ catch _:_ ->
+ {error, {payload_format_invalid, Type}}
+ end;
+validate_payload(_, _, _) ->
+ ok.
+
+%%%===================================================================
+%%% Misc
+%%%===================================================================
+-spec resubscribe(jid:ljid(), map()) -> ok | {error, error_reason()}.
+resubscribe(USR, Subs) ->
+ case maps:fold(
+ fun(TopicFilter, {SubOpts, ID}, ok) ->
+ mod_mqtt:subscribe(USR, TopicFilter, SubOpts, ID);
+ (_, _, {error, _} = Err) ->
+ Err
+ end, ok, Subs) of
+ ok ->
+ ok;
+ {error, _} = Err1 ->
+ unsubscribe(maps:keys(Subs), USR, #{}),
+ Err1
+ end.
+
+-spec publish_will(state()) -> state().
+publish_will(#state{will = #publish{} = Will,
+ jid = #jid{} = JID} = State) ->
+ case publish(State, Will) of
+ {ok, _} ->
+ ?DEBUG("Will of ~s has been published to ~s",
+ [jid:encode(JID), Will#publish.topic]);
+ {error, Why} ->
+ ?WARNING_MSG("Failed to publish will of ~s to ~s: ~s",
+ [jid:encode(JID), Will#publish.topic,
+ format_error(Why)])
+ end,
+ State#state{will = undefined};
+publish_will(State) ->
+ State.
+
+-spec next_id(non_neg_integer()) -> pos_integer().
+next_id(ID) ->
+ (ID rem 65535) + 1.
+
+-spec set_dup_flag(mqtt_packet()) -> mqtt_packet().
+set_dup_flag(#publish{qos = QoS} = Pkt) when QoS>0 ->
+ Pkt#publish{dup = true};
+set_dup_flag(Pkt) ->
+ Pkt.
+
+-spec get_publish_code_props({ok, non_neg_integer()} |
+ {error, error_reason()}) -> {reason_code(), properties()}.
+get_publish_code_props({ok, 0}) ->
+ {'no-matching-subscribers', #{}};
+get_publish_code_props({ok, _}) ->
+ {success, #{}};
+get_publish_code_props({error, Err}) ->
+ Code = publish_reason_code(Err),
+ Reason = format_reason_string(Err),
+ {Code, #{reason_string => Reason}}.
+
+-spec err_args(undefined | jid:jid(), peername(), error_reason()) -> iolist().
+err_args(undefined, IP, Reason) ->
+ [ejabberd_config:may_hide_data(misc:ip_to_list(IP)),
+ format_error(Reason)];
+err_args(JID, IP, Reason) ->
+ [jid:encode(JID),
+ ejabberd_config:may_hide_data(misc:ip_to_list(IP)),
+ format_error(Reason)].
+
+-spec log_disconnection(state(), error_reason()) -> ok.
+log_disconnection(#state{jid = JID, peername = IP}, Reason) ->
+ Msg = case JID of
+ undefined -> "Rejected MQTT connection from ~s: ~s";
+ _ -> "Closing MQTT connection for ~s from ~s: ~s"
+ end,
+ case Reason of
+ {Tag, _} when Tag == replaced; Tag == resumed; Tag == socket ->
+ ?DEBUG(Msg, err_args(JID, IP, Reason));
+ idle_connection ->
+ ?DEBUG(Msg, err_args(JID, IP, Reason));
+ Tag when Tag == session_expired; Tag == shutdown ->
+ ?INFO_MSG(Msg, err_args(JID, IP, Reason));
+ {peer_disconnected, Code, _} ->
+ case mqtt_codec:is_error_code(Code) of
+ true -> ?WARNING_MSG(Msg, err_args(JID, IP, Reason));
+ false -> ?DEBUG(Msg, err_args(JID, IP, Reason))
+ end;
+ _ ->
+ ?WARNING_MSG(Msg, err_args(JID, IP, Reason))
+ end.