diff options
Diffstat (limited to 'src/ejabberd_sql_pt.erl')
-rw-r--r-- | src/ejabberd_sql_pt.erl | 71 |
1 files changed, 62 insertions, 9 deletions
diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl index 0ae04c64d..2497c2a74 100644 --- a/src/ejabberd_sql_pt.erl +++ b/src/ejabberd_sql_pt.erl @@ -42,7 +42,8 @@ res_pos = 0, server_host_used = false, used_vars = [], - use_new_schema}). + use_new_schema, + need_array_pass = false}). -define(QUERY_RECORD, "sql_query"). @@ -183,12 +184,24 @@ transform_sql(Arg) -> Pos, no_server_host), [] end, - set_pos( - make_schema_check( - make_sql_query(ParseRes), - make_sql_query(ParseResOld) - ), - Pos). + case ParseRes#state.need_array_pass of + true -> + {PR1, PR2} = perform_array_pass(ParseRes), + {PRO1, PRO2} = perform_array_pass(ParseResOld), + set_pos(make_schema_check( + erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PR2)]), + erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PR1)])]), + erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PRO2)]), + erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PRO1)])])), + Pos); + false -> + set_pos( + make_schema_check( + make_sql_query(ParseRes), + make_sql_query(ParseResOld) + ), + Pos) + end. transform_upsert(Form, TableArg, FieldsArg) -> Table = erl_syntax:string_value(TableArg), @@ -315,8 +328,23 @@ parse1([$%, $( | S], Acc, State) -> erl_syntax:atom(?ESCAPE_RECORD), erl_syntax:atom(InternalType)), erl_syntax:variable(Name)]), - State2#state{'query' = [{var, Var} | State2#state.'query'], - args = [Convert | State2#state.args], + IT2 = case InternalType of + string -> + in_array_string; + _ -> + InternalType + end, + ConvertArr = erl_syntax:application( + erl_syntax:atom(ejabberd_sql), + erl_syntax:atom(to_array), + [erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(IT2)), + erl_syntax:variable(Name)]), + State2#state{'query' = [[{var, Var}] | State2#state.'query'], + need_array_pass = true, + args = [[Convert, ConvertArr] | State2#state.args], params = [Var | State2#state.params], param_pos = State2#state.param_pos + 1, used_vars = [Name | State2#state.used_vars]}; @@ -389,6 +417,31 @@ make_var(V) -> Var = "__V" ++ integer_to_list(V), erl_syntax:variable(Var). +perform_array_pass(State) -> + {NQ, PQ, Rest} = lists:foldl( + fun([{var, _} = Var], {N, P, {str, Str} = Prev}) -> + Str2 = re:replace(Str, "(^|\s+)in\s*$", " = any(", [{return, list}]), + {[Var, Prev | N], [{str, ")"}, Var, {str, Str2} | P], none}; + ([{var, _}], _) -> + throw({error, State#state.loc, ["List variable not following 'in' operator"]}); + (Other, {N, P, none}) -> + {N, P, Other}; + (Other, {N, P, Prev}) -> + {[Prev | N], [Prev | P], Other} + end, {[], [], none}, State#state.query), + {NQ2, PQ2} = case Rest of + none -> + {NQ, PQ}; + _ -> {[Rest | NQ], [Rest | PQ]} + end, + {NA, PA} = lists:foldl( + fun([V1, V2], {N, P}) -> + {[V1 | N], [V2 | P]}; + (Other, {N, P}) -> + {[Other | N], [Other | P]} + end, {[], []}, State#state.args), + {State#state{query = lists:reverse(NQ2), args = lists:reverse(NA), need_array_pass = false}, + State#state{query = lists:reverse(PQ2), args = lists:reverse(PA), need_array_pass = false}}. make_sql_query(State) -> Hash = erlang:phash2(State#state{loc = undefined, use_new_schema = true}), |