summaryrefslogtreecommitdiff
path: root/src/mod_mqtt_mnesia.erl
diff options
context:
space:
mode:
authorEvgeny Khramtsov <ekhramtsov@process-one.net>2019-02-26 13:45:53 +0300
committerEvgeny Khramtsov <ekhramtsov@process-one.net>2019-02-26 13:45:53 +0300
commit0250826cf9e70af75cef06bc9d9a93635d494bfd (patch)
tree6377acb5820dfc2beb6ffe006e513d976158ea72 /src/mod_mqtt_mnesia.erl
parentUpdate deps (diff)
Update mod_mqtt_mnesia.erl
Diffstat (limited to 'src/mod_mqtt_mnesia.erl')
-rw-r--r--src/mod_mqtt_mnesia.erl195
1 files changed, 181 insertions, 14 deletions
diff --git a/src/mod_mqtt_mnesia.erl b/src/mod_mqtt_mnesia.erl
index 3439c930..19ad36cd 100644
--- a/src/mod_mqtt_mnesia.erl
+++ b/src/mod_mqtt_mnesia.erl
@@ -21,12 +21,12 @@
%% API
-export([init/2, publish/6, delete_published/2, lookup_published/2]).
-export([list_topics/1, use_cache/1]).
-%% Unsupported backend API
-export([init/0]).
-export([subscribe/4, unsubscribe/2, find_subscriber/2]).
-export([open_session/1, close_session/1, lookup_session/1]).
-include("logger.hrl").
+-include("mqtt.hrl").
-record(mqtt_pub, {topic_server :: {binary(), binary()},
user :: binary(),
@@ -40,6 +40,16 @@
content_type = <<>> :: binary(),
user_properties = [] :: [{binary(), binary()}]}).
+-record(mqtt_sub, {topic :: {binary(), binary(), binary(), binary()},
+ options :: sub_opts(),
+ id :: non_neg_integer(),
+ pid :: pid(),
+ timestamp :: erlang:timestamp()}).
+
+-record(mqtt_session, {usr :: jid:ljid(),
+ pid :: pid(),
+ timestamp :: erlang:timestamp()}).
+
%%%===================================================================
%%% API
%%%===================================================================
@@ -107,26 +117,183 @@ list_topics(S) ->
{ok, [Topic || {Topic, S1} <- mnesia:dirty_all_keys(mqtt_pub), S1 == S]}.
init() ->
- erlang:nif_error(unsupported_db).
+ case mqtree:whereis(mqtt_sub_index) of
+ undefined ->
+ T = mqtree:new(),
+ mqtree:register(mqtt_sub_index, T);
+ _ ->
+ ok
+ end,
+ try
+ {atomic, ok} = ejabberd_mnesia:create(
+ ?MODULE,
+ mqtt_session,
+ [{ram_copies, [node()]},
+ {attributes, record_info(fields, mqtt_session)}]),
+ {atomic, ok} = ejabberd_mnesia:create(
+ ?MODULE,
+ mqtt_sub,
+ [{ram_copies, [node()]},
+ {type, ordered_set},
+ {attributes, record_info(fields, mqtt_sub)}]),
+ ok
+ catch _:{badmatch, Err} ->
+ {error, Err}
+ end.
-open_session(_) ->
- erlang:nif_error(unsupported_db).
+open_session(USR) ->
+ TS1 = p1_time_compat:unique_timestamp(),
+ P1 = self(),
+ F = fun() ->
+ case mnesia:read(mqtt_session, USR) of
+ [#mqtt_session{pid = P2, timestamp = TS2}] ->
+ if TS1 >= TS2 ->
+ mod_mqtt_session:route(P2, {replaced, P1}),
+ mnesia:write(
+ #mqtt_session{usr = USR,
+ pid = P1,
+ timestamp = TS1});
+ true ->
+ case is_process_dead(P2) of
+ true ->
+ mnesia:write(
+ #mqtt_session{usr = USR,
+ pid = P1,
+ timestamp = TS1});
+ false ->
+ mod_mqtt_session:route(P1, {replaced, P2})
+ end
+ end;
+ [] ->
+ mnesia:write(
+ #mqtt_session{usr = USR,
+ pid = P1,
+ timestamp = TS1})
+ end
+ end,
+ case mnesia:transaction(F) of
+ {atomic, _} -> ok;
+ {aborted, Reason} ->
+ db_fail("Failed to register MQTT session for ~s",
+ Reason, [jid:encode(USR)])
+ end.
-close_session(_) ->
- erlang:nif_error(unsupported_db).
+close_session(USR) ->
+ close_session(USR, self()).
-lookup_session(_) ->
- erlang:nif_error(unsupported_db).
+lookup_session(USR) ->
+ case mnesia:dirty_read(mqtt_session, USR) of
+ [#mqtt_session{pid = Pid}] ->
+ case is_process_dead(Pid) of
+ true ->
+ %% Read-Repair
+ close_session(USR, Pid),
+ {error, notfound};
+ false ->
+ {ok, Pid}
+ end;
+ [] ->
+ {error, notfound}
+ end.
-subscribe(_, _, _, _) ->
- erlang:nif_error(unsupported_db).
+subscribe({U, S, R} = USR, TopicFilter, SubOpts, ID) ->
+ T1 = p1_time_compat:unique_timestamp(),
+ P1 = self(),
+ Key = {TopicFilter, S, U, R},
+ F = fun() ->
+ case mnesia:read(mqtt_sub, Key) of
+ [#mqtt_sub{timestamp = T2}] when T1 < T2 ->
+ ok;
+ _ ->
+ Tree = mqtree:whereis(mqtt_sub_index),
+ mqtree:insert(Tree, TopicFilter),
+ mnesia:write(
+ #mqtt_sub{topic = {TopicFilter, S, U, R},
+ options = SubOpts,
+ id = ID,
+ pid = P1,
+ timestamp = T1})
+ end
+ end,
+ case mnesia:transaction(F) of
+ {atomic, _} -> ok;
+ {abored, Reason} ->
+ db_fail("Failed to subscribe ~s to ~s",
+ Reason, [jid:encode(USR), TopicFilter])
+ end.
-unsubscribe(_, _) ->
- erlang:nif_error(unsupported_db).
+unsubscribe({U, S, R} = USR, Topic) ->
+ Pid = self(),
+ F = fun() ->
+ Tree = mqtree:whereis(mqtt_sub_index),
+ mqtree:delete(Tree, Topic),
+ case mnesia:read(mqtt_sub, {Topic, S, U, R}) of
+ [#mqtt_sub{pid = Pid} = Obj] ->
+ mnesia:delete_object(Obj);
+ _ ->
+ ok
+ end
+ end,
+ case mnesia:transaction(F) of
+ {atomic, _} -> ok;
+ {aborted, Reason} ->
+ db_fail("Failed to unsubscribe ~s from ~s",
+ Reason, [jid:encode(USR), Topic])
+ end.
-find_subscriber(_, _) ->
- erlang:nif_error(unsupported_db).
+find_subscriber(S, Topic) when is_binary(Topic) ->
+ Tree = mqtree:whereis(mqtt_sub_index),
+ case mqtree:match(Tree, Topic) of
+ [Filter|Filters] ->
+ find_subscriber(S, {Filters, {Filter, S, '_', '_'}});
+ [] ->
+ {error, notfound}
+ end;
+find_subscriber(S, {Filters, {Filter, S, _, _} = Prev}) ->
+ case mnesia:dirty_next(mqtt_sub, Prev) of
+ {Filter, S, _, _} = Next ->
+ case mnesia:dirty_read(mqtt_sub, Next) of
+ [#mqtt_sub{options = SubOpts, id = ID, pid = Pid}] ->
+ case is_process_dead(Pid) of
+ true ->
+ find_subscriber(S, {Filters, Next});
+ false ->
+ {ok, {Pid, SubOpts, ID}, {Filters, Next}}
+ end;
+ [] ->
+ find_subscriber(S, {Filters, Next})
+ end;
+ _ ->
+ case Filters of
+ [] ->
+ {error, notfound};
+ [Filter1|Filters1] ->
+ find_subscriber(S, {Filters1, {Filter1, S, '_', '_'}})
+ end
+ end.
%%%===================================================================
%%% Internal functions
%%%===================================================================
+close_session(USR, Pid) ->
+ F = fun() ->
+ case mnesia:read(mqtt_session, USR) of
+ [#mqtt_session{pid = Pid} = Obj] ->
+ mnesia:delete_object(Obj);
+ _ ->
+ ok
+ end
+ end,
+ case mnesia:transaction(F) of
+ {atomic, _} -> ok;
+ {aborted, Reason} ->
+ db_fail("Failed to unregister MQTT session for ~s",
+ Reason, [jid:encode(USR)])
+ end.
+
+is_process_dead(Pid) ->
+ node(Pid) == node() andalso not is_process_alive(Pid).
+
+db_fail(Format, Reason, Args) ->
+ ?ERROR_MSG(Format ++ ": ~p", Args ++ [Reason]),
+ {error, db_failure}.