summaryrefslogtreecommitdiff
path: root/src/ejabberd_oauth.erl
diff options
context:
space:
mode:
authorAlexey Shchepin <alexey@process-one.net>2016-07-20 16:55:45 +0300
committerAlexey Shchepin <alexey@process-one.net>2016-07-25 20:08:30 +0300
commit839490b0d9a8cbdcf13f5eb08412fae3c11ffcf6 (patch)
tree826f37a5538f4b71f83d5d6c2c7ff16abe952b30 /src/ejabberd_oauth.erl
parentExport acl:parse_ip_netmask/1 for mod_rest (ejabberd-contrib#175) (diff)
Add DB backend support for ejabberd_oauth
Diffstat (limited to 'src/ejabberd_oauth.erl')
-rw-r--r--src/ejabberd_oauth.erl65
1 files changed, 29 insertions, 36 deletions
diff --git a/src/ejabberd_oauth.erl b/src/ejabberd_oauth.erl
index 81b5f415..d4b1ff87 100644
--- a/src/ejabberd_oauth.erl
+++ b/src/ejabberd_oauth.erl
@@ -56,6 +56,7 @@
-include("ejabberd_http.hrl").
-include("ejabberd_web_admin.hrl").
+-include("ejabberd_oauth.hrl").
-include("ejabberd_commands.hrl").
@@ -64,17 +65,12 @@
%% * Using the web form/api results in the token being generated in behalf of the user providing the user/pass
%% * Using the command line and oauth_issue_token command, the token is generated in behalf of ejabberd' sysadmin
%% (as it has access to ejabberd command line).
--record(oauth_token, {
- token = {<<"">>, <<"">>} :: {binary(), binary()},
- us = {<<"">>, <<"">>} :: {binary(), binary()},
- scope = [] :: [binary()],
- expire :: integer()
- }).
-define(EXPIRE, 3600).
start() ->
- init_db(mnesia, ?MYNAME),
+ DBMod = get_db_backend(),
+ DBMod:init(),
Expire = expire(),
application:set_env(oauth2, backend, ejabberd_oauth),
application:set_env(oauth2, expiry_time, Expire),
@@ -172,15 +168,8 @@ handle_cast(_Msg, State) -> {noreply, State}.
handle_info(clean, State) ->
{MegaSecs, Secs, MiniSecs} = os:timestamp(),
TS = 1000000 * MegaSecs + Secs,
- F = fun() ->
- Ts = mnesia:select(
- oauth_token,
- [{#oauth_token{expire = '$1', _ = '_'},
- [{'<', '$1', TS}],
- ['$_']}]),
- lists:foreach(fun mnesia:delete_object/1, Ts)
- end,
- mnesia:async_dirty(F),
+ DBMod = get_db_backend(),
+ DBMod:clean(TS),
erlang:send_after(trunc(expire() * 1000 * (1 + MiniSecs / 1000000)),
self(), clean),
{noreply, State};
@@ -191,16 +180,6 @@ terminate(_Reason, _State) -> ok.
code_change(_OldVsn, State, _Extra) -> {ok, State}.
-init_db(mnesia, _Host) ->
- mnesia:create_table(oauth_token,
- [{disc_copies, [node()]},
- {attributes,
- record_info(fields, oauth_token)}]),
- mnesia:add_table_copy(oauth_token, node(), disc_copies);
-init_db(_, _) ->
- ok.
-
-
get_client_identity(Client, Ctx) -> {ok, {Ctx, {client, Client}}}.
verify_redirection_uri(_, _, Ctx) -> {ok, Ctx}.
@@ -305,7 +284,8 @@ associate_access_token(AccessToken, Context, AppContext) ->
scope = Scope,
expire = Expire
},
- mnesia:dirty_write(R),
+ DBMod = get_db_backend(),
+ DBMod:store(R),
{ok, AppContext}.
associate_refresh_token(_RefreshToken, _Context, AppContext) ->
@@ -315,10 +295,11 @@ associate_refresh_token(_RefreshToken, _Context, AppContext) ->
check_token(User, Server, ScopeList, Token) ->
LUser = jid:nodeprep(User),
LServer = jid:nameprep(Server),
- case catch mnesia:dirty_read(oauth_token, Token) of
- [#oauth_token{us = {LUser, LServer},
- scope = TokenScope,
- expire = Expire}] ->
+ DBMod = get_db_backend(),
+ case DBMod:lookup(Token) of
+ #oauth_token{us = {LUser, LServer},
+ scope = TokenScope,
+ expire = Expire} ->
{MegaSecs, Secs, _} = os:timestamp(),
TS = 1000000 * MegaSecs + Secs,
TokenScopeSet = oauth2_priv_set:new(TokenScope),
@@ -330,10 +311,11 @@ check_token(User, Server, ScopeList, Token) ->
end.
check_token(ScopeList, Token) ->
- case catch mnesia:dirty_read(oauth_token, Token) of
- [#oauth_token{us = US,
- scope = TokenScope,
- expire = Expire}] ->
+ DBMod = get_db_backend(),
+ case DBMod:lookup(Token) of
+ #oauth_token{us = US,
+ scope = TokenScope,
+ expire = Expire} ->
{MegaSecs, Secs, _} = os:timestamp(),
TS = 1000000 * MegaSecs + Secs,
TokenScopeSet = oauth2_priv_set:new(TokenScope),
@@ -548,6 +530,15 @@ process(_Handlers,
process(_Handlers, _Request) ->
ejabberd_web:error(not_found).
+-spec get_db_backend() -> module().
+
+get_db_backend() ->
+ DBType = ejabberd_config:get_option(
+ oauth_db_type,
+ fun(T) -> ejabberd_config:v_db(?MODULE, T) end,
+ mnesia),
+ list_to_atom("ejabberd_oauth_" ++ atom_to_list(DBType)).
+
%% Headers as per RFC 6749
json_response(Code, Body) ->
@@ -688,4 +679,6 @@ opt_type(oauth_expire) ->
fun(I) when is_integer(I), I >= 0 -> I end;
opt_type(oauth_access) ->
fun acl:access_rules_validator/1;
-opt_type(_) -> [oauth_expire, oauth_access].
+opt_type(oauth_db_type) ->
+ fun(T) -> ejabberd_config:v_db(?MODULE, T) end;
+opt_type(_) -> [oauth_expire, oauth_access, oauth_db_type].