summaryrefslogtreecommitdiff
path: root/src/odbc/ejabberd_odbc.erl
diff options
context:
space:
mode:
authorEvgeniy Khramtsov <xramtsov@gmail.com>2010-01-31 11:41:28 +0000
committerEvgeniy Khramtsov <xramtsov@gmail.com>2010-01-31 11:41:28 +0000
commit03454c7f1db017cb11712c59b9a3ce3d098c3a5e (patch)
tree30934aabf5629dfdc1b67aeed446d4599e082c58 /src/odbc/ejabberd_odbc.erl
parentregenerated guide.html (diff)
improved SQL reconnect behaviour
SVN Revision: 2947
Diffstat (limited to 'src/odbc/ejabberd_odbc.erl')
-rw-r--r--src/odbc/ejabberd_odbc.erl323
1 files changed, 190 insertions, 133 deletions
diff --git a/src/odbc/ejabberd_odbc.erl b/src/odbc/ejabberd_odbc.erl
index b2c1c20f..de38260c 100644
--- a/src/odbc/ejabberd_odbc.erl
+++ b/src/odbc/ejabberd_odbc.erl
@@ -27,7 +27,9 @@
-module(ejabberd_odbc).
-author('alexey@process-one.net').
--behaviour(gen_server).
+-define(GEN_FSM, p1_fsm).
+
+-behaviour(?GEN_FSM).
%% External exports
-export([start/1, start_link/2,
@@ -39,17 +41,28 @@
escape_like/1,
keep_alive/1]).
-%% gen_server callbacks
+%% gen_fsm callbacks
-export([init/1,
- handle_call/3,
- handle_cast/2,
- code_change/3,
- handle_info/2,
- terminate/2]).
+ handle_event/3,
+ handle_sync_event/4,
+ handle_info/3,
+ terminate/3,
+ code_change/4]).
+
+%% gen_fsm states
+-export([connecting/2,
+ connecting/3,
+ session_established/2,
+ session_established/3]).
-include("ejabberd.hrl").
--record(state, {db_ref, db_type}).
+-record(state, {db_ref,
+ db_type,
+ start_interval,
+ host,
+ max_pending_requests_len,
+ pending_requests}).
-define(STATE_KEY, ejabberd_odbc_state).
-define(NESTING_KEY, ejabberd_odbc_nesting_level).
@@ -62,14 +75,23 @@
-define(KEEPALIVE_TIMEOUT, 60000).
-define(KEEPALIVE_QUERY, "SELECT 1;").
+%%-define(DBGFSM, true).
+
+-ifdef(DBGFSM).
+-define(FSMOPTS, [{debug, [trace]}]).
+-else.
+-define(FSMOPTS, []).
+-endif.
+
%%%----------------------------------------------------------------------
%%% API
%%%----------------------------------------------------------------------
start(Host) ->
- gen_server:start(ejabberd_odbc, [Host], []).
+ ?GEN_FSM:start(ejabberd_odbc, [Host], fsm_limit_opts() ++ ?FSMOPTS).
start_link(Host, StartInterval) ->
- gen_server:start_link(ejabberd_odbc, [Host, StartInterval], []).
+ ?GEN_FSM:start_link(ejabberd_odbc, [Host, StartInterval],
+ fsm_limit_opts() ++ ?FSMOPTS).
sql_query(Host, Query) ->
sql_call(Host, {sql_query, Query}).
@@ -95,12 +117,16 @@ sql_bloc(Host, F) ->
sql_call(Host, Msg) ->
case get(?STATE_KEY) of
undefined ->
- gen_server:call(ejabberd_odbc_sup:get_random_pid(Host),
- {sql_cmd, Msg}, ?TRANSACTION_TIMEOUT);
+ ?GEN_FSM:sync_send_event(ejabberd_odbc_sup:get_random_pid(Host),
+ {sql_cmd, Msg}, ?TRANSACTION_TIMEOUT);
_State ->
nested_op(Msg)
end.
+% perform a harmless query on all opened connexions to avoid connexion close.
+keep_alive(PID) ->
+ ?GEN_FSM:sync_send_event(PID, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}},
+ ?KEEPALIVE_TIMEOUT).
%% This function is intended to be used from inside an sql_transaction:
sql_query_t(Query) ->
@@ -134,16 +160,8 @@ escape_like(C) -> odbc_queries:escape(C).
%%%----------------------------------------------------------------------
-%%% Callback functions from gen_server
+%%% Callback functions from gen_fsm
%%%----------------------------------------------------------------------
-
-%%----------------------------------------------------------------------
-%% Func: init/1
-%% Returns: {ok, State} |
-%% {ok, State, Timeout} |
-%% ignore |
-%% {stop, Reason}
-%%----------------------------------------------------------------------
init([Host, StartInterval]) ->
case ejabberd_config:get_local_option({odbc_keepalive_interval, Host}) of
KeepaliveInterval when is_integer(KeepaliveInterval) ->
@@ -155,80 +173,114 @@ init([Host, StartInterval]) ->
?ERROR_MSG("Wrong odbc_keepalive_interval definition '~p'"
" for host ~p.~n", [_Other, Host])
end,
- SQLServer = ejabberd_config:get_local_option({odbc_server, Host}),
- case SQLServer of
- %% Default pgsql port
- {pgsql, Server, DB, Username, Password} ->
- pgsql_connect(Server, ?PGSQL_PORT, DB, Username, Password,
- StartInterval);
- {pgsql, Server, Port, DB, Username, Password} when is_integer(Port) ->
- pgsql_connect(Server, Port, DB, Username, Password,
- StartInterval);
- %% Default mysql port
- {mysql, Server, DB, Username, Password} ->
- mysql_connect(Server, ?MYSQL_PORT, DB, Username, Password,
- StartInterval);
- {mysql, Server, Port, DB, Username, Password} when is_integer(Port) ->
- mysql_connect(Server, Port, DB, Username, Password,
- StartInterval);
- _ when is_list(SQLServer) ->
- odbc_connect(SQLServer, StartInterval)
- end.
+ [DBType | _] = db_opts(Host),
+ ?GEN_FSM:send_event(self(), connect),
+ {ok, connecting, #state{db_type = DBType,
+ host = Host,
+ max_pending_requests_len = max_fsm_queue(),
+ pending_requests = {0, queue:new()},
+ start_interval = StartInterval}}.
+
+connecting(connect, #state{host = Host} = State) ->
+ ConnectRes = case db_opts(Host) of
+ [mysql | Args] ->
+ apply(fun mysql_connect/5, Args);
+ [pgsql | Args] ->
+ apply(fun pgsql_connect/5, Args);
+ [odbc | Args] ->
+ apply(fun odbc_connect/1, Args)
+ end,
+ {_, PendingRequests} = State#state.pending_requests,
+ case ConnectRes of
+ {ok, Ref} ->
+ erlang:monitor(process, Ref),
+ queue:filter(
+ fun(Req) ->
+ ?GEN_FSM:send_event(self(), Req),
+ false
+ end, PendingRequests),
+ {next_state, session_established,
+ State#state{db_ref = Ref,
+ pending_requests = {0, queue:new()}}};
+ {error, Reason} ->
+ ?INFO_MSG("~p connection failed:~n"
+ "** Reason: ~p~n"
+ "** Retry after: ~p seconds",
+ [State#state.db_type, Reason,
+ State#state.start_interval div 1000]),
+ ?GEN_FSM:send_event_after(State#state.start_interval,
+ connect),
+ {next_state, connecting, State}
+ end;
+connecting(Event, State) ->
+ ?WARNING_MSG("unexpected event in 'connecting': ~p", [Event]),
+ {next_state, connecting, State}.
+
+connecting({sql_cmd, {sql_query, ?KEEPALIVE_QUERY}}, From, State) ->
+ ?GEN_FSM:reply(From, {error, "SQL connection failed"}),
+ {next_state, connecting, State};
+connecting({sql_cmd, Command} = Req, From, State) ->
+ ?DEBUG("queueing pending request while connecting:~n\t~p", [Req]),
+ {Len, PendingRequests} = State#state.pending_requests,
+ NewPendingRequests =
+ if Len < State#state.max_pending_requests_len ->
+ {Len + 1, queue:in({sql_cmd, Command, From}, PendingRequests)};
+ true ->
+ queue:filter(
+ fun({sql_cmd, _, To}) ->
+ ?GEN_FSM:reply(To,
+ {error, "SQL connection failed"}),
+ false
+ end, PendingRequests),
+ {1, queue:from_list([{sql_cmd, Command, From}])}
+ end,
+ {next_state, connecting,
+ State#state{pending_requests = NewPendingRequests}};
+connecting(Request, {Who, _Ref}, State) ->
+ ?WARNING_MSG("unexpected call ~p from ~p in 'connecting'",
+ [Request, Who]),
+ {reply, {error, badarg}, connecting, State}.
+
+session_established({sql_cmd, Command}, From, State) ->
+ put(?NESTING_KEY, ?TOP_LEVEL_TXN),
+ put(?STATE_KEY, State),
+ abort_on_driver_error(outer_op(Command), From);
+session_established(Request, {Who, _Ref}, State) ->
+ ?WARNING_MSG("unexpected call ~p from ~p in 'session_established'",
+ [Request, Who]),
+ {reply, {error, badarg}, session_established, State}.
-%%----------------------------------------------------------------------
-%% Func: handle_call/3
-%% Returns: {reply, Reply, State} |
-%% {reply, Reply, State, Timeout} |
-%% {noreply, State} |
-%% {noreply, State, Timeout} |
-%% {stop, Reason, Reply, State} | (terminate/2 is called)
-%% {stop, Reason, State} (terminate/2 is called)
-%%----------------------------------------------------------------------
-handle_call({sql_cmd, Command}, _From, State) ->
+session_established({sql_cmd, Command, From}, State) ->
put(?NESTING_KEY, ?TOP_LEVEL_TXN),
put(?STATE_KEY, State),
- abort_on_driver_error(outer_op(Command));
-handle_call(Request, {Who, _Ref}, State) ->
- ?WARNING_MSG("Unexpected call ~p from ~p.", [Request, Who]),
- {reply, ok, State}.
-
-%%----------------------------------------------------------------------
-%% Func: handle_cast/2
-%% Returns: {noreply, State} |
-%% {noreply, State, Timeout} |
-%% {stop, Reason, State} (terminate/2 is called)
-%%----------------------------------------------------------------------
-handle_cast(_Msg, State) ->
- {noreply, State}.
-
-
-code_change(_OldVsn, State, _Extra) ->
- {ok, State}.
-
-%%----------------------------------------------------------------------
-%% Func: handle_info/2
-%% Returns: {noreply, State} |
-%% {noreply, State, Timeout} |
-%% {stop, Reason, State} (terminate/2 is called)
-%%----------------------------------------------------------------------
+ abort_on_driver_error(outer_op(Command), From);
+session_established(Event, State) ->
+ ?WARNING_MSG("unexpected event in 'session_established': ~p", [Event]),
+ {next_state, session_established, State}.
+
+handle_event(_Event, StateName, State) ->
+ {next_state, StateName, State}.
+
+handle_sync_event(_Event, _From, StateName, State) ->
+ {reply, {error, badarg}, StateName, State}.
+
+code_change(_OldVsn, StateName, State, _Extra) ->
+ {ok, StateName, State}.
+
%% We receive the down signal when we loose the MySQL connection (we are
%% monitoring the connection)
-%% => We exit and let the supervisor restart the connection.
-handle_info({'DOWN', _MonitorRef, process, _Pid, _Info}, State) ->
- {stop, connection_dropped, State};
-handle_info(_Info, State) ->
- {noreply, State}.
-
-%%----------------------------------------------------------------------
-%% Func: terminate/2
-%% Purpose: Shutdown the server
-%% Returns: any (ignored by gen_server)
-%%----------------------------------------------------------------------
-terminate(_Reason, State) ->
+handle_info({'DOWN', _MonitorRef, process, _Pid, _Info}, _StateName, State) ->
+ ?GEN_FSM:send_event(self(), connect),
+ {next_state, connecting, State};
+handle_info(Info, StateName, State) ->
+ ?WARNING_MSG("unexpected info in ~p: ~p", [StateName, Info]),
+ {next_state, StateName, State}.
+
+terminate(_Reason, _StateName, State) ->
case State#state.db_type of
mysql ->
- % old versions of mysql driver don't have the stop function
- % so the catch
+ %% old versions of mysql driver don't have the stop function
+ %% so the catch
catch mysql_conn:stop(State#state.db_ref);
_ ->
ok
@@ -367,50 +419,34 @@ sql_query_internal(Query) ->
end.
%% Generate the OTP callback return tuple depending on the driver result.
-abort_on_driver_error({error, "query timed out"} = Reply) ->
+abort_on_driver_error({error, "query timed out"} = Reply, From) ->
%% mysql driver error
- {stop, timeout, Reply, get(?STATE_KEY)};
-abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply) ->
+ ?GEN_FSM:reply(From, Reply),
+ {stop, timeout, get(?STATE_KEY)};
+abort_on_driver_error({error, "Failed sending data on socket" ++ _} = Reply,
+ From) ->
%% mysql driver error
- {stop, closed, Reply, get(?STATE_KEY)};
-abort_on_driver_error(Reply) ->
- {reply, Reply, get(?STATE_KEY)}.
+ ?GEN_FSM:reply(From, Reply),
+ {stop, closed, get(?STATE_KEY)};
+abort_on_driver_error(Reply, From) ->
+ ?GEN_FSM:reply(From, Reply),
+ {next_state, session_established, get(?STATE_KEY)}.
%% == pure ODBC code
%% part of init/1
%% Open an ODBC database connection
-odbc_connect(SQLServer, StartInterval) ->
+odbc_connect(SQLServer) ->
application:start(odbc),
- case odbc:connect(SQLServer,[{scrollable_cursors, off}]) of
- {ok, Ref} ->
- erlang:monitor(process, Ref),
- {ok, #state{db_ref = Ref, db_type = odbc}};
- {error, Reason} ->
- ?ERROR_MSG("ODBC connection (~s) failed: ~p~n",
- [SQLServer, Reason]),
- %% If we can't connect we wait before retrying
- timer:sleep(StartInterval),
- {stop, odbc_connection_failed}
- end.
-
+ odbc:connect(SQLServer, [{scrollable_cursors, off}]).
%% == Native PostgreSQL code
%% part of init/1
%% Open a database connection to PostgreSQL
-pgsql_connect(Server, Port, DB, Username, Password, StartInterval) ->
- case pgsql:connect(Server, DB, Username, Password, Port) of
- {ok, Ref} ->
- erlang:monitor(process, Ref),
- {ok, #state{db_ref = Ref, db_type = pgsql}};
- {error, Reason} ->
- ?ERROR_MSG("PostgreSQL connection failed: ~p~n", [Reason]),
- %% If we can't connect we wait before retrying
- timer:sleep(StartInterval),
- {stop, pgsql_connection_failed}
- end.
+pgsql_connect(Server, Port, DB, Username, Password) ->
+ pgsql:connect(Server, DB, Username, Password, Port).
%% Convert PostgreSQL query result to Erlang ODBC result formalism
pgsql_to_odbc({ok, PGSQLResult}) ->
@@ -441,19 +477,13 @@ pgsql_item_to_odbc(_) ->
%% part of init/1
%% Open a database connection to MySQL
-mysql_connect(Server, Port, DB, Username, Password, StartInterval) ->
+mysql_connect(Server, Port, DB, Username, Password) ->
case mysql_conn:start(Server, Port, Username, Password, DB, fun log/3) of
{ok, Ref} ->
- erlang:monitor(process, Ref),
mysql_conn:fetch(Ref, ["set names 'utf8';"], self()),
- {ok, #state{db_ref = Ref, db_type = mysql}};
- {error, Reason} ->
- ?ERROR_MSG("MySQL connection failed: ~p~n"
- "Waiting ~p seconds before retrying...~n",
- [Reason, StartInterval div 1000]),
- %% If we can't connect we wait before retrying
- timer:sleep(StartInterval),
- {stop, mysql_connection_failed}
+ {ok, Ref};
+ Err ->
+ Err
end.
%% Convert MySQL query result to Erlang ODBC result formalism
@@ -475,11 +505,6 @@ mysql_item_to_odbc(Columns, Recs) ->
[element(2, Column) || Column <- Columns],
[list_to_tuple(Rec) || Rec <- Recs]}.
-% perform a harmless query on all opened connexions to avoid connexion close.
-keep_alive(PID) ->
- gen_server:call(PID, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}},
- ?KEEPALIVE_TIMEOUT).
-
% log function used by MySQL driver
log(Level, Format, Args) ->
case Level of
@@ -490,3 +515,35 @@ log(Level, Format, Args) ->
error ->
?ERROR_MSG(Format, Args)
end.
+
+db_opts(Host) ->
+ case ejabberd_config:get_local_option({odbc_server, Host}) of
+ %% Default pgsql port
+ {pgsql, Server, DB, User, Pass} ->
+ [pgsql, Server, ?PGSQL_PORT, DB, User, Pass];
+ {pgsql, Server, Port, DB, User, Pass} when is_integer(Port) ->
+ [pgsql, Server, Port, DB, User, Pass];
+ %% Default mysql port
+ {mysql, Server, DB, User, Pass} ->
+ [mysql, Server, ?MYSQL_PORT, DB, User, Pass];
+ {mysql, Server, Port, DB, User, Pass} when is_integer(Port) ->
+ [mysql, Server, Port, DB, User, Pass];
+ SQLServer when is_list(SQLServer) ->
+ [odbc, SQLServer]
+ end.
+
+max_fsm_queue() ->
+ case ejabberd_config:get_local_option(max_fsm_queue) of
+ N when is_integer(N), N>0 ->
+ N;
+ _ ->
+ undefined
+ end.
+
+fsm_limit_opts() ->
+ case max_fsm_queue() of
+ N when is_integer(N) ->
+ [{max_queue, N}];
+ _ ->
+ []
+ end.