aboutsummaryrefslogtreecommitdiff
path: root/src/ejabberd_sql_pt.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/ejabberd_sql_pt.erl')
-rw-r--r--src/ejabberd_sql_pt.erl405
1 files changed, 344 insertions, 61 deletions
diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl
index e51b7f928..0896b4b1a 100644
--- a/src/ejabberd_sql_pt.erl
+++ b/src/ejabberd_sql_pt.erl
@@ -2,17 +2,33 @@
%%% File : ejabberd_sql_pt.erl
%%% Author : Alexey Shchepin <alexey@process-one.net>
%%% Description : Parse transform for SQL queries
-%%%
%%% Created : 20 Jan 2016 by Alexey Shchepin <alexey@process-one.net>
-%%%-------------------------------------------------------------------
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2019 ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
+%%%----------------------------------------------------------------------
+
-module(ejabberd_sql_pt).
%% API
--export([parse_transform/2]).
-
--export([parse/2]).
+-export([parse_transform/2, format_error/1]).
--include("ejabberd_sql_pt.hrl").
+-include("ejabberd_sql.hrl").
-record(state, {loc,
'query' = [],
@@ -21,7 +37,11 @@
args = [],
res = [],
res_vars = [],
- res_pos = 0}).
+ res_pos = 0,
+ server_host_used = false,
+ used_vars = [],
+ use_new_schema,
+ need_array_pass = false}).
-define(QUERY_RECORD, "sql_query").
@@ -30,6 +50,12 @@
-define(MOD, sql__module_).
+-ifdef(NEW_SQL_SCHEMA).
+-define(USE_NEW_SCHEMA, true).
+-else.
+-define(USE_NEW_SCHEMA, false).
+-endif.
+
%%====================================================================
%% API
%%====================================================================
@@ -38,12 +64,13 @@
%% Description:
%%--------------------------------------------------------------------
parse_transform(AST, _Options) ->
- %io:format("PT: ~p~nOpts: ~p~n", [AST, Options]),
+ put(warnings, []),
NewAST = top_transform(AST),
- %io:format("NewPT: ~p~n", [NewAST]),
- NewAST.
+ NewAST ++ get(warnings).
+format_error(no_server_host) ->
+ "server_host field is not used".
%%====================================================================
%% Internal functions
@@ -59,10 +86,7 @@ transform(Form) ->
[Arg] ->
case erl_syntax:type(Arg) of
string ->
- S = erl_syntax:string_value(Arg),
- Pos = erl_syntax:get_pos(Arg),
- ParseRes = parse(S, Pos),
- set_pos(make_sql_query(ParseRes), Pos);
+ transform_sql(Arg);
_ ->
throw({error, erl_syntax:get_pos(Form),
"?SQL argument must be "
@@ -78,14 +102,7 @@ transform(Form) ->
case {erl_syntax:type(TableArg),
erl_syntax:is_proper_list(FieldsArg)}of
{string, true} ->
- Table = erl_syntax:string_value(TableArg),
- ParseRes =
- parse_upsert(
- erl_syntax:list_elements(FieldsArg)),
- Pos = erl_syntax:get_pos(Form),
- set_pos(
- make_sql_upsert(Table, ParseRes, Pos),
- Pos);
+ transform_upsert(Form, TableArg, FieldsArg);
_ ->
throw({error, erl_syntax:get_pos(Form),
"?SQL_UPSERT arguments must be "
@@ -95,6 +112,22 @@ transform(Form) ->
throw({error, erl_syntax:get_pos(Form),
"wrong number of ?SQL_UPSERT args"})
end;
+ {?SQL_INSERT_MARK, 2} ->
+ case erl_syntax:application_arguments(Form) of
+ [TableArg, FieldsArg] ->
+ case {erl_syntax:type(TableArg),
+ erl_syntax:is_proper_list(FieldsArg)}of
+ {string, true} ->
+ transform_insert(Form, TableArg, FieldsArg);
+ _ ->
+ throw({error, erl_syntax:get_pos(Form),
+ "?SQL_INSERT arguments must be "
+ "a constant string and a list"})
+ end;
+ _ ->
+ throw({error, erl_syntax:get_pos(Form),
+ "wrong number of ?SQL_INSERT args"})
+ end;
_ ->
Form
end;
@@ -104,7 +137,6 @@ transform(Form) ->
case erl_syntax:attribute_arguments(Form) of
[M | _] ->
Module = erl_syntax:atom_value(M),
- %io:format("module ~p~n", [Module]),
put(?MOD, Module),
Form;
_ ->
@@ -121,11 +153,7 @@ top_transform(Forms) when is_list(Forms) ->
lists:map(
fun(Form) ->
try
- Form2 = erl_syntax_lib:map(
- fun(Node) ->
- %io:format("asd ~p~n", [Node]),
- transform(Node)
- end, Form),
+ Form2 = erl_syntax_lib:map(fun transform/1, Form),
Form3 = erl_syntax:revert(Form2),
Form3
catch
@@ -134,11 +162,93 @@ top_transform(Forms) when is_list(Forms) ->
end
end, Forms).
-parse(S, Loc) ->
- parse1(S, [], #state{loc = Loc}).
+transform_sql(Arg) ->
+ S = erl_syntax:string_value(Arg),
+ Pos = erl_syntax:get_pos(Arg),
+ ParseRes = parse(S, Pos, true),
+ ParseResOld = parse(S, Pos, false),
+ case ParseRes#state.server_host_used of
+ {true, _SHVar} ->
+ ok;
+ false ->
+ add_warning(
+ Pos, no_server_host),
+ []
+ end,
+ case ParseRes#state.need_array_pass of
+ true ->
+ {PR1, PR2} = perform_array_pass(ParseRes),
+ {PRO1, PRO2} = perform_array_pass(ParseResOld),
+ set_pos(make_schema_check(
+ erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PR2)]),
+ erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PR1)])]),
+ erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PRO2)]),
+ erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PRO1)])])),
+ Pos);
+ false ->
+ set_pos(
+ make_schema_check(
+ make_sql_query(ParseRes),
+ make_sql_query(ParseResOld)
+ ),
+ Pos)
+ end.
-parse(S, ParamPos, Loc) ->
- parse1(S, [], #state{loc = Loc, param_pos = ParamPos}).
+transform_upsert(Form, TableArg, FieldsArg) ->
+ Table = erl_syntax:string_value(TableArg),
+ ParseRes =
+ parse_upsert(
+ erl_syntax:list_elements(FieldsArg)),
+ Pos = erl_syntax:get_pos(Form),
+ case lists:keymember(
+ "server_host", 1, ParseRes) of
+ true ->
+ ok;
+ false ->
+ add_warning(Pos, no_server_host)
+ end,
+ ParseResOld =
+ filter_upsert_sh(Table, ParseRes),
+ set_pos(
+ make_schema_check(
+ make_sql_upsert(Table, ParseRes, Pos),
+ make_sql_upsert(Table, ParseResOld, Pos)
+ ),
+ Pos).
+
+transform_insert(Form, TableArg, FieldsArg) ->
+ Table = erl_syntax:string_value(TableArg),
+ ParseRes =
+ parse_insert(
+ erl_syntax:list_elements(FieldsArg)),
+ Pos = erl_syntax:get_pos(Form),
+ case lists:keymember(
+ "server_host", 1, ParseRes) of
+ true ->
+ ok;
+ false ->
+ add_warning(Pos, no_server_host)
+ end,
+ ParseResOld =
+ filter_upsert_sh(Table, ParseRes),
+ set_pos(
+ make_schema_check(
+ make_sql_insert(Table, ParseRes),
+ make_sql_insert(Table, ParseResOld)
+ ),
+ Pos).
+
+
+parse(S, Loc, UseNewSchema) ->
+ parse1(S, [],
+ #state{loc = Loc,
+ use_new_schema = UseNewSchema}).
+
+parse(S, ParamPos, Loc, UseNewSchema) ->
+ parse1(S, [],
+ #state{loc = Loc,
+ param_pos = ParamPos,
+ use_new_schema = UseNewSchema}).
parse1([], Acc, State) ->
State1 = append_string(lists:reverse(Acc), State),
@@ -150,7 +260,7 @@ parse1([], Acc, State) ->
};
parse1([$@, $( | S], Acc, State) ->
State1 = append_string(lists:reverse(Acc), State),
- {Name, Type, S1, State2} = parse_name(S, State1),
+ {Name, Type, S1, State2} = parse_name(S, false, State1),
Var = "__V" ++ integer_to_list(State2#state.res_pos),
EVar = erl_syntax:variable(Var),
Convert =
@@ -174,21 +284,75 @@ parse1([$@, $( | S], Acc, State) ->
parse1(S1, [], State4);
parse1([$%, $( | S], Acc, State) ->
State1 = append_string(lists:reverse(Acc), State),
- {Name, Type, S1, State2} = parse_name(S, State1),
+ {Name, Type, S1, State2} = parse_name(S, true, State1),
Var = State2#state.param_pos,
- Convert =
- erl_syntax:application(
- erl_syntax:record_access(
- erl_syntax:variable(?ESCAPE_VAR),
- erl_syntax:atom(?ESCAPE_RECORD),
- erl_syntax:atom(Type)),
- [erl_syntax:variable(Name)]),
- State3 = State2,
State4 =
- State3#state{'query' = [{var, Var} | State3#state.'query'],
- args = [Convert | State3#state.args],
- params = [Var | State3#state.params],
- param_pos = State3#state.param_pos + 1},
+ case Type of
+ host ->
+ State3 =
+ State2#state{server_host_used = {true, Name},
+ used_vars = [Name | State2#state.used_vars]},
+ case State#state.use_new_schema of
+ true ->
+ Convert =
+ erl_syntax:application(
+ erl_syntax:record_access(
+ erl_syntax:variable(?ESCAPE_VAR),
+ erl_syntax:atom(?ESCAPE_RECORD),
+ erl_syntax:atom(string)),
+ [erl_syntax:variable(Name)]),
+ State3#state{'query' = [{var, Var},
+ {str, "server_host="} |
+ State3#state.'query'],
+ args = [Convert | State3#state.args],
+ params = [Var | State3#state.params],
+ param_pos = State3#state.param_pos + 1};
+ false ->
+ append_string("0=0", State3)
+ end;
+ {list, InternalType} ->
+ Convert = erl_syntax:application(
+ erl_syntax:atom(ejabberd_sql),
+ erl_syntax:atom(to_list),
+ [erl_syntax:record_access(
+ erl_syntax:variable(?ESCAPE_VAR),
+ erl_syntax:atom(?ESCAPE_RECORD),
+ erl_syntax:atom(InternalType)),
+ erl_syntax:variable(Name)]),
+ IT2 = case InternalType of
+ string ->
+ in_array_string;
+ _ ->
+ InternalType
+ end,
+ ConvertArr = erl_syntax:application(
+ erl_syntax:atom(ejabberd_sql),
+ erl_syntax:atom(to_array),
+ [erl_syntax:record_access(
+ erl_syntax:variable(?ESCAPE_VAR),
+ erl_syntax:atom(?ESCAPE_RECORD),
+ erl_syntax:atom(IT2)),
+ erl_syntax:variable(Name)]),
+ State2#state{'query' = [[{var, Var}] | State2#state.'query'],
+ need_array_pass = true,
+ args = [[Convert, ConvertArr] | State2#state.args],
+ params = [Var | State2#state.params],
+ param_pos = State2#state.param_pos + 1,
+ used_vars = [Name | State2#state.used_vars]};
+ _ ->
+ Convert =
+ erl_syntax:application(
+ erl_syntax:record_access(
+ erl_syntax:variable(?ESCAPE_VAR),
+ erl_syntax:atom(?ESCAPE_RECORD),
+ erl_syntax:atom(Type)),
+ [erl_syntax:variable(Name)]),
+ State2#state{'query' = [{var, Var} | State2#state.'query'],
+ args = [Convert | State2#state.args],
+ params = [Var | State2#state.params],
+ param_pos = State2#state.param_pos + 1,
+ used_vars = [Name | State2#state.used_vars]}
+ end,
parse1(S1, [], State4);
parse1([C | S], Acc, State) ->
parse1(S, [C | Acc], State).
@@ -198,41 +362,80 @@ append_string([], State) ->
append_string(S, State) ->
State#state{query = [{str, S} | State#state.query]}.
-parse_name(S, State) ->
- parse_name(S, [], 0, State).
+parse_name(S, IsArg, State) ->
+ parse_name(S, [], 0, IsArg, State).
-parse_name([], _Acc, _Depth, State) ->
+parse_name([], _Acc, _Depth, _IsArg, State) ->
throw({error, State#state.loc,
"expected ')', found end of string"});
-parse_name([$), T | S], Acc, 0, State) ->
+parse_name([$), $l, T | S], Acc, 0, true, State) ->
+ Type = case T of
+ $d -> {list, integer};
+ $s -> {list, string};
+ $b -> {list, boolean};
+ _ ->
+ throw({error, State#state.loc,
+ ["unknown type specifier 'l", T, "'"]})
+ end,
+ {lists:reverse(Acc), Type, S, State};
+parse_name([$), $l, T | _], _Acc, 0, false, State) ->
+ throw({error, State#state.loc,
+ ["list type 'l", T, "' is not allowed for outputs"]});
+parse_name([$), T | S], Acc, 0, IsArg, State) ->
Type =
case T of
$d -> integer;
$s -> string;
$b -> boolean;
+ $H when IsArg -> host;
_ ->
throw({error, State#state.loc,
["unknown type specifier '", T, "'"]})
end,
{lists:reverse(Acc), Type, S, State};
-parse_name([$)], _Acc, 0, State) ->
+parse_name([$)], _Acc, 0, _IsArg, State) ->
throw({error, State#state.loc,
"expected type specifier, found end of string"});
-parse_name([$( = C | S], Acc, Depth, State) ->
- parse_name(S, [C | Acc], Depth + 1, State);
-parse_name([$) = C | S], Acc, Depth, State) ->
- parse_name(S, [C | Acc], Depth - 1, State);
-parse_name([C | S], Acc, Depth, State) ->
- parse_name(S, [C | Acc], Depth, State).
+parse_name([$( = C | S], Acc, Depth, IsArg, State) ->
+ parse_name(S, [C | Acc], Depth + 1, IsArg, State);
+parse_name([$) = C | S], Acc, Depth, IsArg, State) ->
+ parse_name(S, [C | Acc], Depth - 1, IsArg, State);
+parse_name([C | S], Acc, Depth, IsArg, State) ->
+ parse_name(S, [C | Acc], Depth, IsArg, State).
make_var(V) ->
Var = "__V" ++ integer_to_list(V),
erl_syntax:variable(Var).
+perform_array_pass(State) ->
+ {NQ, PQ, Rest} = lists:foldl(
+ fun([{var, _} = Var], {N, P, {str, Str} = Prev}) ->
+ Str2 = re:replace(Str, "(^|\s+)in\s*$", " = any(", [{return, list}]),
+ {[Var, Prev | N], [{str, ")"}, Var, {str, Str2} | P], none};
+ ([{var, _}], _) ->
+ throw({error, State#state.loc, ["List variable not following 'in' operator"]});
+ (Other, {N, P, none}) ->
+ {N, P, Other};
+ (Other, {N, P, Prev}) ->
+ {[Prev | N], [Prev | P], Other}
+ end, {[], [], none}, State#state.query),
+ {NQ2, PQ2} = case Rest of
+ none ->
+ {NQ, PQ};
+ _ -> {[Rest | NQ], [Rest | PQ]}
+ end,
+ {NA, PA} = lists:foldl(
+ fun([V1, V2], {N, P}) ->
+ {[V1 | N], [V2 | P]};
+ (Other, {N, P}) ->
+ {[Other | N], [Other | P]}
+ end, {[], []}, State#state.args),
+ {State#state{query = lists:reverse(NQ2), args = lists:reverse(NA), need_array_pass = false},
+ State#state{query = lists:reverse(PQ2), args = lists:reverse(PA), need_array_pass = false}}.
make_sql_query(State) ->
- Hash = erlang:phash2(State#state{loc = undefined}),
+ Hash = erlang:phash2(State#state{loc = undefined, use_new_schema = true}),
SHash = <<"Q", (integer_to_binary(Hash))/binary>>,
Query = pack_query(State#state.'query'),
EQuery =
@@ -305,7 +508,6 @@ parse_upsert(Fields) ->
"a constant string"})
end
end, {[], 0}, Fields),
- %io:format("upsert ~p~n", [{Fields, Fs}]),
Fs.
%% key | {Update}
@@ -324,7 +526,7 @@ parse_upsert_field1([], _Acc, _ParamPos, Loc) ->
"?SQL_UPSERT fields must have the "
"following form: \"[!-]name=value\""});
parse_upsert_field1([$= | S], Acc, ParamPos, Loc) ->
- {lists:reverse(Acc), parse(S, ParamPos, Loc)};
+ {lists:reverse(Acc), parse(S, ParamPos, Loc, true)};
parse_upsert_field1([C | S], Acc, ParamPos, Loc) ->
parse_upsert_field1(S, [C | Acc], ParamPos, Loc).
@@ -426,7 +628,7 @@ make_sql_upsert_insert(Table, ParseRes) ->
join_states(Fields, ", "),
#state{'query' = [{str, ") VALUES ("}]},
join_states(Vals, ", "),
- #state{'query' = [{str, ")"}]}
+ #state{'query' = [{str, ");"}]}
]),
State.
@@ -480,6 +682,66 @@ check_upsert(ParseRes, Pos) ->
ok.
+parse_insert(Fields) ->
+ {Fs, _} =
+ lists:foldr(
+ fun(F, {Acc, Param}) ->
+ case erl_syntax:type(F) of
+ string ->
+ V = erl_syntax:string_value(F),
+ {_, _, State} = Res =
+ parse_insert_field(
+ V, Param, erl_syntax:get_pos(F)),
+ {[Res | Acc], State#state.param_pos};
+ _ ->
+ throw({error, erl_syntax:get_pos(F),
+ "?SQL_INSERT field must be "
+ "a constant string"})
+ end
+ end, {[], 0}, Fields),
+ Fs.
+
+parse_insert_field([$! | _S], _ParamPos, Loc) ->
+ throw({error, Loc,
+ "?SQL_INSERT fields must not start with \"!\""});
+parse_insert_field([$- | _S], _ParamPos, Loc) ->
+ throw({error, Loc,
+ "?SQL_INSERT fields must not start with \"-\""});
+parse_insert_field(S, ParamPos, Loc) ->
+ {Name, ParseState} = parse_insert_field1(S, [], ParamPos, Loc),
+ {Name, {true}, ParseState}.
+
+parse_insert_field1([], _Acc, _ParamPos, Loc) ->
+ throw({error, Loc,
+ "?SQL_INSERT fields must have the "
+ "following form: \"name=value\""});
+parse_insert_field1([$= | S], Acc, ParamPos, Loc) ->
+ {lists:reverse(Acc), parse(S, ParamPos, Loc, true)};
+parse_insert_field1([C | S], Acc, ParamPos, Loc) ->
+ parse_insert_field1(S, [C | Acc], ParamPos, Loc).
+
+
+make_sql_insert(Table, ParseRes) ->
+ make_sql_query(make_sql_upsert_insert(Table, ParseRes)).
+
+make_schema_check(Tree, Tree) ->
+ Tree;
+make_schema_check(New, Old) ->
+ erl_syntax:case_expr(
+ erl_syntax:application(
+ erl_syntax:atom(ejabberd_sql),
+ erl_syntax:atom(use_new_schema),
+ []),
+ [erl_syntax:clause(
+ [erl_syntax:abstract(true)],
+ none,
+ [New]),
+ erl_syntax:clause(
+ [erl_syntax:abstract(false)],
+ none,
+ [Old])]).
+
+
concat_states(States) ->
lists:foldr(
fun(ST11, ST2) ->
@@ -548,3 +810,24 @@ set_pos(Tree, Pos) ->
_ -> Node
end
end, Tree).
+
+filter_upsert_sh(Table, ParseRes) ->
+ lists:filter(
+ fun({Field, _Match, _ST}) ->
+ Field /= "server_host" orelse Table == "route"
+ end, ParseRes).
+
+-ifdef(ENABLE_PT_WARNINGS).
+
+add_warning(Pos, Warning) ->
+ Marker = erl_syntax:revert(
+ erl_syntax:warning_marker({Pos, ?MODULE, Warning})),
+ put(warnings, [Marker | get(warnings)]),
+ ok.
+
+-else.
+
+add_warning(_Pos, _Warning) ->
+ ok.
+
+-endif.