diff options
Diffstat (limited to 'src/mod_mqtt_session.erl')
-rw-r--r-- | src/mod_mqtt_session.erl | 108 |
1 files changed, 62 insertions, 46 deletions
diff --git a/src/mod_mqtt_session.erl b/src/mod_mqtt_session.erl index bbcf9258a..dd7a7c47f 100644 --- a/src/mod_mqtt_session.erl +++ b/src/mod_mqtt_session.erl @@ -17,7 +17,7 @@ %%%------------------------------------------------------------------- -module(mod_mqtt_session). -behaviour(p1_server). --define(VSN, 1). +-define(VSN, 2). -vsn(?VSN). %% API @@ -33,20 +33,25 @@ -record(state, {vsn = ?VSN :: integer(), version :: undefined | mqtt_version(), socket :: undefined | socket(), - peername :: peername(), + peername :: undefined | peername(), timeout = infinity :: timer(), jid :: undefined | jid:jid(), session_expiry = 0 :: seconds(), will :: undefined | publish(), will_delay = 0 :: seconds(), stop_reason :: undefined | error_reason(), - acks = #{} :: map(), - subscriptions = #{} :: map(), - topic_aliases = #{} :: map(), + acks = #{} :: acks(), + subscriptions = #{} :: subscriptions(), + topic_aliases = #{} :: topic_aliases(), id = 0 :: non_neg_integer(), in_flight :: undefined | publish() | pubrel(), codec :: mqtt_codec:state(), - queue :: undefined | p1_queue:queue()}). + queue :: undefined | p1_queue:queue(publish()), + tls :: boolean()}). + +-type acks() :: #{non_neg_integer() => pubrec()}. +-type subscriptions() :: #{binary() => {sub_opts(), non_neg_integer()}}. +-type topic_aliases() :: #{non_neg_integer() => binary()}. -type error_reason() :: {auth, reason_code()} | {code, reason_code()} | @@ -64,8 +69,9 @@ session_expiry_non_zero | unknown_topic_alias. -type state() :: #state{}. --type sockmod() :: gen_tcp | fast_tls | mod_mqtt_ws. --type socket() :: {sockmod(), inet:socket() | fast_tls:tls_socket() | mod_mqtt_ws:socket()}. +-type socket() :: {gen_tcp, inet:socket()} | + {fast_tls, fast_tls:tls_socket()} | + {mod_mqtt_ws, mod_mqtt_ws:socket()}. -type peername() :: {inet:ip_address(), inet:port_number()}. -type seconds() :: non_neg_integer(). -type milli_seconds() :: non_neg_integer(). @@ -153,13 +159,9 @@ format_error(Reason) -> %%%=================================================================== init([SockMod, Socket, ListenOpts]) -> MaxSize = proplists:get_value(max_payload_size, ListenOpts, infinity), - SockMod1 = case {SockMod, proplists:get_bool(tls, ListenOpts)} of - {gen_tcp, true} -> fast_tls; - {gen_tcp, false} -> gen_tcp; - {_, _} -> SockMod - end, - State1 = #state{socket = {SockMod1, Socket}, + State1 = #state{socket = {SockMod, Socket}, id = p1_rand:uniform(65535), + tls = proplists:get_bool(tls, ListenOpts), codec = mqtt_codec:new(MaxSize)}, Timeout = timer:seconds(30), State2 = set_timeout(State1, Timeout), @@ -188,14 +190,14 @@ handle_call({get_state, Pid}, From, State) -> noreply(State3) end; handle_call(Request, From, State) -> - ?WARNING_MSG("Got unexpected call from ~p: ~p", [From, Request]), + ?WARNING_MSG("Unexpected call from ~p: ~p", [From, Request]), noreply(State). -handle_cast(accept, #state{socket = {_, Sock} = Socket} = State) -> +handle_cast(accept, #state{socket = {_, Sock}} = State) -> case peername(State) of {ok, IPPort} -> State1 = State#state{peername = IPPort}, - case starttls(Socket) of + case starttls(State) of {ok, Socket1} -> State2 = State1#state{socket = Socket1}, handle_info({tcp, Sock, <<>>}, State2); @@ -206,7 +208,7 @@ handle_cast(accept, #state{socket = {_, Sock} = Socket} = State) -> stop(State, {socket, Why}) end; handle_cast(Msg, State) -> - ?WARNING_MSG("Got unexpected cast: ~p", [Msg]), + ?WARNING_MSG("Unexpected cast: ~p", [Msg]), noreply(State). handle_info(Msg, #state{stop_reason = {resumed, Pid} = Reason} = State) -> @@ -277,7 +279,7 @@ handle_info({Ref, badarg}, State) when is_reference(Ref) -> %% TODO: figure out from where this messages comes from noreply(State); handle_info(Info, State) -> - ?WARNING_MSG("Got unexpected info: ~p", [Info]), + ?WARNING_MSG("Unexpected info: ~p", [Info]), noreply(State). -spec handle_packet(mqtt_packet(), state()) -> {ok, state()} | @@ -310,7 +312,7 @@ handle_packet(#pubrec{id = ID, code = Code}, State) -> {ok, State}; false -> Code1 = 'packet-identifier-not-found', - ?DEBUG("Got unexpected PUBREC with id=~B, " + ?DEBUG("Unexpected PUBREC with id=~B, " "sending PUBREL with error code '~s'", [ID, Code1]), send(State, #pubrel{id = ID, code = Code1}) end; @@ -326,7 +328,7 @@ handle_packet(#pubrel{id = ID}, State) -> send(State#state{acks = Acks}, #pubcomp{id = ID}); error -> Code = 'packet-identifier-not-found', - ?DEBUG("Got unexpected PUBREL with id=~B, " + ?DEBUG("Unexpected PUBREL with id=~B, " "sending PUBCOMP with error code '~s'", [ID, Code]), Pubcomp = #pubcomp{id = ID, code = Code}, send(State, Pubcomp) @@ -416,13 +418,27 @@ stop(#state{session_expiry = SessExp} = State, Reason) -> noreply(State4) end. --spec upgrade_state(term()) -> state(). +%% Here is the code upgrading state between different +%% code versions. This is needed when doing session resumption from +%% remote node running the version of the code with incompatible #state{} +%% record fields. Also used by code_change/3 callback. +-spec upgrade_state(tuple()) -> state(). upgrade_state(State) -> - %% Here will be the code upgrading state between different - %% code versions. This is needed when doing session resumption from - %% remote node running the version of the code with incompatible #state{} - %% record fields. Also used by code_change/3 callback. - %% Use element(2, State) for vsn comparison. + case element(2, State) of + ?VSN -> + State; + VSN when VSN > ?VSN -> + erlang:error({downgrade_not_supported, State}); + VSN -> + State1 = upgrade_state(State, VSN), + upgrade_state(setelement(2, State1, VSN+1)) + end. + +-spec upgrade_state(tuple(), 1..?VSN) -> tuple(). +upgrade_state(OldState, 1) -> + %% Appending 'tls' field + erlang:append_element(OldState, false); +upgrade_state(State, _VSN) -> State. %%%=================================================================== @@ -673,13 +689,13 @@ get_connack_properties(#state{session_expiry = SessExp, jid = JID}, server_keep_alive => KeepAlive}. -spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer()) -> - {[reason_code()], map(), properties()}. + {[reason_code()], subscriptions(), properties()}. subscribe(TopicFilters, USR, SubID) -> subscribe(TopicFilters, USR, SubID, [], #{}, ok). -spec subscribe([{binary(), sub_opts()}], jid:ljid(), non_neg_integer(), - [reason_code()], map(), ok | {error, error_reason()}) -> - {[reason_code()], map(), properties()}. + [reason_code()], subscriptions(), ok | {error, error_reason()}) -> + {[reason_code()], subscriptions(), properties()}. subscribe([{TopicFilter, SubOpts}|TopicFilters], USR, SubID, Codes, Subs, Err) -> case mod_mqtt:subscribe(USR, TopicFilter, SubOpts, SubID) of ok -> @@ -698,15 +714,15 @@ subscribe([], _USR, _SubID, Codes, Subs, Err) -> end, {lists:reverse(Codes), Subs, Props}. --spec unsubscribe([binary()], jid:ljid(), map()) -> - {[reason_code()], map(), properties()}. +-spec unsubscribe([binary()], jid:ljid(), subscriptions()) -> + {[reason_code()], subscriptions(), properties()}. unsubscribe(TopicFilters, USR, Subs) -> unsubscribe(TopicFilters, USR, [], Subs, ok). -spec unsubscribe([binary()], jid:ljid(), - [reason_code()], map(), + [reason_code()], subscriptions(), ok | {error, error_reason()}) -> - {[reason_code()], map(), properties()}. + {[reason_code()], subscriptions(), properties()}. unsubscribe([TopicFilter|TopicFilters], USR, Codes, Subs, Err) -> case mod_mqtt:unsubscribe(USR, TopicFilter) of ok -> @@ -728,7 +744,7 @@ unsubscribe([], _USR, Codes, Subs, Err) -> end, {lists:reverse(Codes), Subs, Props}. --spec select_retained(jid:ljid(), map(), map()) -> [{publish(), seconds()}]. +-spec select_retained(jid:ljid(), subscriptions(), subscriptions()) -> [{publish(), seconds()}]. select_retained(USR, NewSubs, OldSubs) -> lists:flatten( maps:fold( @@ -915,8 +931,8 @@ check_sock_result({_, Sock}, {error, Why}) -> self() ! {tcp_closed, Sock}, ?DEBUG("MQTT socket error: ~p", [format_inet_error(Why)]). --spec starttls(socket()) -> {ok, socket()} | {error, error_reason()}. -starttls({fast_tls, Socket}) -> +-spec starttls(state()) -> {ok, socket()} | {error, error_reason()}. +starttls(#state{socket = {gen_tcp, Socket}, tls = true}) -> case ejabberd_pkix:get_certfile() of {ok, Cert} -> case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of @@ -928,7 +944,7 @@ starttls({fast_tls, Socket}) -> error -> {error, {tls, no_certfile}} end; -starttls(Socket) -> +starttls(#state{socket = Socket}) -> {ok, Socket}. -spec recv_data(socket(), binary()) -> {ok, binary()} | {error, error_reason()}. @@ -961,8 +977,8 @@ format_inet_error(Reason) -> end. -spec format_tls_error(atom() | binary()) -> string() | binary(). -format_tls_error(no_cerfile) -> - "certificate not found"; +format_tls_error(no_certfile) -> + "certificate not configured"; format_tls_error(Reason) when is_atom(Reason) -> format_inet_error(Reason); format_tls_error(Reason) -> @@ -1050,19 +1066,19 @@ connack_reason_code(_) -> 'unspecified-error'. %%%=================================================================== -spec queue_type(binary()) -> ram | file. queue_type(Host) -> - gen_mod:get_module_opt(Host, mod_mqtt, queue_type). + mod_mqtt_opt:queue_type(Host). -spec queue_limit(binary()) -> non_neg_integer() | unlimited. queue_limit(Host) -> - gen_mod:get_module_opt(Host, mod_mqtt, max_queue). + mod_mqtt_opt:max_queue(Host). -spec session_expiry(binary()) -> seconds(). session_expiry(Host) -> - gen_mod:get_module_opt(Host, mod_mqtt, session_expiry). + mod_mqtt_opt:session_expiry(Host). -spec topic_alias_maximum(binary()) -> non_neg_integer(). topic_alias_maximum(Host) -> - gen_mod:get_module_opt(Host, mod_mqtt, max_topic_aliases). + mod_mqtt_opt:max_topic_aliases(Host). %%%=================================================================== %%% Timings @@ -1177,7 +1193,7 @@ authenticate(#connect{password = Pass} = Pkt, IP) -> %%%=================================================================== %%% Validators %%%=================================================================== --spec validate_will(connect(), jid:jid()) -> ok | {error, reason_code()}. +-spec validate_will(connect(), jid:jid()) -> ok | {error, error_reason()}. validate_will(#connect{will = undefined}, _) -> ok; validate_will(#connect{will = #publish{topic = Topic, payload = Payload}, @@ -1242,7 +1258,7 @@ validate_payload(_, _, _) -> %%%=================================================================== %%% Misc %%%=================================================================== --spec resubscribe(jid:ljid(), map()) -> ok | {error, error_reason()}. +-spec resubscribe(jid:ljid(), subscriptions()) -> ok | {error, error_reason()}. resubscribe(USR, Subs) -> case maps:fold( fun(TopicFilter, {SubOpts, ID}, ok) -> |