diff options
Diffstat (limited to 'src/ejabberd_sql_pt.erl')
-rw-r--r-- | src/ejabberd_sql_pt.erl | 200 |
1 files changed, 170 insertions, 30 deletions
diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl index e90947a5f..197e5ee6d 100644 --- a/src/ejabberd_sql_pt.erl +++ b/src/ejabberd_sql_pt.erl @@ -26,7 +26,7 @@ -module(ejabberd_sql_pt). %% API --export([parse_transform/2]). +-export([parse_transform/2, format_error/1]). -export([parse/2]). @@ -39,7 +39,8 @@ args = [], res = [], res_vars = [], - res_pos = 0}). + res_pos = 0, + server_host_used = false}). -define(QUERY_RECORD, "sql_query"). @@ -48,6 +49,12 @@ -define(MOD, sql__module_). +-ifdef(NEW_SQL_SCHEMA). +-define(USE_NEW_SCHEMA, true). +-else. +-define(USE_NEW_SCHEMA, false). +-endif. + %%==================================================================== %% API %%==================================================================== @@ -57,11 +64,14 @@ %%-------------------------------------------------------------------- parse_transform(AST, _Options) -> %io:format("PT: ~p~nOpts: ~p~n", [AST, Options]), + put(warnings, []), NewAST = top_transform(AST), %io:format("NewPT: ~p~n", [NewAST]), - NewAST. + NewAST ++ get(warnings). +format_error(no_server_host) -> + "server_host field is not used". %%==================================================================== %% Internal functions @@ -80,6 +90,12 @@ transform(Form) -> S = erl_syntax:string_value(Arg), Pos = erl_syntax:get_pos(Arg), ParseRes = parse(S, Pos), + if + ParseRes#state.server_host_used -> + ok; + true -> + add_warning(Pos, no_server_host) + end, set_pos(make_sql_query(ParseRes), Pos); _ -> throw({error, erl_syntax:get_pos(Form), @@ -101,8 +117,17 @@ transform(Form) -> parse_upsert( erl_syntax:list_elements(FieldsArg)), Pos = erl_syntax:get_pos(Form), + case lists:keymember( + "server_host", 1, ParseRes) of + true -> + ok; + false -> + add_warning(Pos, no_server_host) + end, + ParseRes2 = + filter_upsert_sh(Table, ParseRes), set_pos( - make_sql_upsert(Table, ParseRes, Pos), + make_sql_upsert(Table, ParseRes2, Pos), Pos); _ -> throw({error, erl_syntax:get_pos(Form), @@ -113,6 +138,38 @@ transform(Form) -> throw({error, erl_syntax:get_pos(Form), "wrong number of ?SQL_UPSERT args"}) end; + {?SQL_INSERT_MARK, 2} -> + case erl_syntax:application_arguments(Form) of + [TableArg, FieldsArg] -> + case {erl_syntax:type(TableArg), + erl_syntax:is_proper_list(FieldsArg)}of + {string, true} -> + Table = erl_syntax:string_value(TableArg), + ParseRes = + parse_insert( + erl_syntax:list_elements(FieldsArg)), + Pos = erl_syntax:get_pos(Form), + case lists:keymember( + "server_host", 1, ParseRes) of + true -> + ok; + false -> + add_warning(Pos, no_server_host) + end, + ParseRes2 = + filter_upsert_sh(Table, ParseRes), + set_pos( + make_sql_insert(Table, ParseRes2), + Pos); + _ -> + throw({error, erl_syntax:get_pos(Form), + "?SQL_INSERT arguments must be " + "a constant string and a list"}) + end; + _ -> + throw({error, erl_syntax:get_pos(Form), + "wrong number of ?SQL_INSERT args"}) + end; _ -> Form end; @@ -168,7 +225,7 @@ parse1([], Acc, State) -> }; parse1([$@, $( | S], Acc, State) -> State1 = append_string(lists:reverse(Acc), State), - {Name, Type, S1, State2} = parse_name(S, State1), + {Name, Type, S1, State2} = parse_name(S, false, State1), Var = "__V" ++ integer_to_list(State2#state.res_pos), EVar = erl_syntax:variable(Var), Convert = @@ -192,21 +249,43 @@ parse1([$@, $( | S], Acc, State) -> parse1(S1, [], State4); parse1([$%, $( | S], Acc, State) -> State1 = append_string(lists:reverse(Acc), State), - {Name, Type, S1, State2} = parse_name(S, State1), + {Name, Type, S1, State2} = parse_name(S, true, State1), Var = State2#state.param_pos, - Convert = - erl_syntax:application( - erl_syntax:record_access( - erl_syntax:variable(?ESCAPE_VAR), - erl_syntax:atom(?ESCAPE_RECORD), - erl_syntax:atom(Type)), - [erl_syntax:variable(Name)]), - State3 = State2, State4 = - State3#state{'query' = [{var, Var} | State3#state.'query'], - args = [Convert | State3#state.args], - params = [Var | State3#state.params], - param_pos = State3#state.param_pos + 1}, + case Type of + host -> + State3 = State2#state{server_host_used = true}, + case ?USE_NEW_SCHEMA of + true -> + Convert = + erl_syntax:application( + erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(string)), + [erl_syntax:variable(Name)]), + State3#state{'query' = [{var, Var}, + {str, "server_host="} | + State3#state.'query'], + args = [Convert | State3#state.args], + params = [Var | State3#state.params], + param_pos = State3#state.param_pos + 1}; + false -> + append_string("0=0", State3) + end; + _ -> + Convert = + erl_syntax:application( + erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(Type)), + [erl_syntax:variable(Name)]), + State2#state{'query' = [{var, Var} | State2#state.'query'], + args = [Convert | State2#state.args], + params = [Var | State2#state.params], + param_pos = State2#state.param_pos + 1} + end, parse1(S1, [], State4); parse1([C | S], Acc, State) -> parse1(S, [C | Acc], State). @@ -216,32 +295,33 @@ append_string([], State) -> append_string(S, State) -> State#state{query = [{str, S} | State#state.query]}. -parse_name(S, State) -> - parse_name(S, [], 0, State). +parse_name(S, IsArg, State) -> + parse_name(S, [], 0, IsArg, State). -parse_name([], _Acc, _Depth, State) -> +parse_name([], _Acc, _Depth, _IsArg, State) -> throw({error, State#state.loc, "expected ')', found end of string"}); -parse_name([$), T | S], Acc, 0, State) -> +parse_name([$), T | S], Acc, 0, IsArg, State) -> Type = case T of $d -> integer; $s -> string; $b -> boolean; + $H when IsArg -> host; _ -> throw({error, State#state.loc, ["unknown type specifier '", T, "'"]}) end, {lists:reverse(Acc), Type, S, State}; -parse_name([$)], _Acc, 0, State) -> +parse_name([$)], _Acc, 0, _IsArg, State) -> throw({error, State#state.loc, "expected type specifier, found end of string"}); -parse_name([$( = C | S], Acc, Depth, State) -> - parse_name(S, [C | Acc], Depth + 1, State); -parse_name([$) = C | S], Acc, Depth, State) -> - parse_name(S, [C | Acc], Depth - 1, State); -parse_name([C | S], Acc, Depth, State) -> - parse_name(S, [C | Acc], Depth, State). +parse_name([$( = C | S], Acc, Depth, IsArg, State) -> + parse_name(S, [C | Acc], Depth + 1, IsArg, State); +parse_name([$) = C | S], Acc, Depth, IsArg, State) -> + parse_name(S, [C | Acc], Depth - 1, IsArg, State); +parse_name([C | S], Acc, Depth, IsArg, State) -> + parse_name(S, [C | Acc], Depth, IsArg, State). make_var(V) -> @@ -444,7 +524,7 @@ make_sql_upsert_insert(Table, ParseRes) -> join_states(Fields, ", "), #state{'query' = [{str, ") VALUES ("}]}, join_states(Vals, ", "), - #state{'query' = [{str, ")"}]} + #state{'query' = [{str, ");"}]} ]), State. @@ -498,6 +578,49 @@ check_upsert(ParseRes, Pos) -> ok. +parse_insert(Fields) -> + {Fs, _} = + lists:foldr( + fun(F, {Acc, Param}) -> + case erl_syntax:type(F) of + string -> + V = erl_syntax:string_value(F), + {_, _, State} = Res = + parse_insert_field( + V, Param, erl_syntax:get_pos(F)), + {[Res | Acc], State#state.param_pos}; + _ -> + throw({error, erl_syntax:get_pos(F), + "?SQL_INSERT field must be " + "a constant string"}) + end + end, {[], 0}, Fields), + Fs. + +parse_insert_field([$! | _S], _ParamPos, Loc) -> + throw({error, Loc, + "?SQL_INSERT fields must not start with \"!\""}); +parse_insert_field([$- | _S], _ParamPos, Loc) -> + throw({error, Loc, + "?SQL_INSERT fields must not start with \"-\""}); +parse_insert_field(S, ParamPos, Loc) -> + {Name, ParseState} = parse_insert_field1(S, [], ParamPos, Loc), + {Name, {true}, ParseState}. + +parse_insert_field1([], _Acc, _ParamPos, Loc) -> + throw({error, Loc, + "?SQL_INSERT fields must have the " + "following form: \"name=value\""}); +parse_insert_field1([$= | S], Acc, ParamPos, Loc) -> + {lists:reverse(Acc), parse(S, ParamPos, Loc)}; +parse_insert_field1([C | S], Acc, ParamPos, Loc) -> + parse_insert_field1(S, [C | Acc], ParamPos, Loc). + + +make_sql_insert(Table, ParseRes) -> + make_sql_query(make_sql_upsert_insert(Table, ParseRes)). + + concat_states(States) -> lists:foldr( fun(ST11, ST2) -> @@ -566,3 +689,20 @@ set_pos(Tree, Pos) -> _ -> Node end end, Tree). + +filter_upsert_sh(Table, ParseRes) -> + case ?USE_NEW_SCHEMA of + true -> + ParseRes; + false -> + lists:filter( + fun({Field, _Match, _ST}) -> + Field /= "server_host" orelse Table == "route" + end, ParseRes) + end. + +add_warning(Pos, Warning) -> + Marker = erl_syntax:revert( + erl_syntax:warning_marker({Pos, ?MODULE, Warning})), + put(warnings, [Marker | get(warnings)]), + ok. |