diff options
author | Alexey Shchepin <alexey@process-one.net> | 2016-02-09 19:23:15 +0300 |
---|---|---|
committer | Alexey Shchepin <alexey@process-one.net> | 2016-03-01 22:48:30 +0300 |
commit | 6374ef48669283933931946f9fbe9a6fccd811ed (patch) | |
tree | ff479e6bd02be3e5abbd8c535d63cdc7296c5faf /src/ejabberd_sql_pt.erl | |
parent | Update ejabberd version for hex.pm release (diff) |
New parse transform for SQL queries, use prepare/execute calls with Postgres
Diffstat (limited to 'src/ejabberd_sql_pt.erl')
-rw-r--r-- | src/ejabberd_sql_pt.erl | 255 |
1 files changed, 255 insertions, 0 deletions
diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl new file mode 100644 index 000000000..f9701a0be --- /dev/null +++ b/src/ejabberd_sql_pt.erl @@ -0,0 +1,255 @@ +%%%------------------------------------------------------------------- +%%% File : ejabberd_sql_pt.erl +%%% Author : Alexey Shchepin <alexey@process-one.net> +%%% Description : Parse transform for SQL queries +%%% +%%% Created : 20 Jan 2016 by Alexey Shchepin <alexey@process-one.net> +%%%------------------------------------------------------------------- +-module(ejabberd_sql_pt). + +%% API +-export([parse_transform/2]). + +-export([parse/2]). + +-include("ejabberd_sql_pt.hrl"). + +-record(state, {loc, + 'query' = [], + params = [], + param_pos = 0, + args = [], + res = [], + res_vars = [], + res_pos = 0}). + +-define(QUERY_RECORD, "sql_query"). + +-define(ESCAPE_RECORD, "sql_escape"). +-define(ESCAPE_VAR, "__SQLEscape"). + +-define(MOD, sql__module_). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: +%% Description: +%%-------------------------------------------------------------------- +parse_transform(AST, _Options) -> + %io:format("PT: ~p~nOpts: ~p~n", [AST, Options]), + NewAST = top_transform(AST), + %io:format("NewPT: ~p~n", [NewAST]), + NewAST. + + + +%%==================================================================== +%% Internal functions +%%==================================================================== + + +transform(Form) -> + case erl_syntax:type(Form) of + application -> + case erl_syntax_lib:analyze_application(Form) of + {?SQL_MARK, 1} -> + case erl_syntax:application_arguments(Form) of + [Arg] -> + case erl_syntax:type(Arg) of + string -> + S = erl_syntax:string_value(Arg), + ParseRes = + parse(S, erl_syntax:get_pos(Arg)), + make_sql_query(ParseRes); + _ -> + throw({error, erl_syntax:get_pos(Form), + "?SQL argument must be " + "a constant string"}) + end; + _ -> + throw({error, erl_syntax:get_pos(Form), + "wrong number of ?SQL args"}) + end; + _ -> + Form + end; + attribute -> + case erl_syntax:atom_value(erl_syntax:attribute_name(Form)) of + module -> + case erl_syntax:attribute_arguments(Form) of + [M | _] -> + Module = erl_syntax:atom_value(M), + %io:format("module ~p~n", [Module]), + put(?MOD, Module), + Form; + _ -> + Form + end; + _ -> + Form + end; + _ -> + Form + end. + +top_transform(Forms) when is_list(Forms) -> + lists:map( + fun(Form) -> + try + Form2 = erl_syntax_lib:map( + fun(Node) -> + %io:format("asd ~p~n", [Node]), + transform(Node) + end, Form), + Form3 = erl_syntax:revert(Form2), + Form3 + catch + throw:{error, Line, Error} -> + {error, {Line, erl_parse, Error}} + end + end, Forms). + +parse(S, Loc) -> + parse1(S, [], #state{loc = Loc}). + +parse1([], Acc, State) -> + State1 = append_string(lists:reverse(Acc), State), + State1#state{'query' = lists:reverse(State1#state.'query'), + params = lists:reverse(State1#state.params), + args = lists:reverse(State1#state.args), + res = lists:reverse(State1#state.res), + res_vars = lists:reverse(State1#state.res_vars) + }; +parse1([$@, $( | S], Acc, State) -> + State1 = append_string(lists:reverse(Acc), State), + {Name, Type, S1, State2} = parse_name(S, State1), + Var = "__V" ++ integer_to_list(State2#state.res_pos), + EVar = erl_syntax:variable(Var), + Convert = + case Type of + integer -> + erl_syntax:application( + erl_syntax:atom(binary_to_integer), + [EVar]); + string -> + EVar; + boolean -> + erl_syntax:application( + erl_syntax:atom(ejabberd_odbc), + erl_syntax:atom(to_bool), + [EVar]) + end, + State3 = append_string(Name, State2), + State4 = State3#state{res_pos = State3#state.res_pos + 1, + res = [Convert | State3#state.res], + res_vars = [EVar | State3#state.res_vars]}, + parse1(S1, [], State4); +parse1([$%, $( | S], Acc, State) -> + State1 = append_string(lists:reverse(Acc), State), + {Name, Type, S1, State2} = parse_name(S, State1), + Var = "__V" ++ integer_to_list(State2#state.param_pos), + EVar = erl_syntax:variable(Var), + Convert = + erl_syntax:application( + erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(Type)), + [erl_syntax:variable(Name)]), + State3 = State2, + State4 = + State3#state{'query' = [{var, EVar} | State3#state.'query'], + args = [Convert | State3#state.args], + params = [EVar | State3#state.params], + param_pos = State3#state.param_pos + 1}, + parse1(S1, [], State4); +parse1([C | S], Acc, State) -> + parse1(S, [C | Acc], State). + +append_string([], State) -> + State; +append_string(S, State) -> + State#state{query = [{str, S} | State#state.query]}. + +parse_name(S, State) -> + parse_name(S, [], State). + +parse_name([], Acc, State) -> + % todo + error; +parse_name([$), T | S], Acc, State) -> + Type = + case T of + $d -> integer; + $s -> string; + $b -> boolean; + _ -> + % todo + error + end, + {lists:reverse(Acc), Type, S, State}; +parse_name([$) | _], Acc, State) -> + % todo + error; +parse_name([C | S], Acc, State) -> + parse_name(S, [C | Acc], State). + + +make_sql_query(State) -> + Hash = erlang:phash2(State#state{loc = undefined}), + SHash = <<"Q", (integer_to_binary(Hash))/binary>>, + Query = pack_query(State#state.'query'), + EQuery = + lists:map( + fun({str, S}) -> + erl_syntax:binary( + [erl_syntax:binary_field( + erl_syntax:string(S))]); + ({var, V}) -> V + end, Query), + erl_syntax:record_expr( + erl_syntax:atom(?QUERY_RECORD), + [erl_syntax:record_field( + erl_syntax:atom(hash), + %erl_syntax:abstract(SHash) + erl_syntax:binary( + [erl_syntax:binary_field( + erl_syntax:string(binary_to_list(SHash)))])), + erl_syntax:record_field( + erl_syntax:atom(args), + erl_syntax:fun_expr( + [erl_syntax:clause( + [erl_syntax:variable(?ESCAPE_VAR)], + none, + [erl_syntax:list(State#state.args)] + )])), + erl_syntax:record_field( + erl_syntax:atom(format_query), + erl_syntax:fun_expr( + [erl_syntax:clause( + [erl_syntax:list(State#state.params)], + none, + [erl_syntax:list(EQuery)] + )])), + erl_syntax:record_field( + erl_syntax:atom(format_res), + erl_syntax:fun_expr( + [erl_syntax:clause( + [erl_syntax:list(State#state.res_vars)], + none, + [erl_syntax:tuple(State#state.res)] + )])), + erl_syntax:record_field( + erl_syntax:atom(loc), + erl_syntax:abstract({get(?MOD), State#state.loc})) + ]). + +pack_query([]) -> + []; +pack_query([{str, S1}, {str, S2} | Rest]) -> + pack_query([{str, S1 ++ S2} | Rest]); +pack_query([X | Rest]) -> + [X | pack_query(Rest)]. + |