diff options
author | Evgeny Khramtsov <ekhramtsov@process-one.net> | 2019-02-25 11:42:09 +0300 |
---|---|---|
committer | Evgeny Khramtsov <ekhramtsov@process-one.net> | 2019-02-25 11:42:09 +0300 |
commit | a3df791373c30ccc79a6082f4c910a378d726cdc (patch) | |
tree | e7efece6aaaec749f0291f0845abdd4a75b7a059 /src/mod_mqtt_session.erl | |
parent | mod_muc_admin: Fix indentation (diff) |
Add MQTT support
Diffstat (limited to 'src/mod_mqtt_session.erl')
-rw-r--r-- | src/mod_mqtt_session.erl | 1318 |
1 files changed, 1318 insertions, 0 deletions
diff --git a/src/mod_mqtt_session.erl b/src/mod_mqtt_session.erl new file mode 100644 index 000000000..3df36b8fb --- /dev/null +++ b/src/mod_mqtt_session.erl @@ -0,0 +1,1318 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net> +%%% @copyright (C) 2002-2019 ProcessOne, SARL. All Rights Reserved. +%%% +%%% Licensed under the Apache License, Version 2.0 (the "License"); +%%% you may not use this file except in compliance with the License. +%%% You may obtain a copy of the License at +%%% +%%% http://www.apache.org/licenses/LICENSE-2.0 +%%% +%%% Unless required by applicable law or agreed to in writing, software +%%% distributed under the License is distributed on an "AS IS" BASIS, +%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%%% See the License for the specific language governing permissions and +%%% limitations under the License. +%%% +%%%------------------------------------------------------------------- +-module(mod_mqtt_session). +-behaviour(p1_server). +-define(VSN, 1). +-vsn(?VSN). + +%% API +-export([start/3, start_link/3, accept/1, route/2]). +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-include("logger.hrl"). +-include("mqtt.hrl"). +-include("xmpp.hrl"). + +-record(state, {vsn = ?VSN :: integer(), + version :: undefined | mqtt_version(), + socket :: undefined | socket(), + peername :: peername(), + timeout = infinity :: timer(), + jid :: undefined | jid:jid(), + session_expiry = 0 :: seconds(), + will :: undefined | publish(), + will_delay = 0 :: seconds(), + stop_reason :: undefined | error_reason(), + acks = #{} :: map(), + subscriptions = #{} :: map(), + topic_aliases = #{} :: map(), + id = 0 :: non_neg_integer(), + in_flight :: undefined | publish() | pubrel(), + codec :: mqtt_codec:state(), + queue :: undefined | p1_queue:queue()}). + +-type error_reason() :: {auth, reason_code()} | + {code, reason_code()} | + {peer_disconnected, reason_code(), binary()} | + {socket, socket_error_reason()} | + {codec, mqtt_codec:error_reason()} | + {unexpected_packet, atom()} | + {tls, inet:posix() | atom() | binary()} | + {replaced, pid()} | {resumed, pid()} | + subscribe_forbidden | publish_forbidden | + will_topic_forbidden | internal_server_error | + session_expired | idle_connection | + queue_full | shutdown | db_failure | + {payload_format_invalid, will | publish} | + session_expiry_non_zero | unknown_topic_alias. + +-type state() :: #state{}. +-type sockmod() :: gen_tcp | fast_tls. +-type socket() :: {sockmod(), inet:socket() | fast_tls:tls_socket()}. +-type peername() :: {inet:ip_address(), inet:port_number()}. +-type seconds() :: non_neg_integer(). +-type milli_seconds() :: non_neg_integer(). +-type timer() :: infinity | {milli_seconds(), integer()}. +-type socket_error_reason() :: closed | timeout | inet:posix(). + +-define(CALL_TIMEOUT, timer:minutes(1)). +-define(RELAY_TIMEOUT, timer:minutes(1)). +-define(MAX_UINT32, 4294967295). + +%%%=================================================================== +%%% API +%%%=================================================================== +start(SockMod, Socket, ListenOpts) -> + p1_server:start(?MODULE, [SockMod, Socket, ListenOpts], + ejabberd_config:fsm_limit_opts(ListenOpts)). + +start_link(SockMod, Socket, ListenOpts) -> + p1_server:start_link(?MODULE, [SockMod, Socket, ListenOpts], + ejabberd_config:fsm_limit_opts(ListenOpts)). + +-spec accept(pid()) -> ok. +accept(Pid) -> + p1_server:cast(Pid, accept). + +-spec route(pid(), term()) -> boolean(). +route(Pid, Term) -> + ejabberd_cluster:send(Pid, Term). + +-spec format_error(error_reason()) -> string(). +format_error(session_expired) -> + "Disconnected session is expired"; +format_error(idle_connection) -> + "Idle connection"; +format_error(queue_full) -> + "Message queue is overloaded"; +format_error(internal_server_error) -> + "Internal server error"; +format_error(db_failure) -> + "Database failure"; +format_error(shutdown) -> + "System shutting down"; +format_error(subscribe_forbidden) -> + "Subscribing to this topic is forbidden by service policy"; +format_error(publish_forbidden) -> + "Publishing to this topic is forbidden by service policy"; +format_error(will_topic_forbidden) -> + "Publishing to this will topic is forbidden by service policy"; +format_error(session_expiry_non_zero) -> + "Session Expiry Interval in DISCONNECT packet should have been zero"; +format_error(unknown_topic_alias) -> + "No mapping found for this Topic Alias"; +format_error({payload_format_invalid, will}) -> + "Will payload format doesn't match its indicator"; +format_error({payload_format_invalid, publish}) -> + "PUBLISH payload format doesn't match its indicator"; +format_error({peer_disconnected, Code, <<>>}) -> + format("Peer disconnected with reason: ~s", + [mqtt_codec:format_reason_code(Code)]); +format_error({peer_disconnected, Code, Reason}) -> + format("Peer disconnected with reason: ~s (~s)", [Reason, Code]); +format_error({replaced, Pid}) -> + format("Replaced by ~p at ~s", [Pid, node(Pid)]); +format_error({resumed, Pid}) -> + format("Resumed by ~p at ~s", [Pid, node(Pid)]); +format_error({unexpected_packet, Name}) -> + format("Unexpected ~s packet", [string:to_upper(atom_to_list(Name))]); +format_error({tls, Reason}) -> + format("TLS failed: ~s", [format_tls_error(Reason)]); +format_error({socket, A}) -> + format("Connection failed: ~s", [format_inet_error(A)]); +format_error({code, Code}) -> + format("Protocol error: ~s", [mqtt_codec:format_reason_code(Code)]); +format_error({auth, Code}) -> + format("Authentication failed: ~s", [mqtt_codec:format_reason_code(Code)]); +format_error({codec, CodecError}) -> + format("Protocol error: ~s", [mqtt_codec:format_error(CodecError)]); +format_error(A) when is_atom(A) -> + atom_to_list(A); +format_error(Reason) -> + format("Unrecognized error: ~w", [Reason]). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([SockMod, Socket, ListenOpts]) -> + MaxSize = proplists:get_value(max_payload_size, ListenOpts, infinity), + SockMod1 = case proplists:get_bool(tls, ListenOpts) of + true -> fast_tls; + false -> SockMod + end, + State1 = #state{socket = {SockMod1, Socket}, + id = p1_rand:uniform(65535), + codec = mqtt_codec:new(MaxSize)}, + Timeout = timer:seconds(30), + State2 = set_timeout(State1, Timeout), + {ok, State2, Timeout}. + +handle_call({get_state, _}, From, #state{stop_reason = {resumed, Pid}} = State) -> + p1_server:reply(From, {error, {resumed, Pid}}), + noreply(State); +handle_call({get_state, Pid}, From, State) -> + case stop(State, {resumed, Pid}) of + {stop, Status, State1} -> + {stop, Status, State1#state{stop_reason = {replaced, Pid}}}; + {noreply, State1, _} -> + ?DEBUG("Transfering MQTT session state to ~p at ~s", [Pid, node(Pid)]), + Q1 = p1_queue:file_to_ram(State1#state.queue), + p1_server:reply(From, {ok, State1#state{queue = Q1}}), + SessionExpiry = timer:seconds(State1#state.session_expiry), + State2 = set_timeout(State1, min(SessionExpiry, ?RELAY_TIMEOUT)), + State3 = State2#state{queue = undefined, + stop_reason = {resumed, Pid}, + acks = #{}, + will = undefined, + session_expiry = 0, + topic_aliases = #{}, + subscriptions = #{}}, + noreply(State3) + end; +handle_call(Request, From, State) -> + ?WARNING_MSG("Got unexpected call from ~p: ~p", [From, Request]), + noreply(State). + +handle_cast(accept, #state{socket = {_, TCPSock} = Socket} = State) -> + case inet:peername(TCPSock) of + {ok, IPPort} -> + State1 = State#state{peername = IPPort}, + case starttls(Socket) of + {ok, Socket1} -> + State2 = State1#state{socket = Socket1}, + handle_info({tcp, TCPSock, <<>>}, State2); + {error, Why} -> + stop(State1, Why) + end; + {error, Why} -> + stop(State, {socket, Why}) + end; +handle_cast(Msg, State) -> + ?WARNING_MSG("Got unexpected cast: ~p", [Msg]), + noreply(State). + +handle_info(Msg, #state{stop_reason = {resumed, Pid} = Reason} = State) -> + case Msg of + {#publish{}, _} -> + ?DEBUG("Relaying delayed publish to ~p at ~s", [Pid, node(Pid)]), + ejabberd_cluster:send(Pid, Msg), + noreply(State); + timeout -> + stop(State, Reason); + _ -> + noreply(State) + end; +handle_info({#publish{meta = Meta} = Pkt, ExpiryTime}, State) -> + ID = next_id(State#state.id), + Meta1 = Meta#{expiry_time => ExpiryTime}, + Pkt1 = Pkt#publish{id = ID, meta = Meta1}, + State1 = State#state{id = ID}, + case send(State1, Pkt1) of + {ok, State2} -> noreply(State2); + {error, State2, Reason} -> stop(State2, Reason) + end; +handle_info({tcp, TCPSock, TCPData}, + #state{codec = Codec, socket = Socket} = State) -> + case recv_data(Socket, TCPData) of + {ok, Data} -> + case mqtt_codec:decode(Codec, Data) of + {ok, Pkt, Codec1} -> + ?DEBUG("Got MQTT packet:~n~s", [pp(Pkt)]), + State1 = State#state{codec = Codec1}, + case handle_packet(Pkt, State1) of + {ok, State2} -> + handle_info({tcp, TCPSock, <<>>}, State2); + {error, State2, Reason} -> + stop(State2, Reason) + end; + {more, Codec1} -> + State1 = State#state{codec = Codec1}, + State2 = reset_keep_alive(State1), + activate(Socket), + noreply(State2); + {error, Why} -> + stop(State, {codec, Why}) + end; + {error, Why} -> + stop(State, Why) + end; +handle_info({tcp_closed, _Sock}, State) -> + ?DEBUG("MQTT connection reset by peer", []), + stop(State, {socket, closed}); +handle_info({tcp_error, _Sock, Reason}, State) -> + ?DEBUG("MQTT connection error: ~s", [format_inet_error(Reason)]), + stop(State, {socket, Reason}); +handle_info(timeout, #state{socket = Socket} = State) -> + case Socket of + undefined -> + ?DEBUG("MQTT session expired", []), + stop(State#state{session_expiry = 0}, session_expired); + _ -> + ?DEBUG("MQTT connection timed out", []), + stop(State, idle_connection) + end; +handle_info({replaced, Pid}, State) -> + stop(State#state{session_expiry = 0}, {replaced, Pid}); +handle_info({timeout, _TRef, publish_will}, State) -> + noreply(publish_will(State)); +handle_info({Ref, badarg}, State) when is_reference(Ref) -> + %% TODO: figure out from where this messages comes from + noreply(State); +handle_info(Info, State) -> + ?WARNING_MSG("Got unexpected info: ~p", [Info]), + noreply(State). + +-spec handle_packet(mqtt_packet(), state()) -> {ok, state()} | + {error, state(), error_reason()}. +handle_packet(#connect{proto_level = Version} = Pkt, State) -> + handle_connect(Pkt, State#state{version = Version}); +handle_packet(#publish{} = Pkt, State) -> + handle_publish(Pkt, State); +handle_packet(#puback{id = ID}, #state{in_flight = #publish{qos = 1, id = ID}} = State) -> + resend(State#state{in_flight = undefined}); +handle_packet(#puback{id = ID, code = Code}, State) -> + ?DEBUG("Ignoring unexpected PUBACK with id=~B and code '~s'", [ID, Code]), + {ok, State}; +handle_packet(#pubrec{id = ID, code = Code}, + #state{in_flight = #publish{qos = 2, id = ID}} = State) -> + case mqtt_codec:is_error_code(Code) of + true -> + ?DEBUG("Got PUBREC with error code '~s', " + "aborting acknowledgement", [Code]), + resend(State#state{in_flight = undefined}); + false -> + Pubrel = #pubrel{id = ID}, + send(State#state{in_flight = Pubrel}, Pubrel) + end; +handle_packet(#pubrec{id = ID, code = Code}, State) -> + case mqtt_codec:is_error_code(Code) of + true -> + ?DEBUG("Ignoring unexpected PUBREC with id=~B and code '~s'", + [ID, Code]), + {ok, State}; + false -> + Code1 = 'packet-identifier-not-found', + ?DEBUG("Got unexpected PUBREC with id=~B, " + "sending PUBREL with error code '~s'", [ID, Code1]), + send(State, #pubrel{id = ID, code = Code1}) + end; +handle_packet(#pubcomp{id = ID}, #state{in_flight = #pubrel{id = ID}} = State) -> + resend(State#state{in_flight = undefined}); +handle_packet(#pubcomp{id = ID}, State) -> + ?DEBUG("Ignoring unexpected PUBCOMP with id=~B: most likely " + "it's a repeated response to duplicated PUBREL", [ID]), + {ok, State}; +handle_packet(#pubrel{id = ID}, State) -> + case maps:take(ID, State#state.acks) of + {_, Acks} -> + send(State#state{acks = Acks}, #pubcomp{id = ID}); + error -> + Code = 'packet-identifier-not-found', + ?DEBUG("Got unexpected PUBREL with id=~B, " + "sending PUBCOMP with error code '~s'", [ID, Code]), + Pubcomp = #pubcomp{id = ID, code = Code}, + send(State, Pubcomp) + end; +handle_packet(#subscribe{} = Pkt, State) -> + handle_subscribe(Pkt, State); +handle_packet(#unsubscribe{} = Pkt, State) -> + handle_unsubscribe(Pkt, State); +handle_packet(#pingreq{}, State) -> + send(State, #pingresp{}); +handle_packet(#disconnect{properties = #{session_expiry_interval := SE}}, + #state{session_expiry = 0} = State) when SE>0 -> + %% Protocol violation + {error, State, session_expiry_non_zero}; +handle_packet(#disconnect{code = Code, properties = Props}, + #state{jid = #jid{lserver = Server}} = State) -> + Reason = maps:get(reason_string, Props, <<>>), + Expiry = case maps:get(session_expiry_interval, Props, undefined) of + undefined -> State#state.session_expiry; + SE -> min(SE, session_expiry(Server)) + end, + State1 = State#state{session_expiry = Expiry}, + State2 = case Code of + 'normal-disconnection' -> State1#state{will = undefined}; + _ -> State1 + end, + {error, State2, {peer_disconnected, Code, Reason}}; +handle_packet(Pkt, State) -> + ?WARNING_MSG("Unexpected packet:~n~s~n** when state:~n~s", + [pp(Pkt), pp(State)]), + {error, State, {unexpected_packet, element(1, Pkt)}}. + +terminate(_, #state{peername = undefined}) -> + ok; +terminate(Reason, State) -> + Reason1 = case Reason of + shutdown -> shutdown; + {shutdown, _} -> shutdown; + normal -> State#state.stop_reason; + {process_limit, _} -> queue_full; + _ -> internal_server_error + end, + case State#state.jid of + #jid{} -> unregister_session(State, Reason1); + undefined -> log_disconnection(State, Reason1) + end, + State1 = disconnect(State, Reason1), + publish_will(State1). + +code_change(_OldVsn, State, _Extra) -> + {ok, upgrade_state(State)}. + +%%%=================================================================== +%%% State transitions +%%%=================================================================== +-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. +noreply(#state{timeout = infinity} = State) -> + {noreply, State, infinity}; +noreply(#state{timeout = {MSecs, StartTime}} = State) -> + CurrentTime = current_time(), + Timeout = max(0, MSecs - CurrentTime + StartTime), + {noreply, State, Timeout}. + +-spec stop(state(), error_reason()) -> {noreply, state(), infinity} | + {stop, normal, state()}. +stop(#state{session_expiry = 0} = State, Reason) -> + {stop, normal, State#state{stop_reason = Reason}}; +stop(#state{session_expiry = SessExp} = State, Reason) -> + case State#state.socket of + undefined -> + noreply(State); + _ -> + WillDelay = State#state.will_delay, + log_disconnection(State, Reason), + State1 = disconnect(State, Reason), + State2 = if WillDelay == 0 -> + publish_will(State1); + WillDelay < SessExp -> + erlang:start_timer( + timer:seconds(WillDelay), self(), publish_will), + State1; + true -> + State1 + end, + State3 = set_timeout(State2, timer:seconds(SessExp)), + State4 = State3#state{stop_reason = Reason}, + noreply(State4) + end. + +-spec upgrade_state(term()) -> state(). +upgrade_state(State) -> + %% Here will be the code upgrading state between different + %% code versions. This is needed when doing session resumption from + %% remote node running the version of the code with incompatible #state{} + %% record fields. Also used by code_change/3 callback. + %% Use element(2, State) for vsn comparison. + State. + +%%%=================================================================== +%%% Session management +%%%=================================================================== +-spec open_session(state(), jid(), boolean()) -> {ok, boolean(), state()} | + {error, state(), error_reason()}. +open_session(State, JID, _CleanStart = false) -> + USR = {_, S, _} = jid:tolower(JID), + case mod_mqtt:lookup_session(USR) of + {ok, Pid} -> + try p1_server:call(Pid, {get_state, self()}, ?CALL_TIMEOUT) of + {ok, State1} -> + State2 = upgrade_state(State1), + Q1 = case queue_type(S) of + ram -> State2#state.queue; + _ -> p1_queue:ram_to_file(State2#state.queue) + end, + Q2 = p1_queue:set_limit(Q1, queue_limit(S)), + State3 = State#state{queue = Q2, + acks = State2#state.acks, + subscriptions = State2#state.subscriptions, + id = State2#state.id, + in_flight = State2#state.in_flight}, + ?DEBUG("Resumed state from ~p at ~s:~n~s", + [Pid, node(Pid), pp(State3)]), + register_session(State3, JID, Pid); + {error, Why} -> + {error, State, Why} + catch exit:{Why, {p1_server, _, _}} -> + ?WARNING_MSG("Failed to copy session state from ~p at ~s: ~s", + [Pid, node(Pid), format_exit_reason(Why)]), + register_session(State, JID, undefined) + end; + {error, notfound} -> + register_session(State, JID, undefined); + {error, Why} -> + {error, State, Why} + end; +open_session(State, JID, _CleanStart = true) -> + register_session(State, JID, undefined). + +-spec register_session(state(), jid(), undefined | pid()) -> + {ok, boolean(), state()} | {error, state(), error_reason()}. +register_session(#state{peername = IP} = State, JID, Parent) -> + USR = {_, S, _} = jid:tolower(JID), + case mod_mqtt:open_session(USR) of + ok -> + case resubscribe(USR, State#state.subscriptions) of + ok -> + ?INFO_MSG("~s for ~s from ~s", + [if is_pid(Parent) -> + io_lib:format( + "Reopened MQTT session via ~p", + [Parent]); + true -> + "Opened MQTT session" + end, + jid:encode(JID), + ejabberd_config:may_hide_data( + misc:ip_to_list(IP))]), + Q = case State#state.queue of + undefined -> + p1_queue:new(queue_type(S), queue_limit(S)); + Q1 -> + Q1 + end, + {ok, is_pid(Parent), State#state{jid = JID, queue = Q}}; + {error, Why} -> + mod_mqtt:close_session(USR), + {error, State#state{session_expiry = 0}, Why} + end; + {error, Reason} -> + ?ERROR_MSG("Failed to register MQTT session for ~s from ~s: ~s", + err_args(JID, IP, Reason)), + {error, State, Reason} + end. + +-spec unregister_session(state(), error_reason()) -> ok. +unregister_session(#state{jid = #jid{} = JID, peername = IP} = State, Reason) -> + Msg = "Closing MQTT session for ~s from ~s: ~s", + case Reason of + {Tag, _} when Tag == replaced; Tag == resumed -> + ?DEBUG(Msg, err_args(JID, IP, Reason)); + {socket, _} -> + ?INFO_MSG(Msg, err_args(JID, IP, Reason)); + Tag when Tag == idle_connection; Tag == session_expired; Tag == shutdown -> + ?INFO_MSG(Msg, err_args(JID, IP, Reason)); + {peer_disconnected, Code, _} -> + case mqtt_codec:is_error_code(Code) of + true -> ?WARNING_MSG(Msg, err_args(JID, IP, Reason)); + false -> ?INFO_MSG(Msg, err_args(JID, IP, Reason)) + end; + _ -> + ?WARNING_MSG(Msg, err_args(JID, IP, Reason)) + end, + USR = jid:tolower(JID), + unsubscribe(maps:keys(State#state.subscriptions), USR, #{}), + case mod_mqtt:close_session(USR) of + ok -> ok; + {error, Why} -> + ?ERROR_MSG( + "Failed to close MQTT session for ~s from ~s: ~s", + err_args(JID, IP, Why)) + end; +unregister_session(_, _) -> + ok. + +%%%=================================================================== +%%% CONNECT/PUBLISH/SUBSCRIBE/UNSUBSCRIBE handlers +%%%=================================================================== +-spec handle_connect(connect(), state()) -> {ok, state()} | + {error, state(), error_reason()}. +handle_connect(#connect{clean_start = CleanStart} = Pkt, + #state{jid = undefined, peername = IP} = State) -> + case authenticate(Pkt, IP) of + {ok, JID} -> + case validate_will(Pkt, JID) of + ok -> + case open_session(State, JID, CleanStart) of + {ok, SessionPresent, State1} -> + State2 = set_session_properties(State1, Pkt), + ConnackProps = get_connack_properties(State2, Pkt), + Connack = #connack{session_present = SessionPresent, + properties = ConnackProps}, + case send(State2, Connack) of + {ok, State3} -> resend(State3); + {error, _, _} = Err -> Err + end; + {error, _, _} = Err -> + Err + end; + {error, Reason} -> + {error, State, Reason} + end; + {error, Code} -> + {error, State, {auth, Code}} + end. + +-spec handle_publish(publish(), state()) -> {ok, state()} | + {error, state(), error_reason()}. +handle_publish(#publish{qos = QoS, id = ID} = Publish, State) -> + case QoS == 2 andalso maps:is_key(ID, State#state.acks) of + true -> + send(State, maps:get(ID, State#state.acks)); + false -> + case validate_publish(Publish, State) of + ok -> + State1 = store_topic_alias(State, Publish), + Ret = publish(State1, Publish), + {Code, Props} = get_publish_code_props(Ret), + case Ret of + {ok, _} when QoS == 2 -> + Pkt = #pubrec{id = ID, code = Code, + properties = Props}, + Acks = maps:put(ID, Pkt, State1#state.acks), + State2 = State1#state{acks = Acks}, + send(State2, Pkt); + {error, _} when QoS == 2 -> + Pkt = #pubrec{id = ID, code = Code, + properties = Props}, + send(State1, Pkt); + _ when QoS == 1 -> + Pkt = #puback{id = ID, code = Code, + properties = Props}, + send(State1, Pkt); + _ -> + {ok, State1} + end; + {error, Why} -> + {error, State, Why} + end + end. + +-spec handle_subscribe(subscribe(), state()) -> + {ok, state()} | {error, state(), error_reason()}. +handle_subscribe(#subscribe{id = ID, filters = TopicFilters} = Pkt, State) -> + case validate_subscribe(Pkt) of + ok -> + USR = jid:tolower(State#state.jid), + SubID = maps:get(subscription_identifier, Pkt#subscribe.properties, 0), + OldSubs = State#state.subscriptions, + {Codes, NewSubs, Props} = subscribe(TopicFilters, USR, SubID), + Subs = maps:merge(OldSubs, NewSubs), + State1 = State#state{subscriptions = Subs}, + Suback = #suback{id = ID, codes = Codes, properties = Props}, + case send(State1, Suback) of + {ok, State2} -> + Pubs = select_retained(USR, NewSubs, OldSubs), + send_retained(State2, Pubs); + {error, _, _} = Err -> + Err + end; + {error, Why} -> + {error, State, Why} + end. + +-spec handle_unsubscribe(unsubscribe(), state()) -> + {ok, state()} | {error, state(), error_reason()}. +handle_unsubscribe(#unsubscribe{id = ID, filters = TopicFilters}, State) -> + USR = jid:tolower(State#state.jid), + {Codes, Subs, Props} = unsubscribe(TopicFilters, USR, State#state.subscriptions), + State1 = State#state{subscriptions = Subs}, + Unsuback = #unsuback{id = ID, codes = Codes, properties = Props}, + send(State1, Unsuback). + +%%%=================================================================== +%%% Aux functions for CONNECT/PUBLISH/SUBSCRIBE/UNSUBSCRIBE handlers +%%%=================================================================== +-spec set_session_properties(state(), connect()) -> state(). +set_session_properties(#state{version = Version, + jid = #jid{lserver = Server}} = State, + #connect{clean_start = CleanStart, + keep_alive = KeepAlive, + properties = Props} = Pkt) -> + SEMin = case CleanStart of + false when Version == ?MQTT_VERSION_4 -> infinity; + _ -> maps:get(session_expiry_interval, Props, 0) + end, + SEConfig = session_expiry(Server), + State1 = State#state{session_expiry = min(SEMin, SEConfig)}, + State2 = set_will_properties(State1, Pkt), + set_keep_alive(State2, KeepAlive). + +-spec set_will_properties(state(), connect()) -> state(). +set_will_properties(State, #connect{will = #publish{} = Will, + will_properties = Props}) -> + {WillDelay, Props1} = case maps:take(will_delay_interval, Props) of + error -> {0, Props}; + Ret -> Ret + end, + State#state{will = Will#publish{properties = Props1}, + will_delay = WillDelay}; +set_will_properties(State, _) -> + State. + +-spec get_connack_properties(state(), connect()) -> properties(). +get_connack_properties(#state{session_expiry = SessExp, jid = JID}, + #connect{client_id = ClientID, + keep_alive = KeepAlive}) -> + Props1 = case ClientID of + <<>> -> #{assigned_client_identifier => JID#jid.lresource}; + _ -> #{} + end, + Props1#{session_expiry_interval => SessExp, + shared_subscription_available => false, + topic_alias_maximum => topic_alias_maximum(JID#jid.lserver), + server_keep_alive => KeepAlive}. + +-spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer()) -> + {[reason_code()], map(), properties()}. +subscribe(TopicFilters, USR, SubID) -> + subscribe(TopicFilters, USR, SubID, [], #{}, ok). + +-spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer(), + [reason_code()], map(), ok | {error, error_reason()}) -> + {[reason_code()], map(), properties()}. +subscribe([{TopicFilter, SubOpts}|TopicFilters], USR, SubID, Codes, Subs, Err) -> + case mod_mqtt:subscribe(USR, TopicFilter, SubOpts, SubID) of + ok -> + Code = subscribe_reason_code(SubOpts#sub_opts.qos), + subscribe(TopicFilters, USR, SubID, [Code|Codes], + maps:put(TopicFilter, {SubOpts, SubID}, Subs), Err); + {error, Why} = Err1 -> + Code = subscribe_reason_code(Why), + subscribe(TopicFilters, USR, SubID, [Code|Codes], Subs, Err1) + end; +subscribe([], _USR, _SubID, Codes, Subs, Err) -> + Props = case Err of + ok -> #{}; + {error, Why} -> + #{reason_string => format_reason_string(Why)} + end, + {lists:reverse(Codes), Subs, Props}. + +-spec unsubscribe([binary()], jid:ljid(), map()) -> + {[reason_code()], map(), properties()}. +unsubscribe(TopicFilters, USR, Subs) -> + unsubscribe(TopicFilters, USR, [], Subs, ok). + +-spec unsubscribe([binary()], jid:ljid(), + [reason_code()], map(), + ok | {error, error_reason()}) -> + {[reason_code()], map(), properties()}. +unsubscribe([TopicFilter|TopicFilters], USR, Codes, Subs, Err) -> + case mod_mqtt:unsubscribe(USR, TopicFilter) of + ok -> + unsubscribe(TopicFilters, USR, [success|Codes], + maps:remove(TopicFilter, Subs), Err); + {error, notfound} -> + unsubscribe(TopicFilters, USR, + ['no-subscription-existed'|Codes], + maps:remove(TopicFilter, Subs), Err); + {error, Why} = Err1 -> + Code = unsubscribe_reason_code(Why), + unsubscribe(TopicFilters, USR, [Code|Codes], Subs, Err1) + end; +unsubscribe([], _USR, Codes, Subs, Err) -> + Props = case Err of + ok -> #{}; + {error, Why} -> + #{reason_string => format_reason_string(Why)} + end, + {lists:reverse(Codes), Subs, Props}. + +-spec select_retained(jid:ljid(), map(), map()) -> [{publish(), seconds()}]. +select_retained(USR, NewSubs, OldSubs) -> + lists:flatten( + maps:fold( + fun(_Filter, {#sub_opts{retain_handling = 2}, _SubID}, Acc) -> + Acc; + (Filter, {#sub_opts{retain_handling = 1, qos = QoS}, SubID}, Acc) -> + case maps:is_key(Filter, OldSubs) of + true -> Acc; + false -> [mod_mqtt:select_retained(USR, Filter, QoS, SubID)|Acc] + end; + (Filter, {#sub_opts{qos = QoS}, SubID}, Acc) -> + [mod_mqtt:select_retained(USR, Filter, QoS, SubID)|Acc] + end, [], NewSubs)). + +-spec send_retained(state(), [{publish(), seconds()}]) -> + {ok, state()} | {error, state(), error_reason()}. +send_retained(State, [{#publish{meta = Meta} = Pub, Expiry}|Pubs]) -> + I = next_id(State#state.id), + Meta1 = Meta#{expiry_time => Expiry}, + Pub1 = Pub#publish{id = I, retain = true, meta = Meta1}, + case send(State#state{id = I}, Pub1) of + {ok, State1} -> + send_retained(State1, Pubs); + Err -> + Err + end; +send_retained(State, []) -> + {ok, State}. + +-spec publish(state(), publish()) -> {ok, non_neg_integer()} | + {error, error_reason()}. +publish(State, #publish{topic = Topic, properties = Props} = Pkt) -> + MessageExpiry = maps:get(message_expiry_interval, Props, ?MAX_UINT32), + ExpiryTime = min(unix_time() + MessageExpiry, ?MAX_UINT32), + USR = jid:tolower(State#state.jid), + Props1 = maps:filter( + fun(payload_format_indicator, _) -> true; + (content_type, _) -> true; + (response_topic, _) -> true; + (correlation_data, _) -> true; + (user_property, _) -> true; + (_, _) -> false + end, Props), + Topic1 = case Topic of + <<>> -> + Alias = maps:get(topic_alias, Props), + maps:get(Alias, State#state.topic_aliases); + _ -> + Topic + end, + Pkt1 = Pkt#publish{topic = Topic1, properties = Props1}, + mod_mqtt:publish(USR, Pkt1, ExpiryTime). + +-spec store_topic_alias(state(), publish()) -> state(). +store_topic_alias(State, #publish{topic = <<_, _/binary>> = Topic, + properties = #{topic_alias := Alias}}) -> + Aliases = maps:put(Alias, Topic, State#state.topic_aliases), + State#state{topic_aliases = Aliases}; +store_topic_alias(State, _) -> + State. + +%%%=================================================================== +%%% Socket management +%%%=================================================================== +-spec send(state(), mqtt_packet()) -> {ok, state()} | + {error, state(), error_reason()}. +send(State, #publish{} = Pkt) -> + case is_expired(Pkt) of + {false, Pkt1} -> + case State#state.in_flight == undefined andalso + p1_queue:is_empty(State#state.queue) of + true -> + Dup = case Pkt1#publish.qos of + 0 -> undefined; + _ -> Pkt1 + end, + State1 = State#state{in_flight = Dup}, + {ok, do_send(State1, Pkt1)}; + false -> + ?DEBUG("Queueing packet:~n~s~n** when state:~n~s", + [pp(Pkt), pp(State)]), + try p1_queue:in(Pkt, State#state.queue) of + Q -> + State1 = State#state{queue = Q}, + {ok, State1} + catch error:full -> + Q = p1_queue:clear(State#state.queue), + State1 = State#state{queue = Q, session_expiry = 0}, + {error, State1, queue_full} + end + end; + true -> + {ok, State} + end; +send(State, Pkt) -> + {ok, do_send(State, Pkt)}. + +-spec resend(state()) -> {ok, state()} | {error, state(), error_reason()}. +resend(#state{in_flight = undefined} = State) -> + case p1_queue:out(State#state.queue) of + {{value, #publish{qos = QoS} = Pkt}, Q} -> + case is_expired(Pkt) of + true -> + resend(State#state{queue = Q}); + {false, Pkt1} when QoS > 0 -> + State1 = State#state{in_flight = Pkt1, queue = Q}, + {ok, do_send(State1, Pkt1)}; + {false, Pkt1} -> + State1 = do_send(State#state{queue = Q}, Pkt1), + resend(State1) + end; + {empty, _} -> + {ok, State} + end; +resend(#state{in_flight = Pkt} = State) -> + {ok, do_send(State, set_dup_flag(Pkt))}. + +-spec do_send(state(), mqtt_packet()) -> state(). +do_send(#state{socket = {SockMod, Sock} = Socket} = State, Pkt) -> + ?DEBUG("Send MQTT packet:~n~s", [pp(Pkt)]), + Data = mqtt_codec:encode(State#state.version, Pkt), + Res = SockMod:send(Sock, Data), + check_sock_result(Socket, Res), + State; +do_send(State, _Pkt) -> + State. + +-spec activate(socket()) -> ok. +activate({SockMod, Sock} = Socket) -> + Res = case SockMod of + gen_tcp -> inet:setopts(Sock, [{active, once}]); + _ -> SockMod:setopts(Sock, [{active, once}]) + end, + check_sock_result(Socket, Res). + +-spec disconnect(state(), error_reason()) -> state(). +disconnect(#state{socket = {SockMod, Sock}} = State, Err) -> + State1 = case Err of + {auth, Code} -> + do_send(State, #connack{code = Code}); + {codec, {Tag, _, _}} when Tag == unsupported_protocol_version; + Tag == unsupported_protocol_name -> + do_send(State#state{version = ?MQTT_VERSION_4}, + #connack{code = connack_reason_code(Err)}); + _ when State#state.version == undefined -> + State; + {Tag, _} when Tag == socket; Tag == tls -> + State; + {peer_disconnected, _, _} -> + State; + _ -> + Props = #{reason_string => format_reason_string(Err)}, + case State#state.jid of + undefined -> + Code = connack_reason_code(Err), + Pkt = #connack{code = Code, properties = Props}, + do_send(State, Pkt); + _ when State#state.version == ?MQTT_VERSION_5 -> + Code = disconnect_reason_code(Err), + Pkt = #disconnect{code = Code, properties = Props}, + do_send(State, Pkt); + _ -> + State + end + end, + SockMod:close(Sock), + State1#state{socket = undefined, + version = undefined, + codec = mqtt_codec:renew(State#state.codec)}; +disconnect(State, _) -> + State. + +-spec check_sock_result(socket(), ok | {error, inet:posix()}) -> ok. +check_sock_result(_, ok) -> + ok; +check_sock_result({_, Sock}, {error, Why}) -> + self() ! {tcp_closed, Sock}, + ?DEBUG("MQTT socket error: ~p", [format_inet_error(Why)]). + +-spec starttls(socket()) -> {ok, socket()} | {error, error_reason()}. +starttls({fast_tls, Socket}) -> + case ejabberd_pkix:get_certfile() of + {ok, Cert} -> + case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of + {ok, TLSSock} -> + {ok, {fast_tls, TLSSock}}; + {error, Why} -> + {error, {tls, Why}} + end; + error -> + {error, {tls, no_certfile}} + end; +starttls(Socket) -> + {ok, Socket}. + +-spec recv_data(socket(), binary()) -> {ok, binary()} | {error, error_reason()}. +recv_data({fast_tls, Sock}, Data) -> + case fast_tls:recv_data(Sock, Data) of + {ok, _} = OK -> OK; + {error, E} when is_atom(E) -> {error, {socket, E}}; + {error, E} when is_binary(E) -> {error, {tls, E}}; + {error, _} = Err -> Err + end; +recv_data(_, Data) -> + {ok, Data}. + +%%%=================================================================== +%%% Formatters +%%%=================================================================== +-spec pp(any()) -> iolist(). +pp(Term) -> + io_lib_pretty:print(Term, fun pp/2). + +-spec format_inet_error(socket_error_reason()) -> string(). +format_inet_error(closed) -> + "connection closed"; +format_inet_error(timeout) -> + format_inet_error(etimedout); +format_inet_error(Reason) -> + case inet:format_error(Reason) of + "unknown POSIX error" -> atom_to_list(Reason); + Txt -> Txt + end. + +-spec format_tls_error(atom() | binary()) -> string() | binary(). +format_tls_error(no_cerfile) -> + "certificate not found"; +format_tls_error(Reason) when is_atom(Reason) -> + format_inet_error(Reason); +format_tls_error(Reason) -> + Reason. + +-spec format_exit_reason(term()) -> string(). +format_exit_reason(noproc) -> + "process is dead"; +format_exit_reason(normal) -> + "process has exited"; +format_exit_reason(killed) -> + "process has been killed"; +format_exit_reason(timeout) -> + "remote call to process timed out"; +format_exit_reason(Why) -> + format("unexpected error: ~p", [Why]). + +%% Same as format_error/1, but hides sensitive data +%% and returns result as binary +-spec format_reason_string(error_reason()) -> binary(). +format_reason_string({resumed, _}) -> + <<"Resumed by another connection">>; +format_reason_string({replaced, _}) -> + <<"Replaced by another connection">>; +format_reason_string(Err) -> + list_to_binary(format_error(Err)). + +-spec format(io:format(), list()) -> string(). +format(Fmt, Args) -> + lists:flatten(io_lib:format(Fmt, Args)). + +-spec pp(atom(), non_neg_integer()) -> [atom()] | no. +pp(state, 17) -> record_info(fields, state); +pp(Rec, Size) -> mqtt_codec:pp(Rec, Size). + +-spec publish_reason_code(error_reason()) -> reason_code(). +publish_reason_code(publish_forbidden) -> 'topic-name-invalid'; +publish_reason_code(_) -> 'implementation-specific-error'. + +-spec subscribe_reason_code(qos() | error_reason()) -> reason_code(). +subscribe_reason_code(0) -> 'granted-qos-0'; +subscribe_reason_code(1) -> 'granted-qos-1'; +subscribe_reason_code(2) -> 'granted-qos-2'; +subscribe_reason_code(subscribe_forbidden) -> 'topic-filter-invalid'; +subscribe_reason_code(_) -> 'implementation-specific-error'. + +-spec unsubscribe_reason_code(error_reason()) -> reason_code(). +unsubscribe_reason_code(_) -> 'implementation-specific-error'. + +-spec disconnect_reason_code(error_reason()) -> reason_code(). +disconnect_reason_code({code, Code}) -> Code; +disconnect_reason_code({codec, Err}) -> mqtt_codec:error_reason_code(Err); +disconnect_reason_code({unexpected_packet, _}) -> 'protocol-error'; +disconnect_reason_code({replaced, _}) -> 'session-taken-over'; +disconnect_reason_code({resumed, _}) -> 'session-taken-over'; +disconnect_reason_code(internal_server_error) -> 'implementation-specific-error'; +disconnect_reason_code(db_failure) -> 'implementation-specific-error'; +disconnect_reason_code(idle_connection) -> 'keep-alive-timeout'; +disconnect_reason_code(queue_full) -> 'quota-exceeded'; +disconnect_reason_code(shutdown) -> 'server-shutting-down'; +disconnect_reason_code(subscribe_forbidden) -> 'topic-filter-invalid'; +disconnect_reason_code(publish_forbidden) -> 'topic-name-invalid'; +disconnect_reason_code(will_topic_forbidden) -> 'topic-name-invalid'; +disconnect_reason_code({payload_format_invalid, _}) -> 'payload-format-invalid'; +disconnect_reason_code(session_expiry_non_zero) -> 'protocol-error'; +disconnect_reason_code(unknown_topic_alias) -> 'protocol-error'; +disconnect_reason_code(_) -> 'unspecified-error'. + +-spec connack_reason_code(error_reason()) -> reason_code(). +connack_reason_code({Tag, Code}) when Tag == auth; Tag == code -> Code; +connack_reason_code({codec, Err}) -> mqtt_codec:error_reason_code(Err); +connack_reason_code({unexpected_packet, _}) -> 'protocol-error'; +connack_reason_code(internal_server_error) -> 'implementation-specific-error'; +connack_reason_code(db_failure) -> 'implementation-specific-error'; +connack_reason_code(idle_connection) -> 'keep-alive-timeout'; +connack_reason_code(queue_full) -> 'quota-exceeded'; +connack_reason_code(shutdown) -> 'server-shutting-down'; +connack_reason_code(will_topic_forbidden) -> 'topic-name-invalid'; +connack_reason_code({payload_format_invalid, _}) -> 'payload-format-invalid'; +connack_reason_code(session_expiry_non_zero) -> 'protocol-error'; +connack_reason_code(_) -> 'unspecified-error'. + +%%%=================================================================== +%%% Configuration processing +%%%=================================================================== +-spec queue_type(binary()) -> ram | file. +queue_type(Host) -> + gen_mod:get_module_opt(Host, mod_mqtt, queue_type). + +-spec queue_limit(binary()) -> non_neg_integer() | unlimited. +queue_limit(Host) -> + gen_mod:get_module_opt(Host, mod_mqtt, max_queue). + +-spec session_expiry(binary()) -> seconds(). +session_expiry(Host) -> + gen_mod:get_module_opt(Host, mod_mqtt, session_expiry). + +-spec topic_alias_maximum(binary()) -> non_neg_integer(). +topic_alias_maximum(Host) -> + gen_mod:get_module_opt(Host, mod_mqtt, max_topic_aliases). + +%%%=================================================================== +%%% Timings +%%%=================================================================== +-spec current_time() -> milli_seconds(). +current_time() -> + p1_time_compat:monotonic_time(milli_seconds). + +-spec unix_time() -> seconds(). +unix_time() -> + p1_time_compat:system_time(seconds). + +-spec set_keep_alive(state(), seconds()) -> state(). +set_keep_alive(State, 0) -> + ?DEBUG("Disabling MQTT keep-alive", []), + State#state{timeout = infinity}; +set_keep_alive(State, Secs) -> + Secs1 = round(Secs * 1.5), + ?DEBUG("Setting MQTT keep-alive to ~B seconds", [Secs1]), + set_timeout(State, timer:seconds(Secs1)). + +-spec reset_keep_alive(state()) -> state(). +reset_keep_alive(#state{timeout = {MSecs, _}, jid = #jid{}} = State) -> + set_timeout(State, MSecs); +reset_keep_alive(State) -> + State. + +-spec set_timeout(state(), milli_seconds()) -> state(). +set_timeout(State, MSecs) -> + Time = current_time(), + State#state{timeout = {MSecs, Time}}. + +-spec is_expired(publish()) -> true | {false, publish()}. +is_expired(#publish{meta = Meta, properties = Props} = Pkt) -> + case maps:get(expiry_time, Meta, ?MAX_UINT32) of + ?MAX_UINT32 -> + {false, Pkt}; + ExpiryTime -> + Left = ExpiryTime - unix_time(), + if Left > 0 -> + Props1 = Props#{message_expiry_interval => Left}, + {false, Pkt#publish{properties = Props1}}; + true -> + ?DEBUG("Dropping expired packet:~n~s", [pp(Pkt)]), + true + end + end. + +%%%=================================================================== +%%% Authentication +%%%=================================================================== +-spec parse_credentials(connect()) -> {ok, jid:jid()} | {error, reason_code()}. +parse_credentials(#connect{client_id = <<>>}) -> + parse_credentials(#connect{client_id = p1_rand:get_string()}); +parse_credentials(#connect{username = <<>>, client_id = ClientID}) -> + Host = ejabberd_config:get_myname(), + JID = case jid:make(ClientID, Host) of + error -> jid:make(str:sha(ClientID), Host); + J -> J + end, + parse_credentials(JID, ClientID); +parse_credentials(#connect{username = User} = Pkt) -> + try jid:decode(User) of + #jid{luser = <<>>} -> + case jid:make(User, ejabberd_config:get_myname()) of + error -> + {error, 'bad-user-name-or-password'}; + JID -> + parse_credentials(JID, Pkt#connect.client_id) + end; + JID -> + parse_credentials(JID, Pkt#connect.client_id) + catch _:{bad_jid, _} -> + {error, 'bad-user-name-or-password'} + end. + +-spec parse_credentials(jid:jid(), binary()) -> {ok, jid:jid()} | {error, reason_code()}. +parse_credentials(JID, ClientID) -> + case gen_mod:is_loaded(JID#jid.lserver, mod_mqtt) of + false -> + {error, 'server-unavailable'}; + true -> + case jid:replace_resource(JID, ClientID) of + error -> + {error, 'client-identifier-not-valid'}; + JID1 -> + {ok, JID1} + end + end. + +-spec authenticate(connect(), peername()) -> {ok, jid:jid()} | {error, reason_code()}. +authenticate(#connect{password = Pass} = Pkt, IP) -> + case parse_credentials(Pkt) of + {ok, #jid{luser = LUser, lserver = LServer} = JID} -> + case ejabberd_auth:check_password_with_authmodule( + LUser, <<>>, LServer, Pass) of + {true, AuthModule} -> + ?INFO_MSG( + "Accepted MQTT authentication for ~s " + "by ~s backend from ~s", + [jid:encode(JID), + ejabberd_auth:backend_type(AuthModule), + ejabberd_config:may_hide_data(misc:ip_to_list(IP))]), + {ok, JID}; + false -> + {error, 'not-authorized'} + end; + {error, _} = Err -> + Err + end. + +%%%=================================================================== +%%% Validators +%%%=================================================================== +-spec validate_will(connect(), jid:jid()) -> ok | {error, reason_code()}. +validate_will(#connect{will = undefined}, _) -> + ok; +validate_will(#connect{will = #publish{topic = Topic, payload = Payload}, + will_properties = Props}, JID) -> + case mod_mqtt:check_publish_access(Topic, jid:tolower(JID)) of + deny -> {error, will_topic_forbidden}; + allow -> validate_payload(Props, Payload, will) + end. + +-spec validate_publish(publish(), state()) -> ok | {error, error_reason()}. +validate_publish(#publish{topic = Topic, payload = Payload, + properties = Props}, State) -> + case validate_topic(Topic, Props, State) of + ok -> validate_payload(Props, Payload, publish); + Err -> Err + end. + +-spec validate_subscribe(subscribe()) -> ok | {error, error_reason()}. +validate_subscribe(#subscribe{filters = Filters}) -> + case lists:any( + fun({<<"$share/", _/binary>>, _}) -> true; + (_) -> false + end, Filters) of + true -> + {error, {code, 'shared-subscriptions-not-supported'}}; + false -> + ok + end. + +-spec validate_topic(binary(), properties(), state()) -> ok | {error, error_reason()}. +validate_topic(<<>>, Props, State) -> + case maps:get(topic_alias, Props, 0) of + 0 -> + {error, {code, 'topic-alias-invalid'}}; + Alias -> + case maps:is_key(Alias, State#state.topic_aliases) of + true -> ok; + false -> {error, unknown_topic_alias} + end + end; +validate_topic(_, #{topic_alias := Alias}, State) -> + JID = State#state.jid, + Max = topic_alias_maximum(JID#jid.lserver), + if Alias > Max -> + {error, {code, 'topic-alias-invalid'}}; + true -> + ok + end; +validate_topic(_, _, _) -> + ok. + +-spec validate_payload(properties(), binary(), will | publish) -> ok | {error, error_reason()}. +validate_payload(#{payload_format_indicator := utf8}, Payload, Type) -> + try mqtt_codec:utf8(Payload) of + _ -> ok + catch _:_ -> + {error, {payload_format_invalid, Type}} + end; +validate_payload(_, _, _) -> + ok. + +%%%=================================================================== +%%% Misc +%%%=================================================================== +-spec resubscribe(jid:ljid(), map()) -> ok | {error, error_reason()}. +resubscribe(USR, Subs) -> + case maps:fold( + fun(TopicFilter, {SubOpts, ID}, ok) -> + mod_mqtt:subscribe(USR, TopicFilter, SubOpts, ID); + (_, _, {error, _} = Err) -> + Err + end, ok, Subs) of + ok -> + ok; + {error, _} = Err1 -> + unsubscribe(maps:keys(Subs), USR, #{}), + Err1 + end. + +-spec publish_will(state()) -> state(). +publish_will(#state{will = #publish{} = Will, + jid = #jid{} = JID} = State) -> + case publish(State, Will) of + {ok, _} -> + ?DEBUG("Will of ~s has been published to ~s", + [jid:encode(JID), Will#publish.topic]); + {error, Why} -> + ?WARNING_MSG("Failed to publish will of ~s to ~s: ~s", + [jid:encode(JID), Will#publish.topic, + format_error(Why)]) + end, + State#state{will = undefined}; +publish_will(State) -> + State. + +-spec next_id(non_neg_integer()) -> pos_integer(). +next_id(ID) -> + (ID rem 65535) + 1. + +-spec set_dup_flag(mqtt_packet()) -> mqtt_packet(). +set_dup_flag(#publish{qos = QoS} = Pkt) when QoS>0 -> + Pkt#publish{dup = true}; +set_dup_flag(Pkt) -> + Pkt. + +-spec get_publish_code_props({ok, non_neg_integer()} | + {error, error_reason()}) -> {reason_code(), properties()}. +get_publish_code_props({ok, 0}) -> + {'no-matching-subscribers', #{}}; +get_publish_code_props({ok, _}) -> + {success, #{}}; +get_publish_code_props({error, Err}) -> + Code = publish_reason_code(Err), + Reason = format_reason_string(Err), + {Code, #{reason_string => Reason}}. + +-spec err_args(undefined | jid:jid(), peername(), error_reason()) -> iolist(). +err_args(undefined, IP, Reason) -> + [ejabberd_config:may_hide_data(misc:ip_to_list(IP)), + format_error(Reason)]; +err_args(JID, IP, Reason) -> + [jid:encode(JID), + ejabberd_config:may_hide_data(misc:ip_to_list(IP)), + format_error(Reason)]. + +-spec log_disconnection(state(), error_reason()) -> ok. +log_disconnection(#state{jid = JID, peername = IP}, Reason) -> + Msg = case JID of + undefined -> "Rejected MQTT connection from ~s: ~s"; + _ -> "Closing MQTT connection for ~s from ~s: ~s" + end, + case Reason of + {Tag, _} when Tag == replaced; Tag == resumed; Tag == socket -> + ?DEBUG(Msg, err_args(JID, IP, Reason)); + idle_connection -> + ?DEBUG(Msg, err_args(JID, IP, Reason)); + Tag when Tag == session_expired; Tag == shutdown -> + ?INFO_MSG(Msg, err_args(JID, IP, Reason)); + {peer_disconnected, Code, _} -> + case mqtt_codec:is_error_code(Code) of + true -> ?WARNING_MSG(Msg, err_args(JID, IP, Reason)); + false -> ?DEBUG(Msg, err_args(JID, IP, Reason)) + end; + _ -> + ?WARNING_MSG(Msg, err_args(JID, IP, Reason)) + end. |