%%%-------------------------------------------------------------------
%%% File : ejabberd_redis.erl
%%% Author : Evgeny Khramtsov <ekhramtsov@process-one.net>
%%% Created : 8 May 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%
%%%
%%% ejabberd, Copyright (C) 2002-2020 ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%----------------------------------------------------------------------
-module(ejabberd_redis).
-ifndef(GEN_SERVER).
-define(GEN_SERVER, gen_server).
-endif.
-behaviour(?GEN_SERVER).
-compile({no_auto_import, [get/1, put/2]}).
%% API
-export([start_link/1, get_proc/1, get_connection/1, q/1, qp/1, format_error/1]).
%% Commands
-export([multi/1, get/1, set/2, del/1, info/1,
sadd/2, srem/2, smembers/1, sismember/2, scard/1,
hget/2, hset/3, hdel/2, hlen/1, hgetall/1, hkeys/1,
subscribe/1, publish/2, script_load/1, evalsha/3]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-define(SERVER, ?MODULE).
-define(PROCNAME, 'ejabberd_redis_client').
-define(TR_STACK, redis_transaction_stack).
-define(DEFAULT_MAX_QUEUE, 10000).
-define(MAX_RETRIES, 1).
-define(CALL_TIMEOUT, 60*1000). %% 60 seconds
-include("logger.hrl").
-include("ejabberd_stacktrace.hrl").
-record(state, {connection :: pid() | undefined,
num :: pos_integer(),
subscriptions = #{} :: subscriptions(),
pending_q :: queue()}).
-type queue() :: p1_queue:queue({{pid(), term()}, integer()}).
-type subscriptions() :: #{binary() => [pid()]}.
-type error_reason() :: binary() | timeout | disconnected | overloaded.
-type redis_error() :: {error, error_reason()}.
-type redis_reply() :: undefined | binary() | [binary()].
-type redis_command() :: [iodata() | integer()].
-type redis_pipeline() :: [redis_command()].
-type redis_info() :: server | clients | memory | persistence |
stats | replication | cpu | commandstats |
cluster | keyspace | default | all.
-type state() :: #state{}.
-export_type([error_reason/0]).
%%%===================================================================
%%% API
%%%===================================================================
start_link(I) ->
?GEN_SERVER:start_link({local, get_proc(I)}, ?MODULE, [I], []).
get_proc(I) ->
misc:binary_to_atom(
iolist_to_binary(
[atom_to_list(?MODULE), $_, integer_to_list(I)])).
get_connection(I) ->
misc:binary_to_atom(
iolist_to_binary(
[atom_to_list(?MODULE), "_connection_", integer_to_list(I)])).
-spec q(redis_command()) -> {ok, redis_reply()} | redis_error().
q(Command) ->
call(get_rnd_id(), {q, Command}, ?MAX_RETRIES).
-spec qp(redis_pipeline()) -> [{ok, redis_reply()} | redis_error()] | redis_error().
qp(Pipeline) ->
call(get_rnd_id(), {qp, Pipeline}, ?MAX_RETRIES).
-spec multi(fun(() -> any())) -> {ok, redis_reply()} | redis_error().
multi(F) ->
case erlang:get(?TR_STACK) of
undefined ->
erlang:put(?TR_STACK, []),
try F() of
_ ->
Stack = erlang:erase(?TR_STACK),
Command = [["MULTI"]|lists:reverse([["EXEC"]|Stack])],
case qp(Command) of
{error, _} = Err -> Err;
Result -> get_result(Result)
end
catch ?EX_RULE(E, R, St) ->
erlang:erase(?TR_STACK),
erlang:raise(E, R, ?EX_STACK(St))
end;
_ ->
erlang:error(nested_transaction)
end.
-spec format_error(atom() | binary()) -> binary().
format_error(Reason) when is_atom(Reason) ->
format_error(misc:atom_to_binary(Reason));
format_error(Reason) ->
Reason.
%%%===================================================================
%%% Redis commands API
%%%===================================================================
-spec get(iodata()) -> {ok, undefined | binary()} | redis_error().
get(Key) ->
case erlang:get(?TR_STACK) of
undefined ->
q([<<"GET">>, Key]);
_ ->
erlang:error(transaction_unsupported)
end.
-spec set(iodata(), iodata()) -> ok | redis_error() | queued.
set(Key, Val) ->
Cmd = [<<"SET">>, Key, Val],
case erlang:get(?TR_STACK) of
undefined ->
case q(Cmd) of
{ok, <<"OK">>} -> ok;
{error, _} = Err -> Err
end;
Stack ->
tr_enq(Cmd, Stack)
end.
-spec del(list()) -> {ok, non_neg_integer()} | redis_error() | queued.
del([]) ->
reply(0);
del(Keys) ->
Cmd = [<<"DEL">>|Keys],
case erlang:get(?TR_STACK) of
undefined ->
case q(Cmd) of
{ok, N} -> {ok, binary_to_integer(N)};
{error, _} = Err -> Err
end;
Stack ->
tr_enq(Cmd, Stack)
end.
-spec sadd(iodata(), list()) -> {ok, non_neg_integer()} | redis_error() | queued.
sadd(_Set, []) ->
reply(0);
sadd(Set, Members) ->
Cmd = [<<"SADD">>, Set|Members],
case erlang:get(?TR_STACK) of
undefined ->
case q(Cmd) of
{ok, N} -> {ok, binary_to_integer(N)};
{error, _} = Err -> Err
end;
Stack ->
tr_enq(Cmd, Stack)
end.
-spec srem(iodata(), list()) -> {ok, non_neg_integer()} | redis_error() | queued.
srem(_Set, []) ->
reply(0);
srem(Set, Members) ->
Cmd = [<<"SREM">>, Set|Members],
case erlang:get(?TR_STACK) of
undefined ->
case q(Cmd) of
{ok, N} -> {ok, binary_to_integer(N)};
{error, _} = Err -> Err
end;
Stack ->
tr_enq(Cmd, Stack)
end.
-spec smembers(iodata()) -> {ok, [binary()]} | redis_error().
smembers(Set) ->
case erlang:get(?TR_STACK) of
undefined ->
q([<<"SMEMBERS">>, Set]);
_ ->
erlang:error(transaction_unsupported)
end.
-spec sismember(iodata(), iodata()) -> boolean() | redis_error().
sismember(Set, Member) ->
case erlang:get(?TR_STACK) of
undefined ->
case q([<<"SISMEMBER">>, Set, Member]) of
{ok, Flag} -> {ok, dec_bool(Flag)};
{error, _} = Err -> Err
end;
_ ->
erlang:error(transaction_unsupported)
end.
-spec scard(iodata()) -> {ok, non_neg_integer()} | redis_error().
scard(Set) ->
case erlang:get(?TR_STACK) of
undefined ->
case q([<<"SCARD">>, Set]) of
{ok, N} ->
{ok, binary_to_integer(N)};
{error, _} = Err ->
Err
end;
_ ->
erlang:error(transaction_unsupported)
end.
-spec hget(iodata(), iodata()) -> {ok, undefined | binary()} | redis_error().
hget(Key, Field) ->
case erlang:get(?TR_STACK) of
undefined ->
q([<<"HGET">>, Key, Field]);
_ ->
erlang:error(transaction_unsupported)
end.
-spec hset(iodata(), iodata(), iodata()) -> {ok, boolean()} | redis_error() | queued.
hset(Key, Field, Val) ->
Cmd = [<<"HSET">>, Key, Field, Val],
case erlang:get(?TR_STACK) of
undefined ->
case q(Cmd) of
{ok, Flag} -> {ok, dec_bool(Flag)};
{error, _} = Err -> Err
end;
Stack ->
tr_enq(Cmd, Stack)
end.
-spec hdel(iodata(), list()) -> {ok, non_neg_integer()} | redis_error() | queued.
hdel(_Key, []) ->
reply(0);
hdel(Key, Fields) ->
Cmd = [<<"HDEL">>, Key|Fields],
case erlang:get(?TR_STACK) of
undefined ->
case q(Cmd) of
{ok, N} -> {ok, binary_to_integer(N)};
{error, _} = Err -> Err
end;
Stack ->
tr_enq(Cmd, Stack)
end.
-spec hgetall(iodata()) -> {ok, [{binary(), binary()}]} | redis_error().
hgetall(Key) ->
case erlang:get(?TR_STACK) of
undefined ->
case q([<<"HGETALL">>, Key]) of
{ok, Pairs} -> {ok, decode_pairs(Pairs)};
{error, _} = Err -> Err
end;
_ ->
erlang:error(transaction_unsupported)
end.
-spec hlen(iodata()) -> {ok, non_neg_integer()} | redis_error().
hlen(Key) ->
case erlang:get(?TR_STACK) of
undefined ->
case q([<<"HLEN">>, Key]) of
{ok, N} -> {ok, binary_to_integer(N)};
{error, _} = Err -> Err
end;
_ ->
erlang:error(transaction_unsupported)
end.
-spec hkeys(iodata()) -> {ok, [binary()]} | redis_error().
hkeys(Key) ->
case erlang:get(?TR_STACK) of
undefined ->
q([<<"HKEYS">>, Key]);
_ ->
erlang:error(transaction_unsupported)
end.
-spec subscribe([binary()]) -> ok | redis_error().
subscribe(Channels) ->
try gen_server_call(get_proc(1), {subscribe, self(), Channels})
catch exit:{Why, {?GEN_SERVER, call, _}} ->
Reason = case Why of
timeout -> timeout;
_ -> disconnected
end,
{error, Reason}
end.
-spec publish(iodata(), iodata()) -> {ok, non_neg_integer()} | redis_error() | queued.
publish(Channel, Data) ->
Cmd = [<<"PUBLISH">>, Channel, Data],
case erlang:get(?TR_STACK) of
undefined ->
case q(Cmd) of
{ok, N} -> {ok, binary_to_integer(N)};
{error, _} = Err -> Err
end;
Stack ->
tr_enq(Cmd, Stack)
end.
-spec script_load(iodata()) -> {ok, binary()} | redis_error().
script_load(Data) ->
case erlang:get(?TR_STACK) of
undefined ->
q([<<"SCRIPT">>, <<"LOAD">>, Data]);
_ ->
erlang:error(transaction_unsupported)
end.
-spec evalsha(binary(), [iodata()], [iodata() | integer()]) -> {ok, binary()} | redis_error().
evalsha(SHA, Keys, Args) ->
case erlang:get(?TR_STACK) of
undefined ->
q([<<"EVALSHA">>, SHA, length(Keys)|Keys ++ Args]);
_ ->
erlang:error(transaction_unsupported)
end.
-spec info(redis_info()) -> {ok, [{atom(), binary()}]} | redis_error().
info(Type) ->
case erlang:get(?TR_STACK) of
undefined ->
case q([<<"INFO">>, misc:atom_to_binary(Type)]) of
{ok, Info} ->
Lines = binary:split(Info, <<"\r\n">>, [global]),
KVs = [binary:split(Line, <<":">>) || Line <- Lines],
{ok, [{misc:binary_to_atom(K), V} || [K, V] <- KVs]};
{error, _} = Err ->
Err
end;
_ ->
erlang:error(transaction_unsupported)
end.
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([I]) ->
process_flag(trap_exit, true),
QueueType = get_queue_type(),
Limit = max_fsm_queue(),
self() ! connect,
{ok, #state{num = I, pending_q = p1_queue:new(QueueType, Limit)}}.
handle_call(connect, From, #state{connection = undefined,
pending_q = Q} = State) ->
CurrTime = erlang:monotonic_time(millisecond),
Q2 = try p1_queue:in({From, CurrTime}, Q)
catch error:full ->
Q1 = clean_queue(Q, CurrTime),
p1_queue:in({From, CurrTime}, Q1)
end,
{noreply, State#state{pending_q = Q2}};
handle_call(connect, From, #state{connection = Pid} = State) ->
case is_process_alive(Pid) of
true ->
{reply, ok, State};
false ->
self() ! connect,
handle_call(connect, From, State#state{connection = undefined})
end;
handle_call({subscribe, Caller, Channels}, _From,
#state{connection = Pid, subscriptions = Subs} = State) ->
Subs1 = lists:foldl(
fun(Channel, Acc) ->
Callers = maps:get(Channel, Acc, []) -- [Caller],
maps:put(Channel, [Caller|Callers], Acc)
end, Subs, Channels),
eredis_subscribe(Pid, Channels),
{reply, ok, State#state{subscriptions = Subs1}};
handle_call(Request, _From, State) ->
?WARNING_MSG("Unexpected call: ~p", [Request]),
{noreply, State}.
handle_cast(_Msg, State) ->
{noreply, State}.
handle_info(connect, #state{connection = undefined} = State) ->
NewState = case connect(State) of
{ok, Connection} ->
Q1 = flush_queue(State#state.pending_q),
re_subscribe(Connection, State#state.subscriptions),
State#state{connection = Connection, pending_q = Q1};
{error, _} ->
State
end,
{noreply, NewState};
handle_info(connect, State) ->
%% Already connected
{noreply, State};
handle_info({'EXIT', Pid, _}, State) ->
case State#state.connection of
Pid ->
self() ! connect,
{noreply, State#state{connection = undefined}};
_ ->
{noreply, State}
end;
handle_info({subscribed, Channel, Pid}, State) ->
case State#state.connection of
Pid ->
case maps:is_key(Channel, State#state.subscriptions) of
true -> eredis_sub:ack_message(Pid);
false ->
?WARNING_MSG("Got subscription ack for unknown channel ~ts",
[Channel])
end;
_ ->
ok
end,
{noreply, State};
handle_info({message, Channel, Data, Pid}, State) ->
case State#state.connection of
Pid ->
lists:foreach(
fun(Subscriber) ->
erlang:send(Subscriber, {redis_message, Channel, Data})
end, maps:get(Channel, State#state.subscriptions, [])),
eredis_sub:ack_message(Pid);
_ ->
ok
end,
{noreply, State};
handle_info(Info, State) ->
?WARNING_MSG("Unexpected info = ~p", [Info]),
{noreply, State}.
terminate(_Reason, _State) ->
ok.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
-spec connect(state()) -> {ok, pid()} | {error, any()}.
connect(#state{num = Num}) ->
Server = ejabberd_option:redis_server(),
Port = ejabberd_option:redis_port(),
DB = ejabberd_option:redis_db(),
Pass = ejabberd_option:redis_password(),
ConnTimeout = ejabberd_option:redis_connect_timeout(),
try case do_connect(Num, Server, Port, Pass, DB, ConnTimeout) of
{ok, Client} ->
?DEBUG("Connection #~p established to Redis at ~ts:~p",
[Num, Server, Port]),
register(get_connection(Num), Client),
{ok, Client};
{error, Why} ->
erlang:error(Why)
end
catch _:Reason ->
Timeout = p1_rand:uniform(
min(10, ejabberd_redis_sup:get_pool_size())),
?ERROR_MSG("Redis connection #~p at ~ts:~p has failed: ~p; "
"reconnecting in ~p seconds",
[Num, Server, Port, Reason, Timeout]),
erlang:send_after(timer:seconds(Timeout), self(), connect),
{error, Reason}
end.
do_connect(1, Server, Port, Pass, _DB, _ConnTimeout) ->
%% First connection in the pool is always a subscriber
Res = eredis_sub:start_link(Server, Port, Pass, no_reconnect, infinity, drop),
case Res of
{ok, Pid} -> eredis_sub:controlling_process(Pid);
_ -> ok
end,
Res;
do_connect(_, Server, Port, Pass, DB, ConnTimeout) ->
eredis:start_link(Server, Port, DB, Pass, no_reconnect, ConnTimeout).
-spec call(pos_integer(), {q, redis_command()}, integer()) ->
{ok, redis_reply()} | redis_error();
(pos_integer(), {qp, redis_pipeline()}, integer()) ->
[{ok, redis_reply()} | redis_error()] | redis_error().
call(I, {F, Cmd}, Retries) ->
?DEBUG("Redis query: ~p", [Cmd]),
Conn = get_connection(I),
Res = try eredis:F(Conn, Cmd, ?CALL_TIMEOUT) of
{error, Reason} when is_atom(Reason) ->
try exit(whereis(Conn), kill) catch _:_ -> ok end,
{error, disconnected};
Other ->
Other
catch exit:{timeout, _} -> {error, timeout};
exit:{_, {gen_server, call, _}} -> {error, disconnected}
end,
case Res of
{error, disconnected} when Retries > 0 ->
try gen_server_call(get_proc(I), connect) of
ok -> call(I, {F, Cmd}, Retries-1);
{error, _} = Err -> Err
catch exit:{Why, {?GEN_SERVER, call, _}} ->
Reason1 = case Why of
timeout -> timeout;
_ -> disconnected
end,
log_error(Cmd, Reason1),
{error, Reason1}
end;
{error, Reason1} ->
log_error(Cmd, Reason1),
Res;
_ ->
Res
end.
gen_server_call(Proc, Msg) ->
case ejabberd_redis_sup:start() of
ok ->
?GEN_SERVER:call(Proc, Msg, ?CALL_TIMEOUT);
{error, _} ->
{error, disconnected}
end.
-spec log_error(redis_command() | redis_pipeline(), atom() | binary()) -> ok.
log_error(Cmd, Reason) ->
?ERROR_MSG("Redis request has failed:~n"
"** request = ~p~n"
"** response = ~ts",
[Cmd, format_error(Reason)]).
-spec get_rnd_id() -> pos_integer().
get_rnd_id() ->
p1_rand:round_robin(ejabberd_redis_sup:get_pool_size() - 1) + 2.
-spec get_result([{ok, redis_reply()} | redis_error()]) ->
{ok, redis_reply()} | redis_error().
get_result([{error, _} = Err|_]) ->
Err;
get_result([{ok, _} = OK]) ->
OK;
get_result([_|T]) ->
get_result(T).
-spec tr_enq([iodata()], list()) -> queued.
tr_enq(Cmd, Stack) ->
erlang:put(?TR_STACK, [Cmd|Stack]),
queued.
-spec decode_pairs([binary()]) -> [{binary(), binary()}].
decode_pairs(Pairs) ->
decode_pairs(Pairs, []).
-spec decode_pairs([binary()], [{binary(), binary()}]) -> [{binary(), binary()}].
decode_pairs([Field, Val|Pairs], Acc) ->
decode_pairs(Pairs, [{Field, Val}|Acc]);
decode_pairs([], Acc) ->
lists:reverse(Acc).
dec_bool(<<$1>>) -> true;
dec_bool(<<$0>>) -> false.
-spec reply(T) -> {ok, T} | queued.
reply(Val) ->
case erlang:get(?TR_STACK) of
undefined -> {ok, Val};
_ -> queued
end.
-spec max_fsm_queue() -> pos_integer().
max_fsm_queue() ->
proplists:get_value(max_queue, fsm_limit_opts(), ?DEFAULT_MAX_QUEUE).
fsm_limit_opts() ->
ejabberd_config:fsm_limit_opts([]).
get_queue_type() ->
ejabberd_option:redis_queue_type().
-spec flush_queue(queue()) -> queue().
flush_queue(Q) ->
CurrTime = erlang:monotonic_time(millisecond),
p1_queue:dropwhile(
fun({From, Time}) ->
if (CurrTime - Time) >= ?CALL_TIMEOUT ->
ok;
true ->
?GEN_SERVER:reply(From, ok)
end,
true
end, Q).
-spec clean_queue(queue(), integer()) -> queue().
clean_queue(Q, CurrTime) ->
Q1 = p1_queue:dropwhile(
fun({_From, Time}) ->
(CurrTime - Time) >= ?CALL_TIMEOUT
end, Q),
Len = p1_queue:len(Q1),
Limit = p1_queue:get_limit(Q1),
if Len >= Limit ->
?ERROR_MSG("Redis request queue is overloaded", []),
p1_queue:dropwhile(
fun({From, _Time}) ->
?GEN_SERVER:reply(From, {error, overloaded}),
true
end, Q1);
true ->
Q1
end.
re_subscribe(Pid, Subs) ->
case maps:keys(Subs) of
[] -> ok;
Channels -> eredis_subscribe(Pid, Channels)
end.
eredis_subscribe(Pid, Channels) ->
?DEBUG("Redis query: ~p", [[<<"SUBSCRIBE">>|Channels]]),
eredis_sub:subscribe(Pid, Channels).