diff options
Diffstat (limited to 'src/mod_mqtt_mnesia.erl')
-rw-r--r-- | src/mod_mqtt_mnesia.erl | 299 |
1 files changed, 299 insertions, 0 deletions
diff --git a/src/mod_mqtt_mnesia.erl b/src/mod_mqtt_mnesia.erl new file mode 100644 index 000000000..f5d2dec8e --- /dev/null +++ b/src/mod_mqtt_mnesia.erl @@ -0,0 +1,299 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net> +%%% @copyright (C) 2002-2019 ProcessOne, SARL. All Rights Reserved. +%%% +%%% Licensed under the Apache License, Version 2.0 (the "License"); +%%% you may not use this file except in compliance with the License. +%%% You may obtain a copy of the License at +%%% +%%% http://www.apache.org/licenses/LICENSE-2.0 +%%% +%%% Unless required by applicable law or agreed to in writing, software +%%% distributed under the License is distributed on an "AS IS" BASIS, +%%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%%% See the License for the specific language governing permissions and +%%% limitations under the License. +%%% +%%%------------------------------------------------------------------- +-module(mod_mqtt_mnesia). +-behaviour(mod_mqtt). + +%% API +-export([init/2, publish/6, delete_published/2, lookup_published/2]). +-export([list_topics/1, use_cache/1]). +-export([init/0]). +-export([subscribe/4, unsubscribe/2, find_subscriber/2]). +-export([open_session/1, close_session/1, lookup_session/1]). + +-include("logger.hrl"). +-include("mqtt.hrl"). + +-record(mqtt_pub, {topic_server :: {binary(), binary()}, + user :: binary(), + resource :: binary(), + qos :: 0..2, + payload :: binary(), + expiry :: non_neg_integer(), + payload_format = binary :: binary | utf8, + response_topic = <<>> :: binary(), + correlation_data = <<>> :: binary(), + content_type = <<>> :: binary(), + user_properties = [] :: [{binary(), binary()}]}). + +-record(mqtt_sub, {topic :: {binary(), binary(), binary(), binary()}, + options :: sub_opts(), + id :: non_neg_integer(), + pid :: pid(), + timestamp :: erlang:timestamp()}). + +-record(mqtt_session, {usr :: jid:ljid(), + pid :: pid(), + timestamp :: erlang:timestamp()}). + +%%%=================================================================== +%%% API +%%%=================================================================== +init(_Host, _Opts) -> + case ejabberd_mnesia:create( + ?MODULE, mqtt_pub, + [{disc_only_copies, [node()]}, + {attributes, record_info(fields, mqtt_pub)}]) of + {atomic, _} -> + ok; + Err -> + {error, Err} + end. + +use_cache(Host) -> + case mnesia:table_info(mqtt_pub, storage_type) of + disc_only_copies -> + mod_mqtt_opt:use_cache(Host); + _ -> + false + end. + +publish({U, LServer, R}, Topic, Payload, QoS, Props, ExpiryTime) -> + PayloadFormat = maps:get(payload_format_indicator, Props, binary), + ResponseTopic = maps:get(response_topic, Props, <<"">>), + CorrelationData = maps:get(correlation_data, Props, <<"">>), + ContentType = maps:get(content_type, Props, <<"">>), + UserProps = maps:get(user_property, Props, []), + mnesia:dirty_write(#mqtt_pub{topic_server = {Topic, LServer}, + user = U, + resource = R, + qos = QoS, + payload = Payload, + expiry = ExpiryTime, + payload_format = PayloadFormat, + response_topic = ResponseTopic, + correlation_data = CorrelationData, + content_type = ContentType, + user_properties = UserProps}). + +delete_published({_, S, _}, Topic) -> + mnesia:dirty_delete(mqtt_pub, {Topic, S}). + +lookup_published({_, S, _}, Topic) -> + case mnesia:dirty_read(mqtt_pub, {Topic, S}) of + [#mqtt_pub{qos = QoS, + payload = Payload, + expiry = ExpiryTime, + payload_format = PayloadFormat, + response_topic = ResponseTopic, + correlation_data = CorrelationData, + content_type = ContentType, + user_properties = UserProps}] -> + Props = #{payload_format => PayloadFormat, + response_topic => ResponseTopic, + correlation_data => CorrelationData, + content_type => ContentType, + user_property => UserProps}, + {ok, {Payload, QoS, Props, ExpiryTime}}; + [] -> + {error, notfound} + end. + +list_topics(S) -> + {ok, [Topic || {Topic, S1} <- mnesia:dirty_all_keys(mqtt_pub), S1 == S]}. + +init() -> + case mqtree:whereis(mqtt_sub_index) of + undefined -> + T = mqtree:new(), + mqtree:register(mqtt_sub_index, T); + _ -> + ok + end, + try + {atomic, ok} = ejabberd_mnesia:create( + ?MODULE, + mqtt_session, + [{ram_copies, [node()]}, + {attributes, record_info(fields, mqtt_session)}]), + {atomic, ok} = ejabberd_mnesia:create( + ?MODULE, + mqtt_sub, + [{ram_copies, [node()]}, + {type, ordered_set}, + {attributes, record_info(fields, mqtt_sub)}]), + ok + catch _:{badmatch, Err} -> + {error, Err} + end. + +open_session(USR) -> + TS1 = misc:unique_timestamp(), + P1 = self(), + F = fun() -> + case mnesia:read(mqtt_session, USR) of + [#mqtt_session{pid = P2, timestamp = TS2}] -> + if TS1 >= TS2 -> + mod_mqtt_session:route(P2, {replaced, P1}), + mnesia:write( + #mqtt_session{usr = USR, + pid = P1, + timestamp = TS1}); + true -> + case is_process_dead(P2) of + true -> + mnesia:write( + #mqtt_session{usr = USR, + pid = P1, + timestamp = TS1}); + false -> + mod_mqtt_session:route(P1, {replaced, P2}) + end + end; + [] -> + mnesia:write( + #mqtt_session{usr = USR, + pid = P1, + timestamp = TS1}) + end + end, + case mnesia:transaction(F) of + {atomic, _} -> ok; + {aborted, Reason} -> + db_fail("Failed to register MQTT session for ~ts", + Reason, [jid:encode(USR)]) + end. + +close_session(USR) -> + close_session(USR, self()). + +lookup_session(USR) -> + case mnesia:dirty_read(mqtt_session, USR) of + [#mqtt_session{pid = Pid}] -> + case is_process_dead(Pid) of + true -> + %% Read-Repair + close_session(USR, Pid), + {error, notfound}; + false -> + {ok, Pid} + end; + [] -> + {error, notfound} + end. + +subscribe({U, S, R} = USR, TopicFilter, SubOpts, ID) -> + T1 = misc:unique_timestamp(), + P1 = self(), + Key = {TopicFilter, S, U, R}, + F = fun() -> + case mnesia:read(mqtt_sub, Key) of + [#mqtt_sub{timestamp = T2}] when T1 < T2 -> + ok; + _ -> + Tree = mqtree:whereis(mqtt_sub_index), + mqtree:insert(Tree, TopicFilter), + mnesia:write( + #mqtt_sub{topic = {TopicFilter, S, U, R}, + options = SubOpts, + id = ID, + pid = P1, + timestamp = T1}) + end + end, + case mnesia:transaction(F) of + {atomic, _} -> ok; + {aborted, Reason} -> + db_fail("Failed to subscribe ~ts to ~ts", + Reason, [jid:encode(USR), TopicFilter]) + end. + +unsubscribe({U, S, R} = USR, Topic) -> + Pid = self(), + F = fun() -> + Tree = mqtree:whereis(mqtt_sub_index), + mqtree:delete(Tree, Topic), + case mnesia:read(mqtt_sub, {Topic, S, U, R}) of + [#mqtt_sub{pid = Pid} = Obj] -> + mnesia:delete_object(Obj); + _ -> + ok + end + end, + case mnesia:transaction(F) of + {atomic, _} -> ok; + {aborted, Reason} -> + db_fail("Failed to unsubscribe ~ts from ~ts", + Reason, [jid:encode(USR), Topic]) + end. + +find_subscriber(S, Topic) when is_binary(Topic) -> + Tree = mqtree:whereis(mqtt_sub_index), + case mqtree:match(Tree, Topic) of + [Filter|Filters] -> + find_subscriber(S, {Filters, {Filter, S, '_', '_'}}); + [] -> + {error, notfound} + end; +find_subscriber(S, {Filters, {Filter, S, _, _} = Prev}) -> + case mnesia:dirty_next(mqtt_sub, Prev) of + {Filter, S, _, _} = Next -> + case mnesia:dirty_read(mqtt_sub, Next) of + [#mqtt_sub{options = SubOpts, id = ID, pid = Pid}] -> + case is_process_dead(Pid) of + true -> + find_subscriber(S, {Filters, Next}); + false -> + {ok, {Pid, SubOpts, ID}, {Filters, Next}} + end; + [] -> + find_subscriber(S, {Filters, Next}) + end; + _ -> + case Filters of + [] -> + {error, notfound}; + [Filter1|Filters1] -> + find_subscriber(S, {Filters1, {Filter1, S, '_', '_'}}) + end + end. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +close_session(USR, Pid) -> + F = fun() -> + case mnesia:read(mqtt_session, USR) of + [#mqtt_session{pid = Pid} = Obj] -> + mnesia:delete_object(Obj); + _ -> + ok + end + end, + case mnesia:transaction(F) of + {atomic, _} -> ok; + {aborted, Reason} -> + db_fail("Failed to unregister MQTT session for ~ts", + Reason, [jid:encode(USR)]) + end. + +is_process_dead(Pid) -> + node(Pid) == node() andalso not is_process_alive(Pid). + +db_fail(Format, Reason, Args) -> + ?ERROR_MSG(Format ++ ": ~p", Args ++ [Reason]), + {error, db_failure}. |