aboutsummaryrefslogtreecommitdiff
path: root/src/mod_mqtt_session.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/mod_mqtt_session.erl')
-rw-r--r--src/mod_mqtt_session.erl108
1 files changed, 62 insertions, 46 deletions
diff --git a/src/mod_mqtt_session.erl b/src/mod_mqtt_session.erl
index bbcf9258a..dd7a7c47f 100644
--- a/src/mod_mqtt_session.erl
+++ b/src/mod_mqtt_session.erl
@@ -17,7 +17,7 @@
%%%-------------------------------------------------------------------
-module(mod_mqtt_session).
-behaviour(p1_server).
--define(VSN, 1).
+-define(VSN, 2).
-vsn(?VSN).
%% API
@@ -33,20 +33,25 @@
-record(state, {vsn = ?VSN :: integer(),
version :: undefined | mqtt_version(),
socket :: undefined | socket(),
- peername :: peername(),
+ peername :: undefined | peername(),
timeout = infinity :: timer(),
jid :: undefined | jid:jid(),
session_expiry = 0 :: seconds(),
will :: undefined | publish(),
will_delay = 0 :: seconds(),
stop_reason :: undefined | error_reason(),
- acks = #{} :: map(),
- subscriptions = #{} :: map(),
- topic_aliases = #{} :: map(),
+ acks = #{} :: acks(),
+ subscriptions = #{} :: subscriptions(),
+ topic_aliases = #{} :: topic_aliases(),
id = 0 :: non_neg_integer(),
in_flight :: undefined | publish() | pubrel(),
codec :: mqtt_codec:state(),
- queue :: undefined | p1_queue:queue()}).
+ queue :: undefined | p1_queue:queue(publish()),
+ tls :: boolean()}).
+
+-type acks() :: #{non_neg_integer() => pubrec()}.
+-type subscriptions() :: #{binary() => {sub_opts(), non_neg_integer()}}.
+-type topic_aliases() :: #{non_neg_integer() => binary()}.
-type error_reason() :: {auth, reason_code()} |
{code, reason_code()} |
@@ -64,8 +69,9 @@
session_expiry_non_zero | unknown_topic_alias.
-type state() :: #state{}.
--type sockmod() :: gen_tcp | fast_tls | mod_mqtt_ws.
--type socket() :: {sockmod(), inet:socket() | fast_tls:tls_socket() | mod_mqtt_ws:socket()}.
+-type socket() :: {gen_tcp, inet:socket()} |
+ {fast_tls, fast_tls:tls_socket()} |
+ {mod_mqtt_ws, mod_mqtt_ws:socket()}.
-type peername() :: {inet:ip_address(), inet:port_number()}.
-type seconds() :: non_neg_integer().
-type milli_seconds() :: non_neg_integer().
@@ -153,13 +159,9 @@ format_error(Reason) ->
%%%===================================================================
init([SockMod, Socket, ListenOpts]) ->
MaxSize = proplists:get_value(max_payload_size, ListenOpts, infinity),
- SockMod1 = case {SockMod, proplists:get_bool(tls, ListenOpts)} of
- {gen_tcp, true} -> fast_tls;
- {gen_tcp, false} -> gen_tcp;
- {_, _} -> SockMod
- end,
- State1 = #state{socket = {SockMod1, Socket},
+ State1 = #state{socket = {SockMod, Socket},
id = p1_rand:uniform(65535),
+ tls = proplists:get_bool(tls, ListenOpts),
codec = mqtt_codec:new(MaxSize)},
Timeout = timer:seconds(30),
State2 = set_timeout(State1, Timeout),
@@ -188,14 +190,14 @@ handle_call({get_state, Pid}, From, State) ->
noreply(State3)
end;
handle_call(Request, From, State) ->
- ?WARNING_MSG("Got unexpected call from ~p: ~p", [From, Request]),
+ ?WARNING_MSG("Unexpected call from ~p: ~p", [From, Request]),
noreply(State).
-handle_cast(accept, #state{socket = {_, Sock} = Socket} = State) ->
+handle_cast(accept, #state{socket = {_, Sock}} = State) ->
case peername(State) of
{ok, IPPort} ->
State1 = State#state{peername = IPPort},
- case starttls(Socket) of
+ case starttls(State) of
{ok, Socket1} ->
State2 = State1#state{socket = Socket1},
handle_info({tcp, Sock, <<>>}, State2);
@@ -206,7 +208,7 @@ handle_cast(accept, #state{socket = {_, Sock} = Socket} = State) ->
stop(State, {socket, Why})
end;
handle_cast(Msg, State) ->
- ?WARNING_MSG("Got unexpected cast: ~p", [Msg]),
+ ?WARNING_MSG("Unexpected cast: ~p", [Msg]),
noreply(State).
handle_info(Msg, #state{stop_reason = {resumed, Pid} = Reason} = State) ->
@@ -277,7 +279,7 @@ handle_info({Ref, badarg}, State) when is_reference(Ref) ->
%% TODO: figure out from where this messages comes from
noreply(State);
handle_info(Info, State) ->
- ?WARNING_MSG("Got unexpected info: ~p", [Info]),
+ ?WARNING_MSG("Unexpected info: ~p", [Info]),
noreply(State).
-spec handle_packet(mqtt_packet(), state()) -> {ok, state()} |
@@ -310,7 +312,7 @@ handle_packet(#pubrec{id = ID, code = Code}, State) ->
{ok, State};
false ->
Code1 = 'packet-identifier-not-found',
- ?DEBUG("Got unexpected PUBREC with id=~B, "
+ ?DEBUG("Unexpected PUBREC with id=~B, "
"sending PUBREL with error code '~s'", [ID, Code1]),
send(State, #pubrel{id = ID, code = Code1})
end;
@@ -326,7 +328,7 @@ handle_packet(#pubrel{id = ID}, State) ->
send(State#state{acks = Acks}, #pubcomp{id = ID});
error ->
Code = 'packet-identifier-not-found',
- ?DEBUG("Got unexpected PUBREL with id=~B, "
+ ?DEBUG("Unexpected PUBREL with id=~B, "
"sending PUBCOMP with error code '~s'", [ID, Code]),
Pubcomp = #pubcomp{id = ID, code = Code},
send(State, Pubcomp)
@@ -416,13 +418,27 @@ stop(#state{session_expiry = SessExp} = State, Reason) ->
noreply(State4)
end.
--spec upgrade_state(term()) -> state().
+%% Here is the code upgrading state between different
+%% code versions. This is needed when doing session resumption from
+%% remote node running the version of the code with incompatible #state{}
+%% record fields. Also used by code_change/3 callback.
+-spec upgrade_state(tuple()) -> state().
upgrade_state(State) ->
- %% Here will be the code upgrading state between different
- %% code versions. This is needed when doing session resumption from
- %% remote node running the version of the code with incompatible #state{}
- %% record fields. Also used by code_change/3 callback.
- %% Use element(2, State) for vsn comparison.
+ case element(2, State) of
+ ?VSN ->
+ State;
+ VSN when VSN > ?VSN ->
+ erlang:error({downgrade_not_supported, State});
+ VSN ->
+ State1 = upgrade_state(State, VSN),
+ upgrade_state(setelement(2, State1, VSN+1))
+ end.
+
+-spec upgrade_state(tuple(), 1..?VSN) -> tuple().
+upgrade_state(OldState, 1) ->
+ %% Appending 'tls' field
+ erlang:append_element(OldState, false);
+upgrade_state(State, _VSN) ->
State.
%%%===================================================================
@@ -673,13 +689,13 @@ get_connack_properties(#state{session_expiry = SessExp, jid = JID},
server_keep_alive => KeepAlive}.
-spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer()) ->
- {[reason_code()], map(), properties()}.
+ {[reason_code()], subscriptions(), properties()}.
subscribe(TopicFilters, USR, SubID) ->
subscribe(TopicFilters, USR, SubID, [], #{}, ok).
-spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer(),
- [reason_code()], map(), ok | {error, error_reason()}) ->
- {[reason_code()], map(), properties()}.
+ [reason_code()], subscriptions(), ok | {error, error_reason()}) ->
+ {[reason_code()], subscriptions(), properties()}.
subscribe([{TopicFilter, SubOpts}|TopicFilters], USR, SubID, Codes, Subs, Err) ->
case mod_mqtt:subscribe(USR, TopicFilter, SubOpts, SubID) of
ok ->
@@ -698,15 +714,15 @@ subscribe([], _USR, _SubID, Codes, Subs, Err) ->
end,
{lists:reverse(Codes), Subs, Props}.
--spec unsubscribe([binary()], jid:ljid(), map()) ->
- {[reason_code()], map(), properties()}.
+-spec unsubscribe([binary()], jid:ljid(), subscriptions()) ->
+ {[reason_code()], subscriptions(), properties()}.
unsubscribe(TopicFilters, USR, Subs) ->
unsubscribe(TopicFilters, USR, [], Subs, ok).
-spec unsubscribe([binary()], jid:ljid(),
- [reason_code()], map(),
+ [reason_code()], subscriptions(),
ok | {error, error_reason()}) ->
- {[reason_code()], map(), properties()}.
+ {[reason_code()], subscriptions(), properties()}.
unsubscribe([TopicFilter|TopicFilters], USR, Codes, Subs, Err) ->
case mod_mqtt:unsubscribe(USR, TopicFilter) of
ok ->
@@ -728,7 +744,7 @@ unsubscribe([], _USR, Codes, Subs, Err) ->
end,
{lists:reverse(Codes), Subs, Props}.
--spec select_retained(jid:ljid(), map(), map()) -> [{publish(), seconds()}].
+-spec select_retained(jid:ljid(), subscriptions(), subscriptions()) -> [{publish(), seconds()}].
select_retained(USR, NewSubs, OldSubs) ->
lists:flatten(
maps:fold(
@@ -915,8 +931,8 @@ check_sock_result({_, Sock}, {error, Why}) ->
self() ! {tcp_closed, Sock},
?DEBUG("MQTT socket error: ~p", [format_inet_error(Why)]).
--spec starttls(socket()) -> {ok, socket()} | {error, error_reason()}.
-starttls({fast_tls, Socket}) ->
+-spec starttls(state()) -> {ok, socket()} | {error, error_reason()}.
+starttls(#state{socket = {gen_tcp, Socket}, tls = true}) ->
case ejabberd_pkix:get_certfile() of
{ok, Cert} ->
case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of
@@ -928,7 +944,7 @@ starttls({fast_tls, Socket}) ->
error ->
{error, {tls, no_certfile}}
end;
-starttls(Socket) ->
+starttls(#state{socket = Socket}) ->
{ok, Socket}.
-spec recv_data(socket(), binary()) -> {ok, binary()} | {error, error_reason()}.
@@ -961,8 +977,8 @@ format_inet_error(Reason) ->
end.
-spec format_tls_error(atom() | binary()) -> string() | binary().
-format_tls_error(no_cerfile) ->
- "certificate not found";
+format_tls_error(no_certfile) ->
+ "certificate not configured";
format_tls_error(Reason) when is_atom(Reason) ->
format_inet_error(Reason);
format_tls_error(Reason) ->
@@ -1050,19 +1066,19 @@ connack_reason_code(_) -> 'unspecified-error'.
%%%===================================================================
-spec queue_type(binary()) -> ram | file.
queue_type(Host) ->
- gen_mod:get_module_opt(Host, mod_mqtt, queue_type).
+ mod_mqtt_opt:queue_type(Host).
-spec queue_limit(binary()) -> non_neg_integer() | unlimited.
queue_limit(Host) ->
- gen_mod:get_module_opt(Host, mod_mqtt, max_queue).
+ mod_mqtt_opt:max_queue(Host).
-spec session_expiry(binary()) -> seconds().
session_expiry(Host) ->
- gen_mod:get_module_opt(Host, mod_mqtt, session_expiry).
+ mod_mqtt_opt:session_expiry(Host).
-spec topic_alias_maximum(binary()) -> non_neg_integer().
topic_alias_maximum(Host) ->
- gen_mod:get_module_opt(Host, mod_mqtt, max_topic_aliases).
+ mod_mqtt_opt:max_topic_aliases(Host).
%%%===================================================================
%%% Timings
@@ -1177,7 +1193,7 @@ authenticate(#connect{password = Pass} = Pkt, IP) ->
%%%===================================================================
%%% Validators
%%%===================================================================
--spec validate_will(connect(), jid:jid()) -> ok | {error, reason_code()}.
+-spec validate_will(connect(), jid:jid()) -> ok | {error, error_reason()}.
validate_will(#connect{will = undefined}, _) ->
ok;
validate_will(#connect{will = #publish{topic = Topic, payload = Payload},
@@ -1242,7 +1258,7 @@ validate_payload(_, _, _) ->
%%%===================================================================
%%% Misc
%%%===================================================================
--spec resubscribe(jid:ljid(), map()) -> ok | {error, error_reason()}.
+-spec resubscribe(jid:ljid(), subscriptions()) -> ok | {error, error_reason()}.
resubscribe(USR, Subs) ->
case maps:fold(
fun(TopicFilter, {SubOpts, ID}, ok) ->