aboutsummaryrefslogtreecommitdiff
path: root/src/ejabberd_sql_pt.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/ejabberd_sql_pt.erl')
-rw-r--r--src/ejabberd_sql_pt.erl200
1 files changed, 170 insertions, 30 deletions
diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl
index e90947a5f..197e5ee6d 100644
--- a/src/ejabberd_sql_pt.erl
+++ b/src/ejabberd_sql_pt.erl
@@ -26,7 +26,7 @@
-module(ejabberd_sql_pt).
%% API
--export([parse_transform/2]).
+-export([parse_transform/2, format_error/1]).
-export([parse/2]).
@@ -39,7 +39,8 @@
args = [],
res = [],
res_vars = [],
- res_pos = 0}).
+ res_pos = 0,
+ server_host_used = false}).
-define(QUERY_RECORD, "sql_query").
@@ -48,6 +49,12 @@
-define(MOD, sql__module_).
+-ifdef(NEW_SQL_SCHEMA).
+-define(USE_NEW_SCHEMA, true).
+-else.
+-define(USE_NEW_SCHEMA, false).
+-endif.
+
%%====================================================================
%% API
%%====================================================================
@@ -57,11 +64,14 @@
%%--------------------------------------------------------------------
parse_transform(AST, _Options) ->
%io:format("PT: ~p~nOpts: ~p~n", [AST, Options]),
+ put(warnings, []),
NewAST = top_transform(AST),
%io:format("NewPT: ~p~n", [NewAST]),
- NewAST.
+ NewAST ++ get(warnings).
+format_error(no_server_host) ->
+ "server_host field is not used".
%%====================================================================
%% Internal functions
@@ -80,6 +90,12 @@ transform(Form) ->
S = erl_syntax:string_value(Arg),
Pos = erl_syntax:get_pos(Arg),
ParseRes = parse(S, Pos),
+ if
+ ParseRes#state.server_host_used ->
+ ok;
+ true ->
+ add_warning(Pos, no_server_host)
+ end,
set_pos(make_sql_query(ParseRes), Pos);
_ ->
throw({error, erl_syntax:get_pos(Form),
@@ -101,8 +117,17 @@ transform(Form) ->
parse_upsert(
erl_syntax:list_elements(FieldsArg)),
Pos = erl_syntax:get_pos(Form),
+ case lists:keymember(
+ "server_host", 1, ParseRes) of
+ true ->
+ ok;
+ false ->
+ add_warning(Pos, no_server_host)
+ end,
+ ParseRes2 =
+ filter_upsert_sh(Table, ParseRes),
set_pos(
- make_sql_upsert(Table, ParseRes, Pos),
+ make_sql_upsert(Table, ParseRes2, Pos),
Pos);
_ ->
throw({error, erl_syntax:get_pos(Form),
@@ -113,6 +138,38 @@ transform(Form) ->
throw({error, erl_syntax:get_pos(Form),
"wrong number of ?SQL_UPSERT args"})
end;
+ {?SQL_INSERT_MARK, 2} ->
+ case erl_syntax:application_arguments(Form) of
+ [TableArg, FieldsArg] ->
+ case {erl_syntax:type(TableArg),
+ erl_syntax:is_proper_list(FieldsArg)}of
+ {string, true} ->
+ Table = erl_syntax:string_value(TableArg),
+ ParseRes =
+ parse_insert(
+ erl_syntax:list_elements(FieldsArg)),
+ Pos = erl_syntax:get_pos(Form),
+ case lists:keymember(
+ "server_host", 1, ParseRes) of
+ true ->
+ ok;
+ false ->
+ add_warning(Pos, no_server_host)
+ end,
+ ParseRes2 =
+ filter_upsert_sh(Table, ParseRes),
+ set_pos(
+ make_sql_insert(Table, ParseRes2),
+ Pos);
+ _ ->
+ throw({error, erl_syntax:get_pos(Form),
+ "?SQL_INSERT arguments must be "
+ "a constant string and a list"})
+ end;
+ _ ->
+ throw({error, erl_syntax:get_pos(Form),
+ "wrong number of ?SQL_INSERT args"})
+ end;
_ ->
Form
end;
@@ -168,7 +225,7 @@ parse1([], Acc, State) ->
};
parse1([$@, $( | S], Acc, State) ->
State1 = append_string(lists:reverse(Acc), State),
- {Name, Type, S1, State2} = parse_name(S, State1),
+ {Name, Type, S1, State2} = parse_name(S, false, State1),
Var = "__V" ++ integer_to_list(State2#state.res_pos),
EVar = erl_syntax:variable(Var),
Convert =
@@ -192,21 +249,43 @@ parse1([$@, $( | S], Acc, State) ->
parse1(S1, [], State4);
parse1([$%, $( | S], Acc, State) ->
State1 = append_string(lists:reverse(Acc), State),
- {Name, Type, S1, State2} = parse_name(S, State1),
+ {Name, Type, S1, State2} = parse_name(S, true, State1),
Var = State2#state.param_pos,
- Convert =
- erl_syntax:application(
- erl_syntax:record_access(
- erl_syntax:variable(?ESCAPE_VAR),
- erl_syntax:atom(?ESCAPE_RECORD),
- erl_syntax:atom(Type)),
- [erl_syntax:variable(Name)]),
- State3 = State2,
State4 =
- State3#state{'query' = [{var, Var} | State3#state.'query'],
- args = [Convert | State3#state.args],
- params = [Var | State3#state.params],
- param_pos = State3#state.param_pos + 1},
+ case Type of
+ host ->
+ State3 = State2#state{server_host_used = true},
+ case ?USE_NEW_SCHEMA of
+ true ->
+ Convert =
+ erl_syntax:application(
+ erl_syntax:record_access(
+ erl_syntax:variable(?ESCAPE_VAR),
+ erl_syntax:atom(?ESCAPE_RECORD),
+ erl_syntax:atom(string)),
+ [erl_syntax:variable(Name)]),
+ State3#state{'query' = [{var, Var},
+ {str, "server_host="} |
+ State3#state.'query'],
+ args = [Convert | State3#state.args],
+ params = [Var | State3#state.params],
+ param_pos = State3#state.param_pos + 1};
+ false ->
+ append_string("0=0", State3)
+ end;
+ _ ->
+ Convert =
+ erl_syntax:application(
+ erl_syntax:record_access(
+ erl_syntax:variable(?ESCAPE_VAR),
+ erl_syntax:atom(?ESCAPE_RECORD),
+ erl_syntax:atom(Type)),
+ [erl_syntax:variable(Name)]),
+ State2#state{'query' = [{var, Var} | State2#state.'query'],
+ args = [Convert | State2#state.args],
+ params = [Var | State2#state.params],
+ param_pos = State2#state.param_pos + 1}
+ end,
parse1(S1, [], State4);
parse1([C | S], Acc, State) ->
parse1(S, [C | Acc], State).
@@ -216,32 +295,33 @@ append_string([], State) ->
append_string(S, State) ->
State#state{query = [{str, S} | State#state.query]}.
-parse_name(S, State) ->
- parse_name(S, [], 0, State).
+parse_name(S, IsArg, State) ->
+ parse_name(S, [], 0, IsArg, State).
-parse_name([], _Acc, _Depth, State) ->
+parse_name([], _Acc, _Depth, _IsArg, State) ->
throw({error, State#state.loc,
"expected ')', found end of string"});
-parse_name([$), T | S], Acc, 0, State) ->
+parse_name([$), T | S], Acc, 0, IsArg, State) ->
Type =
case T of
$d -> integer;
$s -> string;
$b -> boolean;
+ $H when IsArg -> host;
_ ->
throw({error, State#state.loc,
["unknown type specifier '", T, "'"]})
end,
{lists:reverse(Acc), Type, S, State};
-parse_name([$)], _Acc, 0, State) ->
+parse_name([$)], _Acc, 0, _IsArg, State) ->
throw({error, State#state.loc,
"expected type specifier, found end of string"});
-parse_name([$( = C | S], Acc, Depth, State) ->
- parse_name(S, [C | Acc], Depth + 1, State);
-parse_name([$) = C | S], Acc, Depth, State) ->
- parse_name(S, [C | Acc], Depth - 1, State);
-parse_name([C | S], Acc, Depth, State) ->
- parse_name(S, [C | Acc], Depth, State).
+parse_name([$( = C | S], Acc, Depth, IsArg, State) ->
+ parse_name(S, [C | Acc], Depth + 1, IsArg, State);
+parse_name([$) = C | S], Acc, Depth, IsArg, State) ->
+ parse_name(S, [C | Acc], Depth - 1, IsArg, State);
+parse_name([C | S], Acc, Depth, IsArg, State) ->
+ parse_name(S, [C | Acc], Depth, IsArg, State).
make_var(V) ->
@@ -444,7 +524,7 @@ make_sql_upsert_insert(Table, ParseRes) ->
join_states(Fields, ", "),
#state{'query' = [{str, ") VALUES ("}]},
join_states(Vals, ", "),
- #state{'query' = [{str, ")"}]}
+ #state{'query' = [{str, ");"}]}
]),
State.
@@ -498,6 +578,49 @@ check_upsert(ParseRes, Pos) ->
ok.
+parse_insert(Fields) ->
+ {Fs, _} =
+ lists:foldr(
+ fun(F, {Acc, Param}) ->
+ case erl_syntax:type(F) of
+ string ->
+ V = erl_syntax:string_value(F),
+ {_, _, State} = Res =
+ parse_insert_field(
+ V, Param, erl_syntax:get_pos(F)),
+ {[Res | Acc], State#state.param_pos};
+ _ ->
+ throw({error, erl_syntax:get_pos(F),
+ "?SQL_INSERT field must be "
+ "a constant string"})
+ end
+ end, {[], 0}, Fields),
+ Fs.
+
+parse_insert_field([$! | _S], _ParamPos, Loc) ->
+ throw({error, Loc,
+ "?SQL_INSERT fields must not start with \"!\""});
+parse_insert_field([$- | _S], _ParamPos, Loc) ->
+ throw({error, Loc,
+ "?SQL_INSERT fields must not start with \"-\""});
+parse_insert_field(S, ParamPos, Loc) ->
+ {Name, ParseState} = parse_insert_field1(S, [], ParamPos, Loc),
+ {Name, {true}, ParseState}.
+
+parse_insert_field1([], _Acc, _ParamPos, Loc) ->
+ throw({error, Loc,
+ "?SQL_INSERT fields must have the "
+ "following form: \"name=value\""});
+parse_insert_field1([$= | S], Acc, ParamPos, Loc) ->
+ {lists:reverse(Acc), parse(S, ParamPos, Loc)};
+parse_insert_field1([C | S], Acc, ParamPos, Loc) ->
+ parse_insert_field1(S, [C | Acc], ParamPos, Loc).
+
+
+make_sql_insert(Table, ParseRes) ->
+ make_sql_query(make_sql_upsert_insert(Table, ParseRes)).
+
+
concat_states(States) ->
lists:foldr(
fun(ST11, ST2) ->
@@ -566,3 +689,20 @@ set_pos(Tree, Pos) ->
_ -> Node
end
end, Tree).
+
+filter_upsert_sh(Table, ParseRes) ->
+ case ?USE_NEW_SCHEMA of
+ true ->
+ ParseRes;
+ false ->
+ lists:filter(
+ fun({Field, _Match, _ST}) ->
+ Field /= "server_host" orelse Table == "route"
+ end, ParseRes)
+ end.
+
+add_warning(Pos, Warning) ->
+ Marker = erl_syntax:revert(
+ erl_syntax:warning_marker({Pos, ?MODULE, Warning})),
+ put(warnings, [Marker | get(warnings)]),
+ ok.