diff options
Diffstat (limited to 'src/ejabberd_sql_pt.erl')
-rw-r--r-- | src/ejabberd_sql_pt.erl | 405 |
1 files changed, 344 insertions, 61 deletions
diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl index e51b7f928..0896b4b1a 100644 --- a/src/ejabberd_sql_pt.erl +++ b/src/ejabberd_sql_pt.erl @@ -2,17 +2,33 @@ %%% File : ejabberd_sql_pt.erl %%% Author : Alexey Shchepin <alexey@process-one.net> %%% Description : Parse transform for SQL queries -%%% %%% Created : 20 Jan 2016 by Alexey Shchepin <alexey@process-one.net> -%%%------------------------------------------------------------------- +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2019 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%---------------------------------------------------------------------- + -module(ejabberd_sql_pt). %% API --export([parse_transform/2]). - --export([parse/2]). +-export([parse_transform/2, format_error/1]). --include("ejabberd_sql_pt.hrl"). +-include("ejabberd_sql.hrl"). -record(state, {loc, 'query' = [], @@ -21,7 +37,11 @@ args = [], res = [], res_vars = [], - res_pos = 0}). + res_pos = 0, + server_host_used = false, + used_vars = [], + use_new_schema, + need_array_pass = false}). -define(QUERY_RECORD, "sql_query"). @@ -30,6 +50,12 @@ -define(MOD, sql__module_). +-ifdef(NEW_SQL_SCHEMA). +-define(USE_NEW_SCHEMA, true). +-else. +-define(USE_NEW_SCHEMA, false). +-endif. + %%==================================================================== %% API %%==================================================================== @@ -38,12 +64,13 @@ %% Description: %%-------------------------------------------------------------------- parse_transform(AST, _Options) -> - %io:format("PT: ~p~nOpts: ~p~n", [AST, Options]), + put(warnings, []), NewAST = top_transform(AST), - %io:format("NewPT: ~p~n", [NewAST]), - NewAST. + NewAST ++ get(warnings). +format_error(no_server_host) -> + "server_host field is not used". %%==================================================================== %% Internal functions @@ -59,10 +86,7 @@ transform(Form) -> [Arg] -> case erl_syntax:type(Arg) of string -> - S = erl_syntax:string_value(Arg), - Pos = erl_syntax:get_pos(Arg), - ParseRes = parse(S, Pos), - set_pos(make_sql_query(ParseRes), Pos); + transform_sql(Arg); _ -> throw({error, erl_syntax:get_pos(Form), "?SQL argument must be " @@ -78,14 +102,7 @@ transform(Form) -> case {erl_syntax:type(TableArg), erl_syntax:is_proper_list(FieldsArg)}of {string, true} -> - Table = erl_syntax:string_value(TableArg), - ParseRes = - parse_upsert( - erl_syntax:list_elements(FieldsArg)), - Pos = erl_syntax:get_pos(Form), - set_pos( - make_sql_upsert(Table, ParseRes, Pos), - Pos); + transform_upsert(Form, TableArg, FieldsArg); _ -> throw({error, erl_syntax:get_pos(Form), "?SQL_UPSERT arguments must be " @@ -95,6 +112,22 @@ transform(Form) -> throw({error, erl_syntax:get_pos(Form), "wrong number of ?SQL_UPSERT args"}) end; + {?SQL_INSERT_MARK, 2} -> + case erl_syntax:application_arguments(Form) of + [TableArg, FieldsArg] -> + case {erl_syntax:type(TableArg), + erl_syntax:is_proper_list(FieldsArg)}of + {string, true} -> + transform_insert(Form, TableArg, FieldsArg); + _ -> + throw({error, erl_syntax:get_pos(Form), + "?SQL_INSERT arguments must be " + "a constant string and a list"}) + end; + _ -> + throw({error, erl_syntax:get_pos(Form), + "wrong number of ?SQL_INSERT args"}) + end; _ -> Form end; @@ -104,7 +137,6 @@ transform(Form) -> case erl_syntax:attribute_arguments(Form) of [M | _] -> Module = erl_syntax:atom_value(M), - %io:format("module ~p~n", [Module]), put(?MOD, Module), Form; _ -> @@ -121,11 +153,7 @@ top_transform(Forms) when is_list(Forms) -> lists:map( fun(Form) -> try - Form2 = erl_syntax_lib:map( - fun(Node) -> - %io:format("asd ~p~n", [Node]), - transform(Node) - end, Form), + Form2 = erl_syntax_lib:map(fun transform/1, Form), Form3 = erl_syntax:revert(Form2), Form3 catch @@ -134,11 +162,93 @@ top_transform(Forms) when is_list(Forms) -> end end, Forms). -parse(S, Loc) -> - parse1(S, [], #state{loc = Loc}). +transform_sql(Arg) -> + S = erl_syntax:string_value(Arg), + Pos = erl_syntax:get_pos(Arg), + ParseRes = parse(S, Pos, true), + ParseResOld = parse(S, Pos, false), + case ParseRes#state.server_host_used of + {true, _SHVar} -> + ok; + false -> + add_warning( + Pos, no_server_host), + [] + end, + case ParseRes#state.need_array_pass of + true -> + {PR1, PR2} = perform_array_pass(ParseRes), + {PRO1, PRO2} = perform_array_pass(ParseResOld), + set_pos(make_schema_check( + erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PR2)]), + erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PR1)])]), + erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PRO2)]), + erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PRO1)])])), + Pos); + false -> + set_pos( + make_schema_check( + make_sql_query(ParseRes), + make_sql_query(ParseResOld) + ), + Pos) + end. -parse(S, ParamPos, Loc) -> - parse1(S, [], #state{loc = Loc, param_pos = ParamPos}). +transform_upsert(Form, TableArg, FieldsArg) -> + Table = erl_syntax:string_value(TableArg), + ParseRes = + parse_upsert( + erl_syntax:list_elements(FieldsArg)), + Pos = erl_syntax:get_pos(Form), + case lists:keymember( + "server_host", 1, ParseRes) of + true -> + ok; + false -> + add_warning(Pos, no_server_host) + end, + ParseResOld = + filter_upsert_sh(Table, ParseRes), + set_pos( + make_schema_check( + make_sql_upsert(Table, ParseRes, Pos), + make_sql_upsert(Table, ParseResOld, Pos) + ), + Pos). + +transform_insert(Form, TableArg, FieldsArg) -> + Table = erl_syntax:string_value(TableArg), + ParseRes = + parse_insert( + erl_syntax:list_elements(FieldsArg)), + Pos = erl_syntax:get_pos(Form), + case lists:keymember( + "server_host", 1, ParseRes) of + true -> + ok; + false -> + add_warning(Pos, no_server_host) + end, + ParseResOld = + filter_upsert_sh(Table, ParseRes), + set_pos( + make_schema_check( + make_sql_insert(Table, ParseRes), + make_sql_insert(Table, ParseResOld) + ), + Pos). + + +parse(S, Loc, UseNewSchema) -> + parse1(S, [], + #state{loc = Loc, + use_new_schema = UseNewSchema}). + +parse(S, ParamPos, Loc, UseNewSchema) -> + parse1(S, [], + #state{loc = Loc, + param_pos = ParamPos, + use_new_schema = UseNewSchema}). parse1([], Acc, State) -> State1 = append_string(lists:reverse(Acc), State), @@ -150,7 +260,7 @@ parse1([], Acc, State) -> }; parse1([$@, $( | S], Acc, State) -> State1 = append_string(lists:reverse(Acc), State), - {Name, Type, S1, State2} = parse_name(S, State1), + {Name, Type, S1, State2} = parse_name(S, false, State1), Var = "__V" ++ integer_to_list(State2#state.res_pos), EVar = erl_syntax:variable(Var), Convert = @@ -174,21 +284,75 @@ parse1([$@, $( | S], Acc, State) -> parse1(S1, [], State4); parse1([$%, $( | S], Acc, State) -> State1 = append_string(lists:reverse(Acc), State), - {Name, Type, S1, State2} = parse_name(S, State1), + {Name, Type, S1, State2} = parse_name(S, true, State1), Var = State2#state.param_pos, - Convert = - erl_syntax:application( - erl_syntax:record_access( - erl_syntax:variable(?ESCAPE_VAR), - erl_syntax:atom(?ESCAPE_RECORD), - erl_syntax:atom(Type)), - [erl_syntax:variable(Name)]), - State3 = State2, State4 = - State3#state{'query' = [{var, Var} | State3#state.'query'], - args = [Convert | State3#state.args], - params = [Var | State3#state.params], - param_pos = State3#state.param_pos + 1}, + case Type of + host -> + State3 = + State2#state{server_host_used = {true, Name}, + used_vars = [Name | State2#state.used_vars]}, + case State#state.use_new_schema of + true -> + Convert = + erl_syntax:application( + erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(string)), + [erl_syntax:variable(Name)]), + State3#state{'query' = [{var, Var}, + {str, "server_host="} | + State3#state.'query'], + args = [Convert | State3#state.args], + params = [Var | State3#state.params], + param_pos = State3#state.param_pos + 1}; + false -> + append_string("0=0", State3) + end; + {list, InternalType} -> + Convert = erl_syntax:application( + erl_syntax:atom(ejabberd_sql), + erl_syntax:atom(to_list), + [erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(InternalType)), + erl_syntax:variable(Name)]), + IT2 = case InternalType of + string -> + in_array_string; + _ -> + InternalType + end, + ConvertArr = erl_syntax:application( + erl_syntax:atom(ejabberd_sql), + erl_syntax:atom(to_array), + [erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(IT2)), + erl_syntax:variable(Name)]), + State2#state{'query' = [[{var, Var}] | State2#state.'query'], + need_array_pass = true, + args = [[Convert, ConvertArr] | State2#state.args], + params = [Var | State2#state.params], + param_pos = State2#state.param_pos + 1, + used_vars = [Name | State2#state.used_vars]}; + _ -> + Convert = + erl_syntax:application( + erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(Type)), + [erl_syntax:variable(Name)]), + State2#state{'query' = [{var, Var} | State2#state.'query'], + args = [Convert | State2#state.args], + params = [Var | State2#state.params], + param_pos = State2#state.param_pos + 1, + used_vars = [Name | State2#state.used_vars]} + end, parse1(S1, [], State4); parse1([C | S], Acc, State) -> parse1(S, [C | Acc], State). @@ -198,41 +362,80 @@ append_string([], State) -> append_string(S, State) -> State#state{query = [{str, S} | State#state.query]}. -parse_name(S, State) -> - parse_name(S, [], 0, State). +parse_name(S, IsArg, State) -> + parse_name(S, [], 0, IsArg, State). -parse_name([], _Acc, _Depth, State) -> +parse_name([], _Acc, _Depth, _IsArg, State) -> throw({error, State#state.loc, "expected ')', found end of string"}); -parse_name([$), T | S], Acc, 0, State) -> +parse_name([$), $l, T | S], Acc, 0, true, State) -> + Type = case T of + $d -> {list, integer}; + $s -> {list, string}; + $b -> {list, boolean}; + _ -> + throw({error, State#state.loc, + ["unknown type specifier 'l", T, "'"]}) + end, + {lists:reverse(Acc), Type, S, State}; +parse_name([$), $l, T | _], _Acc, 0, false, State) -> + throw({error, State#state.loc, + ["list type 'l", T, "' is not allowed for outputs"]}); +parse_name([$), T | S], Acc, 0, IsArg, State) -> Type = case T of $d -> integer; $s -> string; $b -> boolean; + $H when IsArg -> host; _ -> throw({error, State#state.loc, ["unknown type specifier '", T, "'"]}) end, {lists:reverse(Acc), Type, S, State}; -parse_name([$)], _Acc, 0, State) -> +parse_name([$)], _Acc, 0, _IsArg, State) -> throw({error, State#state.loc, "expected type specifier, found end of string"}); -parse_name([$( = C | S], Acc, Depth, State) -> - parse_name(S, [C | Acc], Depth + 1, State); -parse_name([$) = C | S], Acc, Depth, State) -> - parse_name(S, [C | Acc], Depth - 1, State); -parse_name([C | S], Acc, Depth, State) -> - parse_name(S, [C | Acc], Depth, State). +parse_name([$( = C | S], Acc, Depth, IsArg, State) -> + parse_name(S, [C | Acc], Depth + 1, IsArg, State); +parse_name([$) = C | S], Acc, Depth, IsArg, State) -> + parse_name(S, [C | Acc], Depth - 1, IsArg, State); +parse_name([C | S], Acc, Depth, IsArg, State) -> + parse_name(S, [C | Acc], Depth, IsArg, State). make_var(V) -> Var = "__V" ++ integer_to_list(V), erl_syntax:variable(Var). +perform_array_pass(State) -> + {NQ, PQ, Rest} = lists:foldl( + fun([{var, _} = Var], {N, P, {str, Str} = Prev}) -> + Str2 = re:replace(Str, "(^|\s+)in\s*$", " = any(", [{return, list}]), + {[Var, Prev | N], [{str, ")"}, Var, {str, Str2} | P], none}; + ([{var, _}], _) -> + throw({error, State#state.loc, ["List variable not following 'in' operator"]}); + (Other, {N, P, none}) -> + {N, P, Other}; + (Other, {N, P, Prev}) -> + {[Prev | N], [Prev | P], Other} + end, {[], [], none}, State#state.query), + {NQ2, PQ2} = case Rest of + none -> + {NQ, PQ}; + _ -> {[Rest | NQ], [Rest | PQ]} + end, + {NA, PA} = lists:foldl( + fun([V1, V2], {N, P}) -> + {[V1 | N], [V2 | P]}; + (Other, {N, P}) -> + {[Other | N], [Other | P]} + end, {[], []}, State#state.args), + {State#state{query = lists:reverse(NQ2), args = lists:reverse(NA), need_array_pass = false}, + State#state{query = lists:reverse(PQ2), args = lists:reverse(PA), need_array_pass = false}}. make_sql_query(State) -> - Hash = erlang:phash2(State#state{loc = undefined}), + Hash = erlang:phash2(State#state{loc = undefined, use_new_schema = true}), SHash = <<"Q", (integer_to_binary(Hash))/binary>>, Query = pack_query(State#state.'query'), EQuery = @@ -305,7 +508,6 @@ parse_upsert(Fields) -> "a constant string"}) end end, {[], 0}, Fields), - %io:format("upsert ~p~n", [{Fields, Fs}]), Fs. %% key | {Update} @@ -324,7 +526,7 @@ parse_upsert_field1([], _Acc, _ParamPos, Loc) -> "?SQL_UPSERT fields must have the " "following form: \"[!-]name=value\""}); parse_upsert_field1([$= | S], Acc, ParamPos, Loc) -> - {lists:reverse(Acc), parse(S, ParamPos, Loc)}; + {lists:reverse(Acc), parse(S, ParamPos, Loc, true)}; parse_upsert_field1([C | S], Acc, ParamPos, Loc) -> parse_upsert_field1(S, [C | Acc], ParamPos, Loc). @@ -426,7 +628,7 @@ make_sql_upsert_insert(Table, ParseRes) -> join_states(Fields, ", "), #state{'query' = [{str, ") VALUES ("}]}, join_states(Vals, ", "), - #state{'query' = [{str, ")"}]} + #state{'query' = [{str, ");"}]} ]), State. @@ -480,6 +682,66 @@ check_upsert(ParseRes, Pos) -> ok. +parse_insert(Fields) -> + {Fs, _} = + lists:foldr( + fun(F, {Acc, Param}) -> + case erl_syntax:type(F) of + string -> + V = erl_syntax:string_value(F), + {_, _, State} = Res = + parse_insert_field( + V, Param, erl_syntax:get_pos(F)), + {[Res | Acc], State#state.param_pos}; + _ -> + throw({error, erl_syntax:get_pos(F), + "?SQL_INSERT field must be " + "a constant string"}) + end + end, {[], 0}, Fields), + Fs. + +parse_insert_field([$! | _S], _ParamPos, Loc) -> + throw({error, Loc, + "?SQL_INSERT fields must not start with \"!\""}); +parse_insert_field([$- | _S], _ParamPos, Loc) -> + throw({error, Loc, + "?SQL_INSERT fields must not start with \"-\""}); +parse_insert_field(S, ParamPos, Loc) -> + {Name, ParseState} = parse_insert_field1(S, [], ParamPos, Loc), + {Name, {true}, ParseState}. + +parse_insert_field1([], _Acc, _ParamPos, Loc) -> + throw({error, Loc, + "?SQL_INSERT fields must have the " + "following form: \"name=value\""}); +parse_insert_field1([$= | S], Acc, ParamPos, Loc) -> + {lists:reverse(Acc), parse(S, ParamPos, Loc, true)}; +parse_insert_field1([C | S], Acc, ParamPos, Loc) -> + parse_insert_field1(S, [C | Acc], ParamPos, Loc). + + +make_sql_insert(Table, ParseRes) -> + make_sql_query(make_sql_upsert_insert(Table, ParseRes)). + +make_schema_check(Tree, Tree) -> + Tree; +make_schema_check(New, Old) -> + erl_syntax:case_expr( + erl_syntax:application( + erl_syntax:atom(ejabberd_sql), + erl_syntax:atom(use_new_schema), + []), + [erl_syntax:clause( + [erl_syntax:abstract(true)], + none, + [New]), + erl_syntax:clause( + [erl_syntax:abstract(false)], + none, + [Old])]). + + concat_states(States) -> lists:foldr( fun(ST11, ST2) -> @@ -548,3 +810,24 @@ set_pos(Tree, Pos) -> _ -> Node end end, Tree). + +filter_upsert_sh(Table, ParseRes) -> + lists:filter( + fun({Field, _Match, _ST}) -> + Field /= "server_host" orelse Table == "route" + end, ParseRes). + +-ifdef(ENABLE_PT_WARNINGS). + +add_warning(Pos, Warning) -> + Marker = erl_syntax:revert( + erl_syntax:warning_marker({Pos, ?MODULE, Warning})), + put(warnings, [Marker | get(warnings)]), + ok. + +-else. + +add_warning(_Pos, _Warning) -> + ok. + +-endif. |