diff options
Diffstat (limited to 'src/ejabberd_odbc.erl')
-rw-r--r-- | src/ejabberd_odbc.erl | 417 |
1 files changed, 372 insertions, 45 deletions
diff --git a/src/ejabberd_odbc.erl b/src/ejabberd_odbc.erl index 09f17a635..e98665862 100644 --- a/src/ejabberd_odbc.erl +++ b/src/ejabberd_odbc.erl @@ -5,7 +5,7 @@ %%% Created : 8 Dec 2004 by Alexey Shchepin <alexey@process-one.net> %%% %%% -%%% ejabberd, Copyright (C) 2002-2015 ProcessOne +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne %%% %%% This program is free software; you can redistribute it and/or %%% modify it under the terms of the GNU General Public License as @@ -25,6 +25,8 @@ -module(ejabberd_odbc). +-behaviour(ejabberd_config). + -author('alexey@process-one.net'). -define(GEN_FSM, p1_fsm). @@ -39,11 +41,16 @@ sql_bloc/2, escape/1, escape_like/1, + escape_like_arg/1, to_bool/1, sqlite_db/1, sqlite_file/1, - encode_term/1, - decode_term/1, + encode_term/1, + decode_term/1, + odbc_config/0, + freetds_config/0, + odbcinst_config/0, + init_mssql/1, keep_alive/1]). %% gen_fsm callbacks @@ -51,20 +58,22 @@ handle_info/3, terminate/3, print_state/1, code_change/4]). -%% gen_fsm states -export([connecting/2, connecting/3, - session_established/2, session_established/3]). + session_established/2, session_established/3, + opt_type/1]). -include("ejabberd.hrl"). -include("logger.hrl"). +-include("ejabberd_sql_pt.hrl"). -record(state, {db_ref = self() :: pid(), - db_type = odbc :: pgsql | mysql | sqlite | odbc, - start_interval = 0 :: non_neg_integer(), - host = <<"">> :: binary(), + db_type = odbc :: pgsql | mysql | sqlite | odbc | mssql, + db_version = undefined :: undefined | non_neg_integer(), + start_interval = 0 :: non_neg_integer(), + host = <<"">> :: binary(), max_pending_requests_len :: non_neg_integer(), - pending_requests = {0, queue:new()} :: {non_neg_integer(), ?TQUEUE}}). + pending_requests = {0, queue:new()} :: {non_neg_integer(), ?TQUEUE}}). -define(STATE_KEY, ejabberd_odbc_state). @@ -76,13 +85,17 @@ -define(MYSQL_PORT, 3306). +-define(MSSQL_PORT, 1433). + -define(MAX_TRANSACTION_RESTARTS, 10). -define(TRANSACTION_TIMEOUT, 60000). -define(KEEPALIVE_TIMEOUT, 60000). --define(KEEPALIVE_QUERY, <<"SELECT 1;">>). +-define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]). + +-define(PREPARE_KEY, ejabberd_odbc_prepare). %%-define(DBGFSM, true). @@ -108,16 +121,18 @@ start_link(Host, StartInterval) -> [Host, StartInterval], fsm_limit_opts() ++ (?FSMOPTS)). --type sql_query() :: [sql_query() | binary()]. +-type sql_query() :: [sql_query() | binary()] | #sql_query{} | + fun(() -> any()) | fun((atom(), _) -> any()). -type sql_query_result() :: {updated, non_neg_integer()} | {error, binary()} | {selected, [binary()], - [[binary()]]}. + [[binary()]]} | + {selected, [any()]}. -spec sql_query(binary(), sql_query()) -> sql_query_result(). sql_query(Host, Query) -> - sql_call(Host, {sql_query, Query}). + check_error(sql_call(Host, {sql_query, Query}), Query). %% SQL transaction based on a list of queries %% This function automatically @@ -145,7 +160,8 @@ sql_call(Host, Msg) -> case ejabberd_odbc_sup:get_random_pid(Host) of none -> {error, <<"Unknown Host">>}; Pid -> - (?GEN_FSM):sync_send_event(Pid,{sql_cmd, Msg, now()}, + (?GEN_FSM):sync_send_event(Pid,{sql_cmd, Msg, + p1_time_compat:monotonic_time(milli_seconds)}, ?TRANSACTION_TIMEOUT) end; _State -> nested_op(Msg) @@ -153,7 +169,8 @@ sql_call(Host, Msg) -> keep_alive(PID) -> (?GEN_FSM):sync_send_event(PID, - {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}, now()}, + {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}, + p1_time_compat:monotonic_time(milli_seconds)}, ?KEEPALIVE_TIMEOUT). -spec sql_query_t(sql_query()) -> sql_query_result(). @@ -184,6 +201,13 @@ escape_like($%) -> <<"\\%">>; escape_like($_) -> <<"\\_">>; escape_like(C) when is_integer(C), C >= 0, C =< 255 -> odbc_queries:escape(C). +escape_like_arg(S) when is_binary(S) -> + << <<(escape_like_arg(C))/binary>> || <<C>> <= S >>; +escape_like_arg($%) -> <<"\\%">>; +escape_like_arg($_) -> <<"\\_">>; +escape_like_arg($\\) -> <<"\\\\">>; +escape_like_arg(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>. + to_bool(<<"t">>) -> true; to_bool(<<"true">>) -> true; to_bool(<<"1">>) -> true; @@ -193,7 +217,8 @@ to_bool(_) -> false. encode_term(Term) -> escape(list_to_binary( - erl_prettypr:format(erl_syntax:abstract(Term)))). + erl_prettypr:format(erl_syntax:abstract(Term), + [{paper, 65535}, {ribbon, 65535}]))). decode_term(Bin) -> Str = binary_to_list(<<Bin/binary, ".">>), @@ -248,15 +273,16 @@ connecting(connect, #state{host = Host} = State) -> end, {_, PendingRequests} = State#state.pending_requests, case ConnectRes of - {ok, Ref} -> - erlang:monitor(process, Ref), - lists:foreach(fun (Req) -> - (?GEN_FSM):send_event(self(), Req) - end, - queue:to_list(PendingRequests)), - {next_state, session_established, - State#state{db_ref = Ref, - pending_requests = {0, queue:new()}}}; + {ok, Ref} -> + erlang:monitor(process, Ref), + lists:foreach(fun (Req) -> + (?GEN_FSM):send_event(self(), Req) + end, + queue:to_list(PendingRequests)), + State1 = State#state{db_ref = Ref, + pending_requests = {0, queue:new()}}, + State2 = get_db_version(State1), + {next_state, session_established, State2}; {error, Reason} -> ?INFO_MSG("~p connection failed:~n** Reason: ~p~n** " "Retry after: ~p seconds", @@ -364,7 +390,7 @@ print_state(State) -> State. %%%---------------------------------------------------------------------- run_sql_cmd(Command, From, State, Timestamp) -> - case timer:now_diff(now(), Timestamp) div 1000 of + case p1_time_compat:monotonic_time(milli_seconds) - Timestamp of Age when Age < (?TRANSACTION_TIMEOUT) -> put(?NESTING_KEY, ?TOP_LEVEL_TXN), put(?STATE_KEY, State), @@ -429,13 +455,13 @@ outer_transaction(F, NRestarts, _Reason) -> [T]), erlang:exit(implementation_faulty) end, - sql_query_internal(<<"begin;">>), + sql_query_internal([<<"begin;">>]), put(?NESTING_KEY, PreviousNestingLevel + 1), Result = (catch F()), put(?NESTING_KEY, PreviousNestingLevel), case Result of {aborted, Reason} when NRestarts > 0 -> - sql_query_internal(<<"rollback;">>), + sql_query_internal([<<"rollback;">>]), outer_transaction(F, NRestarts - 1, Reason); {aborted, Reason} when NRestarts =:= 0 -> ?ERROR_MSG("SQL transaction restarts exceeded~n** " @@ -444,11 +470,11 @@ outer_transaction(F, NRestarts, _Reason) -> "== ~p", [?MAX_TRANSACTION_RESTARTS, Reason, erlang:get_stacktrace(), get(?STATE_KEY)]), - sql_query_internal(<<"rollback;">>), + sql_query_internal([<<"rollback;">>]), {aborted, Reason}; {'EXIT', Reason} -> - sql_query_internal(<<"rollback;">>), {aborted, Reason}; - Res -> sql_query_internal(<<"commit;">>), {atomic, Res} + sql_query_internal([<<"rollback;">>]), {aborted, Reason}; + Res -> sql_query_internal([<<"commit;">>]), {atomic, Res} end. execute_bloc(F) -> @@ -458,8 +484,74 @@ execute_bloc(F) -> Res -> {atomic, Res} end. +execute_fun(F) when is_function(F, 0) -> + F(); +execute_fun(F) when is_function(F, 2) -> + State = get(?STATE_KEY), + F(State#state.db_type, State#state.db_version). + +sql_query_internal([{_, _} | _] = Queries) -> + State = get(?STATE_KEY), + case select_sql_query(Queries, State) of + undefined -> + {error, <<"no matching query for the current DBMS found">>}; + Query -> + sql_query_internal(Query) + end; +sql_query_internal(#sql_query{} = Query) -> + State = get(?STATE_KEY), + Res = + try + case State#state.db_type of + odbc -> + generic_sql_query(Query); + pgsql -> + Key = {?PREPARE_KEY, Query#sql_query.hash}, + case get(Key) of + undefined -> + case pgsql_prepare(Query, State) of + {ok, _, _, _} -> + put(Key, prepared); + {error, Error} -> + ?ERROR_MSG("PREPARE failed for SQL query " + "at ~p: ~p", + [Query#sql_query.loc, Error]), + put(Key, ignore) + end; + _ -> + ok + end, + case get(Key) of + prepared -> + pgsql_execute_sql_query(Query, State); + _ -> + generic_sql_query(Query) + end; + mysql -> + generic_sql_query(Query); + sqlite -> + generic_sql_query(Query) + end + catch + Class:Reason -> + ST = erlang:get_stacktrace(), + ?ERROR_MSG("Internal error while processing SQL query: ~p", + [{Class, Reason, ST}]), + {error, <<"internal error">>} + end, + case Res of + {error, <<"No SQL-driver information available.">>} -> + {updated, 0}; + _Else -> Res + end; +sql_query_internal(F) when is_function(F) -> + case catch execute_fun(F) of + {'EXIT', Reason} -> {error, Reason}; + Res -> Res + end; sql_query_internal(Query) -> State = get(?STATE_KEY), + ?DEBUG("SQL: \"~s\"", [Query]), Res = case State#state.db_type of odbc -> to_odbc(odbc:sql_query(State#state.db_ref, Query, @@ -467,10 +559,6 @@ sql_query_internal(Query) -> pgsql -> pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query)); mysql -> - ?DEBUG("MySQL, Send query~n~p~n", [Query]), - %%squery to be able to specify result_type = binary - %%[Query] because p1_mysql_conn expect query to be a list (elements can be binaries, or iolist) - %% but doesn't accept just a binary R = mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref, [Query], self(), [{timeout, (?TRANSACTION_TIMEOUT) - 1000}, @@ -487,6 +575,93 @@ sql_query_internal(Query) -> _Else -> Res end. +select_sql_query(Queries, State) -> + select_sql_query( + Queries, State#state.db_type, State#state.db_version, undefined). + +select_sql_query([], _Type, _Version, undefined) -> + undefined; +select_sql_query([], _Type, _Version, Query) -> + Query; +select_sql_query([{any, Query} | _], _Type, _Version, _) -> + Query; +select_sql_query([{Type, Query} | _], Type, _Version, _) -> + Query; +select_sql_query([{{Type, _Version1}, Query1} | Rest], Type, undefined, _) -> + select_sql_query(Rest, Type, undefined, Query1); +select_sql_query([{{Type, Version1}, Query1} | Rest], Type, Version, Query) -> + if + Version >= Version1 -> + Query1; + true -> + select_sql_query(Rest, Type, Version, Query) + end; +select_sql_query([{_, _} | Rest], Type, Version, Query) -> + select_sql_query(Rest, Type, Version, Query). + +generic_sql_query(SQLQuery) -> + sql_query_format_res( + sql_query_internal(generic_sql_query_format(SQLQuery)), + SQLQuery). + +generic_sql_query_format(SQLQuery) -> + Args = (SQLQuery#sql_query.args)(generic_escape()), + (SQLQuery#sql_query.format_query)(Args). + +generic_escape() -> + #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end, + integer = fun(X) -> integer_to_binary(X) end, + boolean = fun(true) -> <<"1">>; + (false) -> <<"0">> + end + }. + +pgsql_prepare(SQLQuery, State) -> + Escape = #sql_escape{_ = fun(X) -> X end}, + N = length((SQLQuery#sql_query.args)(Escape)), + Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)], + Query = (SQLQuery#sql_query.format_query)(Args), + pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query). + +pgsql_execute_escape() -> + #sql_escape{string = fun(X) -> X end, + integer = fun(X) -> [integer_to_binary(X)] end, + boolean = fun(true) -> "1"; + (false) -> "0" + end + }. + +pgsql_execute_sql_query(SQLQuery, State) -> + Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()), + ExecuteRes = + pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args), +% {T, ExecuteRes} = +% timer:tc(pgsql, execute, [State#state.db_ref, SQLQuery#sql_query.hash, Args]), +% io:format("T ~s ~p~n", [SQLQuery#sql_query.hash, T]), + Res = pgsql_execute_to_odbc(ExecuteRes), + sql_query_format_res(Res, SQLQuery). + + +sql_query_format_res({selected, _, Rows}, SQLQuery) -> + Res = + lists:flatmap( + fun(Row) -> + try + [(SQLQuery#sql_query.format_res)(Row)] + catch + Class:Reason -> + ST = erlang:get_stacktrace(), + ?ERROR_MSG("Error while processing " + "SQL query result: ~p~n" + "row: ~p", + [{Class, Reason, ST}, Row]), + [] + end + end, Rows), + {selected, Res}; +sql_query_format_res(Res, _SQLQuery) -> + Res. + %% Generate the OTP callback return tuple depending on the driver result. abort_on_driver_error({error, <<"query timed out">>} = Reply, @@ -509,7 +684,10 @@ abort_on_driver_error(Reply, From) -> %% Open an ODBC database connection odbc_connect(SQLServer) -> ejabberd:start_app(odbc), - odbc:connect(binary_to_list(SQLServer), [{scrollable_cursors, off}]). + odbc:connect(binary_to_list(SQLServer), + [{scrollable_cursors, off}, + {tuple_row, off}, + {binary_strings, on}]). %% == Native SQLite code @@ -595,20 +773,33 @@ pgsql_item_to_odbc(<<"UPDATE ", N/binary>>) -> pgsql_item_to_odbc({error, Error}) -> {error, Error}; pgsql_item_to_odbc(_) -> {updated, undefined}. +pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) -> + {selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]}; +pgsql_execute_to_odbc({ok, {'INSERT', N}}) -> + {updated, N}; +pgsql_execute_to_odbc({ok, {'DELETE', N}}) -> + {updated, N}; +pgsql_execute_to_odbc({ok, {'UPDATE', N}}) -> + {updated, N}; +pgsql_execute_to_odbc({error, Error}) -> {error, Error}; +pgsql_execute_to_odbc(_) -> {updated, undefined}. + + %% == Native MySQL code %% part of init/1 %% Open a database connection to MySQL mysql_connect(Server, Port, DB, Username, Password) -> case p1_mysql_conn:start(binary_to_list(Server), Port, - binary_to_list(Username), binary_to_list(Password), - binary_to_list(DB), fun log/3) + binary_to_list(Username), + binary_to_list(Password), + binary_to_list(DB), fun log/3) of - {ok, Ref} -> - p1_mysql_conn:fetch(Ref, [<<"set names 'utf8';">>], - self()), - {ok, Ref}; - Err -> Err + {ok, Ref} -> + p1_mysql_conn:fetch( + Ref, [<<"set names 'utf8mb4' collate 'utf8mb4_bin';">>], self()), + {ok, Ref}; + Err -> Err end. %% Convert MySQL query result to Erlang ODBC result formalism @@ -634,10 +825,36 @@ mysql_item_to_odbc(Columns, Recs) -> {selected, [element(2, Column) || Column <- Columns], Recs}. to_odbc({selected, Columns, Recs}) -> - {selected, Columns, [tuple_to_list(Rec) || Rec <- Recs]}; + Rows = [lists:map( + fun(I) when is_integer(I) -> + jlib:integer_to_binary(I); + (B) -> + B + end, Row) || Row <- Recs], + {selected, [list_to_binary(C) || C <- Columns], Rows}; +to_odbc({error, Reason}) when is_list(Reason) -> + {error, list_to_binary(Reason)}; to_odbc(Res) -> Res. +get_db_version(#state{db_type = pgsql} = State) -> + case pgsql:squery(State#state.db_ref, + <<"select current_setting('server_version_num')">>) of + {ok, [{_, _, [[SVersion]]}]} -> + case catch binary_to_integer(SVersion) of + Version when is_integer(Version) -> + State#state{db_version = Version}; + Error -> + ?WARNING_MSG("error getting pgsql version: ~p", [Error]), + State + end; + Res -> + ?WARNING_MSG("error getting pgsql version: ~p", [Res]), + State + end; +get_db_version(State) -> + State. + log(Level, Format, Args) -> case Level of debug -> ?DEBUG(Format, Args); @@ -650,6 +867,7 @@ db_opts(Host) -> fun(mysql) -> mysql; (pgsql) -> pgsql; (sqlite) -> sqlite; + (mssql) -> mssql; (odbc) -> odbc end, odbc), Server = ejabberd_config:get_option({odbc_server, Host}, @@ -665,6 +883,7 @@ db_opts(Host) -> {odbc_port, Host}, fun(P) when is_integer(P), P > 0, P < 65536 -> P end, case Type of + mssql -> ?MSSQL_PORT; mysql -> ?MYSQL_PORT; pgsql -> ?PGSQL_PORT end), @@ -677,9 +896,80 @@ db_opts(Host) -> Pass = ejabberd_config:get_option({odbc_password, Host}, fun iolist_to_binary/1, <<"">>), - [Type, Server, Port, DB, User, Pass] + case Type of + mssql -> + [odbc, <<"DSN=", Host/binary, ";UID=", User/binary, + ";PWD=", Pass/binary>>]; + _ -> + [Type, Server, Port, DB, User, Pass] + end end. +init_mssql(Host) -> + Server = ejabberd_config:get_option({odbc_server, Host}, + fun iolist_to_binary/1, + <<"localhost">>), + Port = ejabberd_config:get_option( + {odbc_port, Host}, + fun(P) when is_integer(P), P > 0, P < 65536 -> P end, + ?MSSQL_PORT), + DB = ejabberd_config:get_option({odbc_database, Host}, + fun iolist_to_binary/1, + <<"ejabberd">>), + FreeTDS = io_lib:fwrite("[~s]~n" + "\thost = ~s~n" + "\tport = ~p~n" + "\ttds version = 7.1~n", + [Host, Server, Port]), + ODBCINST = io_lib:fwrite("[freetds]~n" + "Description = MSSQL connection~n" + "Driver = libtdsodbc.so~n" + "Setup = libtdsS.so~n" + "UsageCount = 1~n" + "FileUsage = 1~n", []), + ODBCINI = io_lib:fwrite("[~s]~n" + "Description = MS SQL~n" + "Driver = freetds~n" + "Servername = ~s~n" + "Database = ~s~n" + "Port = ~p~n", + [Host, Host, DB, Port]), + ?DEBUG("~s:~n~s", [freetds_config(), FreeTDS]), + ?DEBUG("~s:~n~s", [odbcinst_config(), ODBCINST]), + ?DEBUG("~s:~n~s", [odbc_config(), ODBCINI]), + case filelib:ensure_dir(freetds_config()) of + ok -> + try + ok = file:write_file(freetds_config(), FreeTDS, [append]), + ok = file:write_file(odbcinst_config(), ODBCINST), + ok = file:write_file(odbc_config(), ODBCINI, [append]), + os:putenv("ODBCSYSINI", tmp_dir()), + os:putenv("FREETDS", freetds_config()), + os:putenv("FREETDSCONF", freetds_config()), + ok + catch error:{badmatch, {error, Reason} = Err} -> + ?ERROR_MSG("failed to create temporary files in ~s: ~s", + [tmp_dir(), file:format_error(Reason)]), + Err + end; + {error, Reason} = Err -> + ?ERROR_MSG("failed to create temporary directory ~s: ~s", + [tmp_dir(), file:format_error(Reason)]), + Err + end. + +tmp_dir() -> + filename:join(["/tmp", "ejabberd"]). + +odbc_config() -> + filename:join(tmp_dir(), "odbc.ini"). + +freetds_config() -> + filename:join(tmp_dir(), "freetds.conf"). + +odbcinst_config() -> + filename:join(tmp_dir(), "odbcinst.ini"). + max_fsm_queue() -> ejabberd_config:get_option( max_fsm_queue, @@ -690,3 +980,40 @@ fsm_limit_opts() -> N when is_integer(N) -> [{max_queue, N}]; _ -> [] end. + +check_error({error, Why} = Err, #sql_query{} = Query) -> + ?ERROR_MSG("SQL query '~s' at ~p failed: ~p", + [Query#sql_query.hash, Query#sql_query.loc, Why]), + Err; +check_error({error, Why} = Err, Query) -> + case catch iolist_to_binary(Query) of + SQuery when is_binary(SQuery) -> + ?ERROR_MSG("SQL query '~s' failed: ~p", [SQuery, Why]); + _ -> + ?ERROR_MSG("SQL query ~p failed: ~p", [Query, Why]) + end, + Err; +check_error(Result, _Query) -> + Result. + +opt_type(max_fsm_queue) -> + fun (N) when is_integer(N), N > 0 -> N end; +opt_type(odbc_database) -> fun iolist_to_binary/1; +opt_type(odbc_keepalive_interval) -> + fun (I) when is_integer(I), I > 0 -> I end; +opt_type(odbc_password) -> fun iolist_to_binary/1; +opt_type(odbc_port) -> + fun (P) when is_integer(P), P > 0, P < 65536 -> P end; +opt_type(odbc_server) -> fun iolist_to_binary/1; +opt_type(odbc_type) -> + fun (mysql) -> mysql; + (pgsql) -> pgsql; + (sqlite) -> sqlite; + (mssql) -> mssql; + (odbc) -> odbc + end; +opt_type(odbc_username) -> fun iolist_to_binary/1; +opt_type(_) -> + [max_fsm_queue, odbc_database, odbc_keepalive_interval, + odbc_password, odbc_port, odbc_server, odbc_type, + odbc_username]. |