aboutsummaryrefslogtreecommitdiff
path: root/src/ejabberd_sql.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/ejabberd_sql.erl')
-rw-r--r--src/ejabberd_sql.erl1025
1 files changed, 1025 insertions, 0 deletions
diff --git a/src/ejabberd_sql.erl b/src/ejabberd_sql.erl
new file mode 100644
index 000000000..4bee08f7e
--- /dev/null
+++ b/src/ejabberd_sql.erl
@@ -0,0 +1,1025 @@
+%%%----------------------------------------------------------------------
+%%% File : ejabberd_odbc.erl
+%%% Author : Alexey Shchepin <alexey@process-one.net>
+%%% Purpose : Serve ODBC connection
+%%% Created : 8 Dec 2004 by Alexey Shchepin <alexey@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
+%%%----------------------------------------------------------------------
+
+-module(ejabberd_sql).
+
+-behaviour(ejabberd_config).
+
+-author('alexey@process-one.net').
+
+-define(GEN_FSM, p1_fsm).
+
+-behaviour(?GEN_FSM).
+
+%% External exports
+-export([start/1, start_link/2,
+ sql_query/2,
+ sql_query_t/1,
+ sql_transaction/2,
+ sql_bloc/2,
+ escape/1,
+ escape_like/1,
+ escape_like_arg/1,
+ to_bool/1,
+ sqlite_db/1,
+ sqlite_file/1,
+ encode_term/1,
+ decode_term/1,
+ odbc_config/0,
+ freetds_config/0,
+ odbcinst_config/0,
+ init_mssql/1,
+ keep_alive/1]).
+
+%% gen_fsm callbacks
+-export([init/1, handle_event/3, handle_sync_event/4,
+ handle_info/3, terminate/3, print_state/1,
+ code_change/4]).
+
+-export([connecting/2, connecting/3,
+ session_established/2, session_established/3,
+ opt_type/1]).
+
+-include("ejabberd.hrl").
+-include("logger.hrl").
+-include("ejabberd_sql_pt.hrl").
+
+-record(state,
+ {db_ref = self() :: pid(),
+ db_type = odbc :: pgsql | mysql | sqlite | odbc | mssql,
+ db_version = undefined :: undefined | non_neg_integer(),
+ start_interval = 0 :: non_neg_integer(),
+ host = <<"">> :: binary(),
+ max_pending_requests_len :: non_neg_integer(),
+ pending_requests = {0, queue:new()} :: {non_neg_integer(), ?TQUEUE}}).
+
+-define(STATE_KEY, ejabberd_sql_state).
+
+-define(NESTING_KEY, ejabberd_sql_nesting_level).
+
+-define(TOP_LEVEL_TXN, 0).
+
+-define(PGSQL_PORT, 5432).
+
+-define(MYSQL_PORT, 3306).
+
+-define(MSSQL_PORT, 1433).
+
+-define(MAX_TRANSACTION_RESTARTS, 10).
+
+-define(TRANSACTION_TIMEOUT, 60000).
+
+-define(KEEPALIVE_TIMEOUT, 60000).
+
+-define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]).
+
+-define(PREPARE_KEY, ejabberd_sql_prepare).
+
+%%-define(DBGFSM, true).
+
+-ifdef(DBGFSM).
+
+-define(FSMOPTS, [{debug, [trace]}]).
+
+-else.
+
+-define(FSMOPTS, []).
+
+-endif.
+
+%%%----------------------------------------------------------------------
+%%% API
+%%%----------------------------------------------------------------------
+start(Host) ->
+ (?GEN_FSM):start(ejabberd_sql, [Host],
+ fsm_limit_opts() ++ (?FSMOPTS)).
+
+start_link(Host, StartInterval) ->
+ (?GEN_FSM):start_link(ejabberd_sql,
+ [Host, StartInterval],
+ fsm_limit_opts() ++ (?FSMOPTS)).
+
+-type sql_query() :: [sql_query() | binary()] | #sql_query{} |
+ fun(() -> any()) | fun((atom(), _) -> any()).
+-type sql_query_result() :: {updated, non_neg_integer()} |
+ {error, binary()} |
+ {selected, [binary()],
+ [[binary()]]} |
+ {selected, [any()]}.
+
+-spec sql_query(binary(), sql_query()) -> sql_query_result().
+
+sql_query(Host, Query) ->
+ check_error(sql_call(Host, {sql_query, Query}), Query).
+
+%% SQL transaction based on a list of queries
+%% This function automatically
+-spec sql_transaction(binary(), [sql_query()] | fun(() -> any())) ->
+ {atomic, any()} |
+ {aborted, any()}.
+
+sql_transaction(Host, Queries)
+ when is_list(Queries) ->
+ F = fun () ->
+ lists:foreach(fun (Query) -> sql_query_t(Query) end,
+ Queries)
+ end,
+ sql_transaction(Host, F);
+%% SQL transaction, based on a erlang anonymous function (F = fun)
+sql_transaction(Host, F) when is_function(F) ->
+ sql_call(Host, {sql_transaction, F}).
+
+%% SQL bloc, based on a erlang anonymous function (F = fun)
+sql_bloc(Host, F) -> sql_call(Host, {sql_bloc, F}).
+
+sql_call(Host, Msg) ->
+ case get(?STATE_KEY) of
+ undefined ->
+ case ejabberd_sql_sup:get_random_pid(Host) of
+ none -> {error, <<"Unknown Host">>};
+ Pid ->
+ (?GEN_FSM):sync_send_event(Pid,{sql_cmd, Msg,
+ p1_time_compat:monotonic_time(milli_seconds)},
+ ?TRANSACTION_TIMEOUT)
+ end;
+ _State -> nested_op(Msg)
+ end.
+
+keep_alive(PID) ->
+ (?GEN_FSM):sync_send_event(PID,
+ {sql_cmd, {sql_query, ?KEEPALIVE_QUERY},
+ p1_time_compat:monotonic_time(milli_seconds)},
+ ?KEEPALIVE_TIMEOUT).
+
+-spec sql_query_t(sql_query()) -> sql_query_result().
+
+%% This function is intended to be used from inside an sql_transaction:
+sql_query_t(Query) ->
+ QRes = sql_query_internal(Query),
+ case QRes of
+ {error, Reason} -> throw({aborted, Reason});
+ Rs when is_list(Rs) ->
+ case lists:keysearch(error, 1, Rs) of
+ {value, {error, Reason}} -> throw({aborted, Reason});
+ _ -> QRes
+ end;
+ _ -> QRes
+ end.
+
+%% Escape character that will confuse an SQL engine
+escape(S) ->
+ << <<(sql_queries:escape(Char))/binary>> || <<Char>> <= S >>.
+
+%% Escape character that will confuse an SQL engine
+%% Percent and underscore only need to be escaped for pattern matching like
+%% statement
+escape_like(S) when is_binary(S) ->
+ << <<(escape_like(C))/binary>> || <<C>> <= S >>;
+escape_like($%) -> <<"\\%">>;
+escape_like($_) -> <<"\\_">>;
+escape_like(C) when is_integer(C), C >= 0, C =< 255 -> sql_queries:escape(C).
+
+escape_like_arg(S) when is_binary(S) ->
+ << <<(escape_like_arg(C))/binary>> || <<C>> <= S >>;
+escape_like_arg($%) -> <<"\\%">>;
+escape_like_arg($_) -> <<"\\_">>;
+escape_like_arg($\\) -> <<"\\\\">>;
+escape_like_arg(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>.
+
+to_bool(<<"t">>) -> true;
+to_bool(<<"true">>) -> true;
+to_bool(<<"1">>) -> true;
+to_bool(true) -> true;
+to_bool(1) -> true;
+to_bool(_) -> false.
+
+encode_term(Term) ->
+ escape(list_to_binary(
+ erl_prettypr:format(erl_syntax:abstract(Term),
+ [{paper, 65535}, {ribbon, 65535}]))).
+
+decode_term(Bin) ->
+ Str = binary_to_list(<<Bin/binary, ".">>),
+ {ok, Tokens, _} = erl_scan:string(Str),
+ {ok, Term} = erl_parse:parse_term(Tokens),
+ Term.
+
+-spec sqlite_db(binary()) -> atom().
+sqlite_db(Host) ->
+ list_to_atom("ejabberd_sqlite_" ++ binary_to_list(Host)).
+
+-spec sqlite_file(binary()) -> string().
+sqlite_file(Host) ->
+ case ejabberd_config:get_option({sql_database, Host},
+ fun iolist_to_binary/1) of
+ undefined ->
+ {ok, Cwd} = file:get_cwd(),
+ filename:join([Cwd, "sqlite", atom_to_list(node()),
+ binary_to_list(Host), "ejabberd.db"]);
+ File ->
+ binary_to_list(File)
+ end.
+
+%%%----------------------------------------------------------------------
+%%% Callback functions from gen_fsm
+%%%----------------------------------------------------------------------
+init([Host, StartInterval]) ->
+ case ejabberd_config:get_option(
+ {sql_keepalive_interval, Host},
+ fun(I) when is_integer(I), I>0 -> I end) of
+ undefined ->
+ ok;
+ KeepaliveInterval ->
+ timer:apply_interval(KeepaliveInterval * 1000, ?MODULE,
+ keep_alive, [self()])
+ end,
+ [DBType | _] = db_opts(Host),
+ (?GEN_FSM):send_event(self(), connect),
+ ejabberd_sql_sup:add_pid(Host, self()),
+ {ok, connecting,
+ #state{db_type = DBType, host = Host,
+ max_pending_requests_len = max_fsm_queue(),
+ pending_requests = {0, queue:new()},
+ start_interval = StartInterval}}.
+
+connecting(connect, #state{host = Host} = State) ->
+ ConnectRes = case db_opts(Host) of
+ [mysql | Args] -> apply(fun mysql_connect/5, Args);
+ [pgsql | Args] -> apply(fun pgsql_connect/5, Args);
+ [sqlite | Args] -> apply(fun sqlite_connect/1, Args);
+ [mssql | Args] -> apply(fun odbc_connect/1, Args);
+ [odbc | Args] -> apply(fun odbc_connect/1, Args)
+ end,
+ {_, PendingRequests} = State#state.pending_requests,
+ case ConnectRes of
+ {ok, Ref} ->
+ erlang:monitor(process, Ref),
+ lists:foreach(fun (Req) ->
+ (?GEN_FSM):send_event(self(), Req)
+ end,
+ queue:to_list(PendingRequests)),
+ State1 = State#state{db_ref = Ref,
+ pending_requests = {0, queue:new()}},
+ State2 = get_db_version(State1),
+ {next_state, session_established, State2};
+ {error, Reason} ->
+ ?INFO_MSG("~p connection failed:~n** Reason: ~p~n** "
+ "Retry after: ~p seconds",
+ [State#state.db_type, Reason,
+ State#state.start_interval div 1000]),
+ (?GEN_FSM):send_event_after(State#state.start_interval,
+ connect),
+ {next_state, connecting, State}
+ end;
+connecting(Event, State) ->
+ ?WARNING_MSG("unexpected event in 'connecting': ~p",
+ [Event]),
+ {next_state, connecting, State}.
+
+connecting({sql_cmd, {sql_query, ?KEEPALIVE_QUERY},
+ _Timestamp},
+ From, State) ->
+ (?GEN_FSM):reply(From,
+ {error, <<"SQL connection failed">>}),
+ {next_state, connecting, State};
+connecting({sql_cmd, Command, Timestamp} = Req, From,
+ State) ->
+ ?DEBUG("queuing pending request while connecting:~n\t~p",
+ [Req]),
+ {Len, PendingRequests} = State#state.pending_requests,
+ NewPendingRequests = if Len <
+ State#state.max_pending_requests_len ->
+ {Len + 1,
+ queue:in({sql_cmd, Command, From, Timestamp},
+ PendingRequests)};
+ true ->
+ lists:foreach(fun ({sql_cmd, _, To,
+ _Timestamp}) ->
+ (?GEN_FSM):reply(To,
+ {error,
+ <<"SQL connection failed">>})
+ end,
+ queue:to_list(PendingRequests)),
+ {1,
+ queue:from_list([{sql_cmd, Command, From,
+ Timestamp}])}
+ end,
+ {next_state, connecting,
+ State#state{pending_requests = NewPendingRequests}};
+connecting(Request, {Who, _Ref}, State) ->
+ ?WARNING_MSG("unexpected call ~p from ~p in 'connecting'",
+ [Request, Who]),
+ {reply, {error, badarg}, connecting, State}.
+
+session_established({sql_cmd, Command, Timestamp}, From,
+ State) ->
+ run_sql_cmd(Command, From, State, Timestamp);
+session_established(Request, {Who, _Ref}, State) ->
+ ?WARNING_MSG("unexpected call ~p from ~p in 'session_establ"
+ "ished'",
+ [Request, Who]),
+ {reply, {error, badarg}, session_established, State}.
+
+session_established({sql_cmd, Command, From, Timestamp},
+ State) ->
+ run_sql_cmd(Command, From, State, Timestamp);
+session_established(Event, State) ->
+ ?WARNING_MSG("unexpected event in 'session_established': ~p",
+ [Event]),
+ {next_state, session_established, State}.
+
+handle_event(_Event, StateName, State) ->
+ {next_state, StateName, State}.
+
+handle_sync_event(_Event, _From, StateName, State) ->
+ {reply, {error, badarg}, StateName, State}.
+
+code_change(_OldVsn, StateName, State, _Extra) ->
+ {ok, StateName, State}.
+
+%% We receive the down signal when we loose the MySQL connection (we are
+%% monitoring the connection)
+handle_info({'DOWN', _MonitorRef, process, _Pid, _Info},
+ _StateName, State) ->
+ (?GEN_FSM):send_event(self(), connect),
+ {next_state, connecting, State};
+handle_info(Info, StateName, State) ->
+ ?WARNING_MSG("unexpected info in ~p: ~p",
+ [StateName, Info]),
+ {next_state, StateName, State}.
+
+terminate(_Reason, _StateName, State) ->
+ ejabberd_sql_sup:remove_pid(State#state.host, self()),
+ case State#state.db_type of
+ mysql -> catch p1_mysql_conn:stop(State#state.db_ref);
+ sqlite -> catch sqlite3:close(sqlite_db(State#state.host));
+ _ -> ok
+ end,
+ ok.
+
+%%----------------------------------------------------------------------
+%% Func: print_state/1
+%% Purpose: Prepare the state to be printed on error log
+%% Returns: State to print
+%%----------------------------------------------------------------------
+print_state(State) -> State.
+
+%%%----------------------------------------------------------------------
+%%% Internal functions
+%%%----------------------------------------------------------------------
+
+run_sql_cmd(Command, From, State, Timestamp) ->
+ case p1_time_compat:monotonic_time(milli_seconds) - Timestamp of
+ Age when Age < (?TRANSACTION_TIMEOUT) ->
+ put(?NESTING_KEY, ?TOP_LEVEL_TXN),
+ put(?STATE_KEY, State),
+ abort_on_driver_error(outer_op(Command), From);
+ Age ->
+ ?ERROR_MSG("Database was not available or too slow, "
+ "discarding ~p milliseconds old request~n~p~n",
+ [Age, Command]),
+ {next_state, session_established, State}
+ end.
+
+%% Only called by handle_call, only handles top level operations.
+%% @spec outer_op(Op) -> {error, Reason} | {aborted, Reason} | {atomic, Result}
+outer_op({sql_query, Query}) ->
+ sql_query_internal(Query);
+outer_op({sql_transaction, F}) ->
+ outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, <<"">>);
+outer_op({sql_bloc, F}) -> execute_bloc(F).
+
+%% Called via sql_query/transaction/bloc from client code when inside a
+%% nested operation
+nested_op({sql_query, Query}) ->
+ sql_query_internal(Query);
+nested_op({sql_transaction, F}) ->
+ NestingLevel = get(?NESTING_KEY),
+ if NestingLevel =:= (?TOP_LEVEL_TXN) ->
+ outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, <<"">>);
+ true -> inner_transaction(F)
+ end;
+nested_op({sql_bloc, F}) -> execute_bloc(F).
+
+%% Never retry nested transactions - only outer transactions
+inner_transaction(F) ->
+ PreviousNestingLevel = get(?NESTING_KEY),
+ case get(?NESTING_KEY) of
+ ?TOP_LEVEL_TXN ->
+ {backtrace, T} = process_info(self(), backtrace),
+ ?ERROR_MSG("inner transaction called at outer txn "
+ "level. Trace: ~s",
+ [T]),
+ erlang:exit(implementation_faulty);
+ _N -> ok
+ end,
+ put(?NESTING_KEY, PreviousNestingLevel + 1),
+ Result = (catch F()),
+ put(?NESTING_KEY, PreviousNestingLevel),
+ case Result of
+ {aborted, Reason} -> {aborted, Reason};
+ {'EXIT', Reason} -> {'EXIT', Reason};
+ {atomic, Res} -> {atomic, Res};
+ Res -> {atomic, Res}
+ end.
+
+outer_transaction(F, NRestarts, _Reason) ->
+ PreviousNestingLevel = get(?NESTING_KEY),
+ case get(?NESTING_KEY) of
+ ?TOP_LEVEL_TXN -> ok;
+ _N ->
+ {backtrace, T} = process_info(self(), backtrace),
+ ?ERROR_MSG("outer transaction called at inner txn "
+ "level. Trace: ~s",
+ [T]),
+ erlang:exit(implementation_faulty)
+ end,
+ sql_query_internal([<<"begin;">>]),
+ put(?NESTING_KEY, PreviousNestingLevel + 1),
+ Result = (catch F()),
+ put(?NESTING_KEY, PreviousNestingLevel),
+ case Result of
+ {aborted, Reason} when NRestarts > 0 ->
+ sql_query_internal([<<"rollback;">>]),
+ outer_transaction(F, NRestarts - 1, Reason);
+ {aborted, Reason} when NRestarts =:= 0 ->
+ ?ERROR_MSG("SQL transaction restarts exceeded~n** "
+ "Restarts: ~p~n** Last abort reason: "
+ "~p~n** Stacktrace: ~p~n** When State "
+ "== ~p",
+ [?MAX_TRANSACTION_RESTARTS, Reason,
+ erlang:get_stacktrace(), get(?STATE_KEY)]),
+ sql_query_internal([<<"rollback;">>]),
+ {aborted, Reason};
+ {'EXIT', Reason} ->
+ sql_query_internal([<<"rollback;">>]), {aborted, Reason};
+ Res -> sql_query_internal([<<"commit;">>]), {atomic, Res}
+ end.
+
+execute_bloc(F) ->
+ case catch F() of
+ {aborted, Reason} -> {aborted, Reason};
+ {'EXIT', Reason} -> {aborted, Reason};
+ Res -> {atomic, Res}
+ end.
+
+execute_fun(F) when is_function(F, 0) ->
+ F();
+execute_fun(F) when is_function(F, 2) ->
+ State = get(?STATE_KEY),
+ F(State#state.db_type, State#state.db_version).
+
+sql_query_internal([{_, _} | _] = Queries) ->
+ State = get(?STATE_KEY),
+ case select_sql_query(Queries, State) of
+ undefined ->
+ {error, <<"no matching query for the current DBMS found">>};
+ Query ->
+ sql_query_internal(Query)
+ end;
+sql_query_internal(#sql_query{} = Query) ->
+ State = get(?STATE_KEY),
+ Res =
+ try
+ case State#state.db_type of
+ odbc ->
+ generic_sql_query(Query);
+ mssql ->
+ generic_sql_query(Query);
+ pgsql ->
+ Key = {?PREPARE_KEY, Query#sql_query.hash},
+ case get(Key) of
+ undefined ->
+ case pgsql_prepare(Query, State) of
+ {ok, _, _, _} ->
+ put(Key, prepared);
+ {error, Error} ->
+ ?ERROR_MSG("PREPARE failed for SQL query "
+ "at ~p: ~p",
+ [Query#sql_query.loc, Error]),
+ put(Key, ignore)
+ end;
+ _ ->
+ ok
+ end,
+ case get(Key) of
+ prepared ->
+ pgsql_execute_sql_query(Query, State);
+ _ ->
+ generic_sql_query(Query)
+ end;
+ mysql ->
+ generic_sql_query(Query);
+ sqlite ->
+ generic_sql_query(Query)
+ end
+ catch
+ Class:Reason ->
+ ST = erlang:get_stacktrace(),
+ ?ERROR_MSG("Internal error while processing SQL query: ~p",
+ [{Class, Reason, ST}]),
+ {error, <<"internal error">>}
+ end,
+ case Res of
+ {error, <<"No SQL-driver information available.">>} ->
+ {updated, 0};
+ _Else -> Res
+ end;
+sql_query_internal(F) when is_function(F) ->
+ case catch execute_fun(F) of
+ {'EXIT', Reason} -> {error, Reason};
+ Res -> Res
+ end;
+sql_query_internal(Query) ->
+ State = get(?STATE_KEY),
+ ?DEBUG("SQL: \"~s\"", [Query]),
+ Res = case State#state.db_type of
+ odbc ->
+ to_odbc(odbc:sql_query(State#state.db_ref, [Query],
+ (?TRANSACTION_TIMEOUT) - 1000));
+ mssql ->
+ to_odbc(odbc:sql_query(State#state.db_ref, [Query],
+ (?TRANSACTION_TIMEOUT) - 1000));
+ pgsql ->
+ pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query));
+ mysql ->
+ R = mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref,
+ [Query], self(),
+ [{timeout, (?TRANSACTION_TIMEOUT) - 1000},
+ {result_type, binary}])),
+ %% ?INFO_MSG("MySQL, Received result~n~p~n", [R]),
+ R;
+ sqlite ->
+ Host = State#state.host,
+ sqlite_to_odbc(Host, sqlite3:sql_exec(sqlite_db(Host), Query))
+ end,
+ case Res of
+ {error, <<"No SQL-driver information available.">>} ->
+ {updated, 0};
+ _Else -> Res
+ end.
+
+select_sql_query(Queries, State) ->
+ select_sql_query(
+ Queries, State#state.db_type, State#state.db_version, undefined).
+
+select_sql_query([], _Type, _Version, undefined) ->
+ undefined;
+select_sql_query([], _Type, _Version, Query) ->
+ Query;
+select_sql_query([{any, Query} | _], _Type, _Version, _) ->
+ Query;
+select_sql_query([{Type, Query} | _], Type, _Version, _) ->
+ Query;
+select_sql_query([{{Type, _Version1}, Query1} | Rest], Type, undefined, _) ->
+ select_sql_query(Rest, Type, undefined, Query1);
+select_sql_query([{{Type, Version1}, Query1} | Rest], Type, Version, Query) ->
+ if
+ Version >= Version1 ->
+ Query1;
+ true ->
+ select_sql_query(Rest, Type, Version, Query)
+ end;
+select_sql_query([{_, _} | Rest], Type, Version, Query) ->
+ select_sql_query(Rest, Type, Version, Query).
+
+generic_sql_query(SQLQuery) ->
+ sql_query_format_res(
+ sql_query_internal(generic_sql_query_format(SQLQuery)),
+ SQLQuery).
+
+generic_sql_query_format(SQLQuery) ->
+ Args = (SQLQuery#sql_query.args)(generic_escape()),
+ (SQLQuery#sql_query.format_query)(Args).
+
+generic_escape() ->
+ #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
+ integer = fun(X) -> integer_to_binary(X) end,
+ boolean = fun(true) -> <<"1">>;
+ (false) -> <<"0">>
+ end
+ }.
+
+pgsql_prepare(SQLQuery, State) ->
+ Escape = #sql_escape{_ = fun(X) -> X end},
+ N = length((SQLQuery#sql_query.args)(Escape)),
+ Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)],
+ Query = (SQLQuery#sql_query.format_query)(Args),
+ pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query).
+
+pgsql_execute_escape() ->
+ #sql_escape{string = fun(X) -> X end,
+ integer = fun(X) -> [integer_to_binary(X)] end,
+ boolean = fun(true) -> "1";
+ (false) -> "0"
+ end
+ }.
+
+pgsql_execute_sql_query(SQLQuery, State) ->
+ Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()),
+ ExecuteRes =
+ pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args),
+% {T, ExecuteRes} =
+% timer:tc(pgsql, execute, [State#state.db_ref, SQLQuery#sql_query.hash, Args]),
+% io:format("T ~s ~p~n", [SQLQuery#sql_query.hash, T]),
+ Res = pgsql_execute_to_odbc(ExecuteRes),
+ sql_query_format_res(Res, SQLQuery).
+
+
+sql_query_format_res({selected, _, Rows}, SQLQuery) ->
+ Res =
+ lists:flatmap(
+ fun(Row) ->
+ try
+ [(SQLQuery#sql_query.format_res)(Row)]
+ catch
+ Class:Reason ->
+ ST = erlang:get_stacktrace(),
+ ?ERROR_MSG("Error while processing "
+ "SQL query result: ~p~n"
+ "row: ~p",
+ [{Class, Reason, ST}, Row]),
+ []
+ end
+ end, Rows),
+ {selected, Res};
+sql_query_format_res(Res, _SQLQuery) ->
+ Res.
+
+%% Generate the OTP callback return tuple depending on the driver result.
+abort_on_driver_error({error, <<"query timed out">>} =
+ Reply,
+ From) ->
+ (?GEN_FSM):reply(From, Reply),
+ {stop, timeout, get(?STATE_KEY)};
+abort_on_driver_error({error,
+ <<"Failed sending data on socket", _/binary>>} =
+ Reply,
+ From) ->
+ (?GEN_FSM):reply(From, Reply),
+ {stop, closed, get(?STATE_KEY)};
+abort_on_driver_error(Reply, From) ->
+ (?GEN_FSM):reply(From, Reply),
+ {next_state, session_established, get(?STATE_KEY)}.
+
+%% == pure ODBC code
+
+%% part of init/1
+%% Open an ODBC database connection
+odbc_connect(SQLServer) ->
+ ejabberd:start_app(odbc),
+ odbc:connect(binary_to_list(SQLServer),
+ [{scrollable_cursors, off},
+ {tuple_row, off},
+ {binary_strings, on}]).
+
+%% == Native SQLite code
+
+%% part of init/1
+%% Open a database connection to SQLite
+
+sqlite_connect(Host) ->
+ File = sqlite_file(Host),
+ case filelib:ensure_dir(File) of
+ ok ->
+ case sqlite3:open(sqlite_db(Host), [{file, File}]) of
+ {ok, Ref} ->
+ sqlite3:sql_exec(
+ sqlite_db(Host), "pragma foreign_keys = on"),
+ {ok, Ref};
+ {error, {already_started, Ref}} ->
+ {ok, Ref};
+ {error, Reason} ->
+ {error, Reason}
+ end;
+ Err ->
+ Err
+ end.
+
+%% Convert SQLite query result to Erlang ODBC result formalism
+sqlite_to_odbc(Host, ok) ->
+ {updated, sqlite3:changes(sqlite_db(Host))};
+sqlite_to_odbc(Host, {rowid, _}) ->
+ {updated, sqlite3:changes(sqlite_db(Host))};
+sqlite_to_odbc(_Host, [{columns, Columns}, {rows, TRows}]) ->
+ Rows = [lists:map(
+ fun(I) when is_integer(I) ->
+ jlib:integer_to_binary(I);
+ (B) ->
+ B
+ end, tuple_to_list(Row)) || Row <- TRows],
+ {selected, [list_to_binary(C) || C <- Columns], Rows};
+sqlite_to_odbc(_Host, {error, _Code, Reason}) ->
+ {error, Reason};
+sqlite_to_odbc(_Host, _) ->
+ {updated, undefined}.
+
+%% == Native PostgreSQL code
+
+%% part of init/1
+%% Open a database connection to PostgreSQL
+pgsql_connect(Server, Port, DB, Username, Password) ->
+ case pgsql:connect([{host, Server},
+ {database, DB},
+ {user, Username},
+ {password, Password},
+ {port, Port},
+ {as_binary, true}]) of
+ {ok, Ref} ->
+ pgsql:squery(Ref, [<<"alter database ">>, DB, <<" set ">>,
+ <<"standard_conforming_strings='off';">>]),
+ pgsql:squery(Ref, [<<"set standard_conforming_strings to 'off';">>]),
+ {ok, Ref};
+ Err ->
+ Err
+ end.
+
+%% Convert PostgreSQL query result to Erlang ODBC result formalism
+pgsql_to_odbc({ok, PGSQLResult}) ->
+ case PGSQLResult of
+ [Item] -> pgsql_item_to_odbc(Item);
+ Items -> [pgsql_item_to_odbc(Item) || Item <- Items]
+ end.
+
+pgsql_item_to_odbc({<<"SELECT", _/binary>>, Rows,
+ Recs}) ->
+ {selected, [element(1, Row) || Row <- Rows], Recs};
+pgsql_item_to_odbc({<<"FETCH", _/binary>>, Rows,
+ Recs}) ->
+ {selected, [element(1, Row) || Row <- Rows], Recs};
+pgsql_item_to_odbc(<<"INSERT ", OIDN/binary>>) ->
+ [_OID, N] = str:tokens(OIDN, <<" ">>),
+ {updated, jlib:binary_to_integer(N)};
+pgsql_item_to_odbc(<<"DELETE ", N/binary>>) ->
+ {updated, jlib:binary_to_integer(N)};
+pgsql_item_to_odbc(<<"UPDATE ", N/binary>>) ->
+ {updated, jlib:binary_to_integer(N)};
+pgsql_item_to_odbc({error, Error}) -> {error, Error};
+pgsql_item_to_odbc(_) -> {updated, undefined}.
+
+pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) ->
+ {selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]};
+pgsql_execute_to_odbc({ok, {'INSERT', N}}) ->
+ {updated, N};
+pgsql_execute_to_odbc({ok, {'DELETE', N}}) ->
+ {updated, N};
+pgsql_execute_to_odbc({ok, {'UPDATE', N}}) ->
+ {updated, N};
+pgsql_execute_to_odbc({error, Error}) -> {error, Error};
+pgsql_execute_to_odbc(_) -> {updated, undefined}.
+
+
+%% == Native MySQL code
+
+%% part of init/1
+%% Open a database connection to MySQL
+mysql_connect(Server, Port, DB, Username, Password) ->
+ case p1_mysql_conn:start(binary_to_list(Server), Port,
+ binary_to_list(Username),
+ binary_to_list(Password),
+ binary_to_list(DB), fun log/3)
+ of
+ {ok, Ref} ->
+ p1_mysql_conn:fetch(
+ Ref, [<<"set names 'utf8mb4' collate 'utf8mb4_bin';">>], self()),
+ {ok, Ref};
+ Err -> Err
+ end.
+
+%% Convert MySQL query result to Erlang ODBC result formalism
+mysql_to_odbc({updated, MySQLRes}) ->
+ {updated, p1_mysql:get_result_affected_rows(MySQLRes)};
+mysql_to_odbc({data, MySQLRes}) ->
+ mysql_item_to_odbc(p1_mysql:get_result_field_info(MySQLRes),
+ p1_mysql:get_result_rows(MySQLRes));
+mysql_to_odbc({error, MySQLRes})
+ when is_binary(MySQLRes) ->
+ {error, MySQLRes};
+mysql_to_odbc({error, MySQLRes})
+ when is_list(MySQLRes) ->
+ {error, list_to_binary(MySQLRes)};
+mysql_to_odbc({error, MySQLRes}) ->
+ {error, p1_mysql:get_result_reason(MySQLRes)};
+mysql_to_odbc(ok) ->
+ ok.
+
+
+%% When tabular data is returned, convert it to the ODBC formalism
+mysql_item_to_odbc(Columns, Recs) ->
+ {selected, [element(2, Column) || Column <- Columns], Recs}.
+
+to_odbc({selected, Columns, Recs}) ->
+ Rows = [lists:map(
+ fun(I) when is_integer(I) ->
+ jlib:integer_to_binary(I);
+ (B) ->
+ B
+ end, Row) || Row <- Recs],
+ {selected, [list_to_binary(C) || C <- Columns], Rows};
+to_odbc({error, Reason}) when is_list(Reason) ->
+ {error, list_to_binary(Reason)};
+to_odbc(Res) ->
+ Res.
+
+get_db_version(#state{db_type = pgsql} = State) ->
+ case pgsql:squery(State#state.db_ref,
+ <<"select current_setting('server_version_num')">>) of
+ {ok, [{_, _, [[SVersion]]}]} ->
+ case catch binary_to_integer(SVersion) of
+ Version when is_integer(Version) ->
+ State#state{db_version = Version};
+ Error ->
+ ?WARNING_MSG("error getting pgsql version: ~p", [Error]),
+ State
+ end;
+ Res ->
+ ?WARNING_MSG("error getting pgsql version: ~p", [Res]),
+ State
+ end;
+get_db_version(State) ->
+ State.
+
+log(Level, Format, Args) ->
+ case Level of
+ debug -> ?DEBUG(Format, Args);
+ normal -> ?INFO_MSG(Format, Args);
+ error -> ?ERROR_MSG(Format, Args)
+ end.
+
+db_opts(Host) ->
+ Type = ejabberd_config:get_option({sql_type, Host},
+ fun(mysql) -> mysql;
+ (pgsql) -> pgsql;
+ (sqlite) -> sqlite;
+ (mssql) -> mssql;
+ (odbc) -> odbc
+ end, odbc),
+ Server = ejabberd_config:get_option({sql_server, Host},
+ fun iolist_to_binary/1,
+ <<"localhost">>),
+ case Type of
+ odbc ->
+ [odbc, Server];
+ sqlite ->
+ [sqlite, Host];
+ _ ->
+ Port = ejabberd_config:get_option(
+ {sql_port, Host},
+ fun(P) when is_integer(P), P > 0, P < 65536 -> P end,
+ case Type of
+ mssql -> ?MSSQL_PORT;
+ mysql -> ?MYSQL_PORT;
+ pgsql -> ?PGSQL_PORT
+ end),
+ DB = ejabberd_config:get_option({sql_database, Host},
+ fun iolist_to_binary/1,
+ <<"ejabberd">>),
+ User = ejabberd_config:get_option({sql_username, Host},
+ fun iolist_to_binary/1,
+ <<"ejabberd">>),
+ Pass = ejabberd_config:get_option({sql_password, Host},
+ fun iolist_to_binary/1,
+ <<"">>),
+ case Type of
+ mssql ->
+ [mssql, <<"DSN=", Host/binary, ";UID=", User/binary,
+ ";PWD=", Pass/binary>>];
+ _ ->
+ [Type, Server, Port, DB, User, Pass]
+ end
+ end.
+
+init_mssql(Host) ->
+ Server = ejabberd_config:get_option({sql_server, Host},
+ fun iolist_to_binary/1,
+ <<"localhost">>),
+ Port = ejabberd_config:get_option(
+ {sql_port, Host},
+ fun(P) when is_integer(P), P > 0, P < 65536 -> P end,
+ ?MSSQL_PORT),
+ DB = ejabberd_config:get_option({sql_database, Host},
+ fun iolist_to_binary/1,
+ <<"ejabberd">>),
+ FreeTDS = io_lib:fwrite("[~s]~n"
+ "\thost = ~s~n"
+ "\tport = ~p~n"
+ "\ttds version = 7.1~n",
+ [Host, Server, Port]),
+ ODBCINST = io_lib:fwrite("[freetds]~n"
+ "Description = MSSQL connection~n"
+ "Driver = libtdsodbc.so~n"
+ "Setup = libtdsS.so~n"
+ "UsageCount = 1~n"
+ "FileUsage = 1~n", []),
+ ODBCINI = io_lib:fwrite("[~s]~n"
+ "Description = MS SQL~n"
+ "Driver = freetds~n"
+ "Servername = ~s~n"
+ "Database = ~s~n"
+ "Port = ~p~n",
+ [Host, Host, DB, Port]),
+ ?DEBUG("~s:~n~s", [freetds_config(), FreeTDS]),
+ ?DEBUG("~s:~n~s", [odbcinst_config(), ODBCINST]),
+ ?DEBUG("~s:~n~s", [odbc_config(), ODBCINI]),
+ case filelib:ensure_dir(freetds_config()) of
+ ok ->
+ try
+ ok = file:write_file(freetds_config(), FreeTDS, [append]),
+ ok = file:write_file(odbcinst_config(), ODBCINST),
+ ok = file:write_file(odbc_config(), ODBCINI, [append]),
+ os:putenv("ODBCSYSINI", tmp_dir()),
+ os:putenv("FREETDS", freetds_config()),
+ os:putenv("FREETDSCONF", freetds_config()),
+ ok
+ catch error:{badmatch, {error, Reason} = Err} ->
+ ?ERROR_MSG("failed to create temporary files in ~s: ~s",
+ [tmp_dir(), file:format_error(Reason)]),
+ Err
+ end;
+ {error, Reason} = Err ->
+ ?ERROR_MSG("failed to create temporary directory ~s: ~s",
+ [tmp_dir(), file:format_error(Reason)]),
+ Err
+ end.
+
+tmp_dir() ->
+ filename:join(["/tmp", "ejabberd"]).
+
+odbc_config() ->
+ filename:join(tmp_dir(), "odbc.ini").
+
+freetds_config() ->
+ filename:join(tmp_dir(), "freetds.conf").
+
+odbcinst_config() ->
+ filename:join(tmp_dir(), "odbcinst.ini").
+
+max_fsm_queue() ->
+ ejabberd_config:get_option(
+ max_fsm_queue,
+ fun(N) when is_integer(N), N > 0 -> N end).
+
+fsm_limit_opts() ->
+ case max_fsm_queue() of
+ N when is_integer(N) -> [{max_queue, N}];
+ _ -> []
+ end.
+
+check_error({error, Why} = Err, #sql_query{} = Query) ->
+ ?ERROR_MSG("SQL query '~s' at ~p failed: ~p",
+ [Query#sql_query.hash, Query#sql_query.loc, Why]),
+ Err;
+check_error({error, Why} = Err, Query) ->
+ case catch iolist_to_binary(Query) of
+ SQuery when is_binary(SQuery) ->
+ ?ERROR_MSG("SQL query '~s' failed: ~p", [SQuery, Why]);
+ _ ->
+ ?ERROR_MSG("SQL query ~p failed: ~p", [Query, Why])
+ end,
+ Err;
+check_error(Result, _Query) ->
+ Result.
+
+opt_type(max_fsm_queue) ->
+ fun (N) when is_integer(N), N > 0 -> N end;
+opt_type(sql_database) -> fun iolist_to_binary/1;
+opt_type(sql_keepalive_interval) ->
+ fun (I) when is_integer(I), I > 0 -> I end;
+opt_type(sql_password) -> fun iolist_to_binary/1;
+opt_type(sql_port) ->
+ fun (P) when is_integer(P), P > 0, P < 65536 -> P end;
+opt_type(sql_server) -> fun iolist_to_binary/1;
+opt_type(sql_type) ->
+ fun (mysql) -> mysql;
+ (pgsql) -> pgsql;
+ (sqlite) -> sqlite;
+ (mssql) -> mssql;
+ (odbc) -> odbc
+ end;
+opt_type(sql_username) -> fun iolist_to_binary/1;
+opt_type(_) ->
+ [max_fsm_queue, sql_database, sql_keepalive_interval,
+ sql_password, sql_port, sql_server, sql_type,
+ sql_username].