diff options
author | Evgeniy Khramtsov <ekhramtsov@process-one.net> | 2017-05-11 14:37:21 +0300 |
---|---|---|
committer | Evgeniy Khramtsov <ekhramtsov@process-one.net> | 2017-05-11 14:37:21 +0300 |
commit | 633b68db1130c81551b063f3aa15d599b0d355e5 (patch) | |
tree | df2f0be4b75b001e8e47d1778e8e01637a9dfbcd /src/ejabberd_auth.erl | |
parent | Use misc:atom_to_binary/1 instead of the deprecated jlib.erl (#1510) (diff) |
Use cache for authentication backends
The commit introduces the following API incompatibilities:
In ejabberd_auth.erl:
* dirty_get_registered_users/0 is renamed to get_users/0
* get_vh_registered_users/1 is renamed to get_users/1
* get_vh_registered_users/2 is renamed to get_users/2
* get_vh_registered_users_number/1 is renamed to count_users/1
* get_vh_registered_users_number/2 is renamed to count_users/2
In ejabberd_auth callbacks
* plain_password_required/0 is replaced by plain_password_required/1
where the argument is a virtual host
* store_type/0 is replaced by store_type/1 where the argument is
a virtual host
* set_password/3 is now an optional callback
* remove_user/3 callback is no longer needed
* remove_user/2 now should return `ok | {error, atom()}`
* is_user_exists/2 now must only be implemented for backends
with `external` store type
* check_password/6 is no longer needed
* check_password/4 now must only be implemented for backends
with `external` store type
* try_register/3 is now an optional callback and should return
`ok | {error, atom()}`
* dirty_get_registered_users/0 is no longer needed
* get_vh_registered_users/1 is no longer needed
* get_vh_registered_users/2 is renamed to get_users/2
* get_vh_registered_users_number/1 is no longer needed
* get_vh_registered_users_number/2 is renamed to count_users/2
* get_password_s/2 is no longer needed
* get_password/2 now must only be implemented for backends with
`plain` or `scram` store type
Additionally, the commit introduces two new callbacks:
* use_cache/1 where the argument is a virtual host
* cache_nodes/1 where the argument is a virtual host
New options are also introduced: `auth_use_cache`, `auth_cache_missed`,
`auth_cache_life_time` and `auth_cache_size`.
Diffstat (limited to 'src/ejabberd_auth.erl')
-rw-r--r-- | src/ejabberd_auth.erl | 862 |
1 files changed, 582 insertions, 280 deletions
diff --git a/src/ejabberd_auth.erl b/src/ejabberd_auth.erl index 9751142a5..23ed0eeae 100644 --- a/src/ejabberd_auth.erl +++ b/src/ejabberd_auth.erl @@ -22,9 +22,6 @@ %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. %%% %%%---------------------------------------------------------------------- - -%% TODO: Use the functions in ejabberd auth to add and remove users. - -module(ejabberd_auth). -behaviour(gen_server). @@ -37,10 +34,10 @@ set_password/3, check_password/4, check_password/6, check_password_with_authmodule/4, check_password_with_authmodule/6, try_register/3, - dirty_get_registered_users/0, get_vh_registered_users/1, - get_vh_registered_users/2, export/1, import_info/0, - get_vh_registered_users_number/1, import/5, import_start/2, - get_vh_registered_users_number/2, get_password/2, + get_users/0, get_users/1, password_to_scram/1, + get_users/2, export/1, import_info/0, + count_users/1, import/5, import_start/2, + count_users/2, get_password/2, get_password_s/2, get_password_with_authmodule/2, is_user_exists/2, is_user_exists_in_other_modules/3, remove_user/2, remove_user/3, plain_password_required/1, @@ -54,10 +51,13 @@ -include("ejabberd.hrl"). -include("logger.hrl"). +-define(AUTH_CACHE, auth_cache). +-define(SALT_LENGTH, 16). + -record(state, {host_modules = #{} :: map()}). --type scrammed_password() :: {binary(), binary(), binary(), non_neg_integer()}. --type password() :: binary() | scrammed_password(). +-type password() :: binary() | #scram{}. +-type digest_fun() :: fun((binary()) -> binary()). -export_type([password/0]). %%%---------------------------------------------------------------------- @@ -69,24 +69,29 @@ -callback start(binary()) -> any(). -callback stop(binary()) -> any(). --callback plain_password_required() -> boolean(). --callback store_type() -> plain | external | scram. +-callback plain_password_required(binary()) -> boolean(). +-callback store_type(binary()) -> plain | external | scram. -callback set_password(binary(), binary(), binary()) -> ok | {error, atom()}. --callback remove_user(binary(), binary()) -> any(). --callback remove_user(binary(), binary(), binary()) -> any(). +-callback remove_user(binary(), binary()) -> ok | {error, any()}. -callback is_user_exists(binary(), binary()) -> boolean() | {error, atom()}. -callback check_password(binary(), binary(), binary(), binary()) -> boolean(). --callback check_password(binary(), binary(), binary(), binary(), binary(), - fun((binary()) -> binary())) -> boolean(). --callback try_register(binary(), binary(), binary()) -> {atomic, atom()} | - {error, atom()}. --callback dirty_get_registered_users() -> [{binary(), binary()}]. --callback get_vh_registered_users(binary()) -> [{binary(), binary()}]. --callback get_vh_registered_users(binary(), opts()) -> [{binary(), binary()}]. --callback get_vh_registered_users_number(binary()) -> number(). --callback get_vh_registered_users_number(binary(), opts()) -> number(). --callback get_password(binary(), binary()) -> false | password(). --callback get_password_s(binary(), binary()) -> password(). +-callback try_register(binary(), binary(), password()) -> ok | {error, atom()}. +-callback get_users(binary(), opts()) -> [{binary(), binary()}]. +-callback count_users(binary(), opts()) -> number(). +-callback get_password(binary(), binary()) -> {ok, password()} | error. +-callback use_cache(binary()) -> boolean(). +-callback cache_nodes(binary()) -> boolean(). + +-optional_callbacks([set_password/3, + remove_user/2, + is_user_exists/2, + check_password/4, + try_register/3, + get_users/2, + count_users/2, + get_password/2, + use_cache/1, + cache_nodes/1]). -spec start_link() -> {ok, pid()} | {error, any()}. start_link() -> @@ -99,9 +104,13 @@ init([]) -> HostModules = lists:foldl( fun(Host, Acc) -> Modules = auth_modules(Host), - start(Host, Modules), maps:put(Host, Modules, Acc) end, #{}, ?MYHOSTS), + lists:foreach( + fun({Host, Modules}) -> + start(Host, Modules) + end, maps:to_list(HostModules)), + init_cache(HostModules), {ok, #state{host_modules = HostModules}}. handle_call(_Request, _From, State) -> @@ -112,11 +121,13 @@ handle_cast({host_up, Host}, #state{host_modules = HostModules} = State) -> Modules = auth_modules(Host), start(Host, Modules), NewHostModules = maps:put(Host, Modules, HostModules), + init_cache(NewHostModules), {noreply, State#state{host_modules = NewHostModules}}; handle_cast({host_down, Host}, #state{host_modules = HostModules} = State) -> Modules = maps:get(Host, HostModules, []), stop(Host, Modules), NewHostModules = maps:remove(Host, HostModules), + init_cache(NewHostModules), {noreply, State#state{host_modules = NewHostModules}}; handle_cast(config_reloaded, #state{host_modules = HostModules} = State) -> NewHostModules = lists:foldl( @@ -127,6 +138,7 @@ handle_cast(config_reloaded, #state{host_modules = HostModules} = State) -> stop(Host, OldModules -- NewModules), maps:put(Host, NewModules, Acc) end, HostModules, ?MYHOSTS), + init_cache(NewHostModules), {noreply, State#state{host_modules = NewHostModules}}; handle_cast(Msg, State) -> ?WARNING_MSG("unexpected cast: ~p", [Msg]), @@ -162,306 +174,266 @@ host_down(Host) -> config_reloaded() -> gen_server:cast(?MODULE, config_reloaded). +-spec plain_password_required(binary()) -> boolean(). plain_password_required(Server) -> - lists:any(fun (M) -> M:plain_password_required() end, + lists:any(fun (M) -> M:plain_password_required(Server) end, auth_modules(Server)). +-spec store_type(binary()) -> plain | scram | external. store_type(Server) -> -%% @doc Check if the user and password can login in server. -%% @spec (User::string(), Server::string(), Password::string()) -> -%% true | false - lists:foldl(fun (_, external) -> external; - (M, scram) -> - case M:store_type() of - external -> external; - _Else -> scram - end; - (M, plain) -> M:store_type() - end, - plain, auth_modules(Server)). + lists:foldl( + fun(_, external) -> external; + (M, scram) -> + case M:store_type(Server) of + external -> external; + _ -> scram + end; + (M, plain) -> + M:store_type(Server) + end, plain, auth_modules(Server)). -spec check_password(binary(), binary(), binary(), binary()) -> boolean(). - check_password(User, AuthzId, Server, Password) -> - case check_password_with_authmodule(User, AuthzId, Server, - Password) - of - {true, _AuthModule} -> true; - false -> false - end. + check_password(User, AuthzId, Server, Password, <<"">>, undefined). -%% @doc Check if the user and password can login in server. -%% @spec (User::string(), AuthzId::string(), Server::string(), Password::string(), -%% Digest::string(), DigestGen::function()) -> -%% true | false -spec check_password(binary(), binary(), binary(), binary(), binary(), - fun((binary()) -> binary())) -> boolean(). - -check_password(User, AuthzId, Server, Password, Digest, - DigestGen) -> - case check_password_with_authmodule(User, AuthzId, Server, - Password, Digest, DigestGen) - of - {true, _AuthModule} -> true; - false -> false + digest_fun() | undefined) -> boolean(). +check_password(User, AuthzId, Server, Password, Digest, DigestGen) -> + case check_password_with_authmodule( + User, AuthzId, Server, Password, Digest, DigestGen) of + {true, _AuthModule} -> true; + false -> false end. -%% @doc Check if the user and password can login in server. -%% The user can login if at least an authentication method accepts the user -%% and the password. -%% The first authentication method that accepts the credentials is returned. -%% @spec (User::string(), AuthzId::string(), Server::string(), Password::string()) -> -%% {true, AuthModule} | false -%% where -%% AuthModule = ejabberd_auth_anonymous | ejabberd_auth_external -%% | ejabberd_auth_mnesia | ejabberd_auth_ldap -%% | ejabberd_auth_sql | ejabberd_auth_pam | ejabberd_auth_riak --spec check_password_with_authmodule(binary(), binary(), binary(), binary()) -> false | - {true, atom()}. - -check_password_with_authmodule(User, AuthzId, Server, - Password) -> - check_password_loop(auth_modules(Server), - [User, AuthzId, Server, Password]). - --spec check_password_with_authmodule(binary(), binary(), binary(), binary(), binary(), - fun((binary()) -> binary())) -> false | - {true, atom()}. - -check_password_with_authmodule(User, AuthzId, Server, Password, - Digest, DigestGen) -> - check_password_loop(auth_modules(Server), - [User, AuthzId, Server, Password, Digest, DigestGen]). - -check_password_loop([], _Args) -> false; -check_password_loop([AuthModule | AuthModules], Args) -> - case apply(AuthModule, check_password, Args) of - true -> {true, AuthModule}; - false -> check_password_loop(AuthModules, Args) +-spec check_password_with_authmodule(binary(), binary(), + binary(), binary()) -> false | {true, atom()}. +check_password_with_authmodule(User, AuthzId, Server, Password) -> + check_password_with_authmodule( + User, AuthzId, Server, Password, <<"">>, undefined). + +-spec check_password_with_authmodule( + binary(), binary(), binary(), binary(), binary(), + digest_fun() | undefined) -> false | {true, atom()}. +check_password_with_authmodule(User, AuthzId, Server, Password, Digest, DigestGen) -> + case validate_credentials(User, Server) of + {ok, LUser, LServer} -> + lists:foldl( + fun(Mod, false) -> + case db_check_password( + LUser, AuthzId, LServer, Password, + Digest, DigestGen, Mod) of + true -> {true, Mod}; + false -> false + end; + (_, Acc) -> + Acc + end, false, auth_modules(LServer)); + _ -> + false end. --spec set_password(binary(), binary(), binary()) -> ok | - {error, atom()}. - -%% @spec (User::string(), Server::string(), Password::string()) -> -%% ok | {error, ErrorType} -%% where ErrorType = empty_password | not_allowed | invalid_jid -set_password(_User, _Server, <<"">>) -> - {error, empty_password}; +-spec set_password(binary(), binary(), password()) -> ok | {error, atom()}. set_password(User, Server, Password) -> -%% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, not_allowed} - lists:foldl(fun (M, {error, _}) -> - M:set_password(User, Server, Password); - (_M, Res) -> Res - end, - {error, not_allowed}, auth_modules(Server)). - --spec try_register(binary(), binary(), binary()) -> {atomic, atom()} | - {error, atom()}. - -try_register(_User, _Server, <<"">>) -> - {error, not_allowed}; + case validate_credentials(User, Server, Password) of + {ok, LUser, LServer} -> + lists:foldl( + fun(M, {error, _}) -> + db_set_password(LUser, LServer, Password, M); + (_, ok) -> + ok + end, {error, not_allowed}, auth_modules(LServer)); + Err -> + Err + end. + +-spec try_register(binary(), binary(), password()) -> ok | {error, atom()}. try_register(User, Server, Password) -> - case is_user_exists(User, Server) of - true -> {atomic, exists}; - false -> - LServer = jid:nameprep(Server), - case ejabberd_router:is_my_host(LServer) of - true -> - Res = lists:foldl(fun (_M, {atomic, ok} = Res) -> Res; - (M, _) -> - M:try_register(User, Server, Password) - end, - {error, not_allowed}, auth_modules(Server)), - case Res of - {atomic, ok} -> - ejabberd_hooks:run(register_user, Server, - [User, Server]), - {atomic, ok}; - _ -> Res - end; - false -> {error, not_allowed} - end + case validate_credentials(User, Server, Password) of + {ok, LUser, LServer} -> + case is_user_exists(LUser, LServer) of + true -> + {error, exists}; + false -> + case ejabberd_router:is_my_host(LServer) of + true -> + case lists:foldl( + fun(_, ok) -> + ok; + (Mod, _) -> + db_try_register( + User, Server, Password, Mod) + end, {error, not_allowed}, + auth_modules(LServer)) of + ok -> + ejabberd_hooks:run( + register_user, Server, [User, Server]); + {error, _} = Err -> + Err + end; + false -> + {error, not_allowed} + end + end; + Err -> + Err end. -%% Registered users list do not include anonymous users logged --spec dirty_get_registered_users() -> [{binary(), binary()}]. - -dirty_get_registered_users() -> - lists:flatmap(fun (M) -> M:dirty_get_registered_users() - end, - auth_modules()). - --spec get_vh_registered_users(binary()) -> [{binary(), binary()}]. - -%% Registered users list do not include anonymous users logged -get_vh_registered_users(Server) -> - lists:flatmap(fun (M) -> - M:get_vh_registered_users(Server) - end, - auth_modules(Server)). - --spec get_vh_registered_users(binary(), opts()) -> [{binary(), binary()}]. - -get_vh_registered_users(Server, Opts) -> - lists:flatmap(fun (M) -> - case erlang:function_exported(M, - get_vh_registered_users, - 2) - of - true -> M:get_vh_registered_users(Server, Opts); - false -> M:get_vh_registered_users(Server) - end - end, - auth_modules(Server)). - -get_vh_registered_users_number(Server) -> - lists:sum(lists:map(fun (M) -> - case erlang:function_exported(M, - get_vh_registered_users_number, - 1) - of - true -> - M:get_vh_registered_users_number(Server); - false -> - length(M:get_vh_registered_users(Server)) - end - end, - auth_modules(Server))). - --spec get_vh_registered_users_number(binary(), opts()) -> number(). - -get_vh_registered_users_number(Server, Opts) -> -%% @doc Get the password of the user. -%% @spec (User::string(), Server::string()) -> Password::string() - lists:sum(lists:map(fun (M) -> - case erlang:function_exported(M, - get_vh_registered_users_number, - 2) - of - true -> - M:get_vh_registered_users_number(Server, - Opts); - false -> - length(M:get_vh_registered_users(Server)) - end - end, - auth_modules(Server))). +-spec get_users() -> [{binary(), binary()}]. +get_users() -> + lists:flatmap( + fun({Host, Mod}) -> + db_get_users(Host, [], Mod) + end, auth_modules()). + +-spec get_users(binary()) -> [{binary(), binary()}]. +get_users(Server) -> + get_users(Server, []). + +-spec get_users(binary(), opts()) -> [{binary(), binary()}]. +get_users(Server, Opts) -> + case jid:nameprep(Server) of + error -> []; + LServer -> + lists:flatmap( + fun(M) -> db_get_users(LServer, Opts, M) end, + auth_modules(LServer)) + end. --spec get_password(binary(), binary()) -> false | password(). +-spec count_users(binary()) -> non_neg_integer(). +count_users(Server) -> + count_users(Server, []). + +-spec count_users(binary(), opts()) -> non_neg_integer(). +count_users(Server, Opts) -> + case jid:nameprep(Server) of + error -> 0; + LServer -> + lists:sum( + lists:map( + fun(M) -> db_count_users(LServer, Opts, M) end, + auth_modules(LServer))) + end. +-spec get_password(binary(), binary()) -> false | password(). get_password(User, Server) -> - lists:foldl(fun (M, false) -> - M:get_password(User, Server); - (_M, Password) -> Password - end, - false, auth_modules(Server)). + case validate_credentials(User, Server) of + {ok, LUser, LServer} -> + case lists:foldl( + fun(M, error) -> db_get_password(LUser, LServer, M); + (_M, Acc) -> Acc + end, error, auth_modules(LServer)) of + {ok, Password} -> + Password; + error -> + false + end; + _ -> + false + end. -spec get_password_s(binary(), binary()) -> password(). - get_password_s(User, Server) -> case get_password(User, Server) of false -> <<"">>; Password -> Password end. -%% @doc Get the password of the user and the auth module. -%% @spec (User::string(), Server::string()) -> -%% {Password::string(), AuthModule::atom()} | {false, none} -spec get_password_with_authmodule(binary(), binary()) -> {false | password(), module()}. - get_password_with_authmodule(User, Server) -> -%% Returns true if the user exists in the DB or if an anonymous user is logged -%% under the given name - lists:foldl(fun (M, {false, _}) -> - {M:get_password(User, Server), M}; - (_M, {Password, AuthModule}) -> {Password, AuthModule} - end, - {false, none}, auth_modules(Server)). + case validate_credentials(User, Server) of + {ok, LUser, LServer} -> + case lists:foldl( + fun(M, {error, _}) -> + {db_get_password(LUser, LServer, M), M}; + (_M, Acc) -> + Acc + end, {error, undefined}, auth_modules(LServer)) of + {{ok, Password}, Module} -> + {Password, Module}; + {error, Module} -> + {false, Module} + end; + _ -> + {false, undefined} + end. -spec is_user_exists(binary(), binary()) -> boolean(). - is_user_exists(_User, <<"">>) -> false; - is_user_exists(User, Server) -> -%% Check if the user exists in all authentications module except the module -%% passed as parameter -%% @spec (Module::atom(), User, Server) -> true | false | maybe - lists:any(fun (M) -> - case M:is_user_exists(User, Server) of - {error, Error} -> - ?ERROR_MSG("The authentication module ~p returned " - "an error~nwhen checking user ~p in server " - "~p~nError message: ~p", - [M, User, Server, Error]), - false; - Else -> Else + case validate_credentials(User, Server) of + {ok, LUser, LServer} -> + lists:any( + fun(M) -> + case db_is_user_exists(LUser, LServer, M) of + {error, _} -> + false; + Else -> + Else end - end, - auth_modules(Server)). + end, auth_modules(LServer)); + _ -> + false + end. -spec is_user_exists_in_other_modules(atom(), binary(), binary()) -> boolean() | maybe. - is_user_exists_in_other_modules(Module, User, Server) -> - is_user_exists_in_other_modules_loop(auth_modules(Server) - -- [Module], - User, Server). + is_user_exists_in_other_modules_loop( + auth_modules(Server) -- [Module], User, Server). -is_user_exists_in_other_modules_loop([], _User, - _Server) -> +is_user_exists_in_other_modules_loop([], _User, _Server) -> false; -is_user_exists_in_other_modules_loop([AuthModule - | AuthModules], - User, Server) -> - case AuthModule:is_user_exists(User, Server) of - true -> true; - false -> - is_user_exists_in_other_modules_loop(AuthModules, User, - Server); - {error, Error} -> - ?DEBUG("The authentication module ~p returned " - "an error~nwhen checking user ~p in server " - "~p~nError message: ~p", - [AuthModule, User, Server, Error]), - maybe +is_user_exists_in_other_modules_loop([AuthModule | AuthModules], User, Server) -> + case db_is_user_exists(User, Server, AuthModule) of + true -> + true; + false -> + is_user_exists_in_other_modules_loop(AuthModules, User, Server); + {error, _} -> + maybe end. -spec remove_user(binary(), binary()) -> ok. - -%% @spec (User, Server) -> ok -%% @doc Remove user. -%% Note: it may return ok even if there was some problem removing the user. remove_user(User, Server) -> - lists:foreach(fun (M) -> M:remove_user(User, Server) - end, - auth_modules(Server)), - ejabberd_hooks:run(remove_user, jid:nameprep(Server), - [User, Server]), - ok. - -%% @spec (User, Server, Password) -> ok | not_exists | not_allowed | bad_request | error -%% @doc Try to remove user if the provided password is correct. -%% The removal is attempted in each auth method provided: -%% when one returns 'ok' the loop stops; -%% if no method returns 'ok' then it returns the error message indicated by the last method attempted. --spec remove_user(binary(), binary(), binary()) -> any(). + case validate_credentials(User, Server) of + {ok, LUser, LServer} -> + lists:foreach( + fun(Mod) -> db_remove_user(LUser, LServer, Mod) end, + auth_modules(LServer)), + ejabberd_hooks:run(remove_user, LServer, [LUser, LServer]); + _Err -> + ok + end. +-spec remove_user(binary(), binary(), password()) -> ok | {error, atom()}. remove_user(User, Server, Password) -> - R = lists:foldl(fun (_M, ok = Res) -> Res; - (M, _) -> M:remove_user(User, Server, Password) - end, - error, auth_modules(Server)), - case R of - ok -> - ejabberd_hooks:run(remove_user, jid:nameprep(Server), - [User, Server]); - _ -> none - end, - R. - -%% @spec (IOList) -> non_negative_float() + case validate_credentials(User, Server, Password) of + {ok, LUser, LServer} -> + case lists:foldl( + fun (_, ok) -> + ok; + (Mod, _) -> + case db_check_password( + LUser, <<"">>, LServer, Password, + <<"">>, undefined, Mod) of + true -> + db_remove_user(LUser, LServer, Mod); + false -> + {error, not_allowed} + end + end, {error, not_allowed}, auth_modules(Server)) of + ok -> + ejabberd_hooks:run( + remove_user, LServer, [LUser, LServer]); + Err -> + Err + end; + Err -> + Err + end. + %% @doc Calculate informational entropy. +-spec entropy(iodata()) -> float(). entropy(B) -> case binary_to_list(B) of "" -> 0.0; @@ -497,15 +469,266 @@ backend_type(Mod) -> _ -> Mod end. +-spec password_format(binary() | global) -> plain | scram. password_format(LServer) -> ejabberd_config:get_option({auth_password_format, LServer}, plain). %%%---------------------------------------------------------------------- +%%% Backend calls +%%%---------------------------------------------------------------------- +db_try_register(User, Server, Password, Mod) -> + case erlang:function_exported(Mod, try_register, 3) of + true -> + Password1 = case Mod:store_type(Server) of + scram -> password_to_scram(Password); + _ -> Password + end, + case use_cache(Mod, Server) of + true -> + case ets_cache:update( + ?AUTH_CACHE, {User, Server}, {ok, Password}, + fun() -> Mod:try_register(User, Server, Password1) end, + cache_nodes(Mod, Server)) of + {ok, _} -> ok; + {error, _} = Err -> Err + end; + false -> + Mod:try_register(User, Server, Password1) + end; + false -> + {error, not_allowed} + end. + +db_set_password(User, Server, Password, Mod) -> + case erlang:function_exported(Mod, set_password, 3) of + true -> + Password1 = case Mod:store_type(Server) of + scram -> password_to_scram(Password); + _ -> Password + end, + case use_cache(Mod, Server) of + true -> + case ets_cache:update( + ?AUTH_CACHE, {User, Server}, {ok, Password}, + fun() -> Mod:set_password(User, Server, Password1) end, + cache_nodes(Mod, Server)) of + {ok, _} -> ok; + {error, _} = Err -> Err + end; + false -> + Mod:set_password(User, Server, Password1) + end; + false -> + {error, not_allowed} + end. + +db_get_password(User, Server, Mod) -> + UseCache = use_cache(Mod, Server), + case erlang:function_exported(Mod, get_password, 2) of + false when UseCache -> + ets_cache:lookup(?AUTH_CACHE, {User, Server}); + false -> + error; + true when UseCache -> + ets_cache:lookup( + ?AUTH_CACHE, {User, Server}, + fun() -> Mod:get_password(User, Server) end); + true -> + Mod:get_password(User, Server) + end. + +db_is_user_exists(User, Server, Mod) -> + case db_get_password(User, Server, Mod) of + {ok, _} -> + true; + error -> + case Mod:store_type(Server) of + external -> + Mod:is_user_exists(User, Server); + _ -> + false + end + end. + +db_check_password(User, AuthzId, Server, ProvidedPassword, + Digest, DigestFun, Mod) -> + case db_get_password(User, Server, Mod) of + {ok, ValidPassword} -> + match_passwords(ProvidedPassword, ValidPassword, Digest, DigestFun); + error -> + case {Mod:store_type(Server), use_cache(Mod, Server)} of + {external, true} -> + case ets_cache:update( + ?AUTH_CACHE, {User, Server}, {ok, ProvidedPassword}, + fun() -> + case Mod:check_password( + User, AuthzId, Server, ProvidedPassword) of + true -> + {ok, ProvidedPassword}; + false -> + error + end + end, cache_nodes(Mod, Server)) of + {ok, _} -> + true; + error -> + false + end; + {external, false} -> + Mod:check_password(User, AuthzId, Server, ProvidedPassword); + _ -> + false + end + end. + +db_remove_user(User, Server, Mod) -> + case erlang:function_exported(Mod, remove_user, 2) of + true -> + case Mod:remove_user(User, Server) of + ok -> + case use_cache(Mod, Server) of + true -> + ets_cache:delete(?AUTH_CACHE, {User, Server}, + cache_nodes(Mod, Server)); + false -> + ok + end; + {error, _} = Err -> + Err + end; + false -> + {error, not_allowed} + end. + +db_get_users(Server, Opts, Mod) -> + case erlang:function_exported(Mod, get_users, 2) of + true -> + Mod:get_users(Server, Opts); + false -> + case use_cache(Mod, Server) of + true -> + ets_cache:fold( + fun({User, S}, {ok, _}, Users) when S == Server -> + [{User, Server}|Users]; + (_, _, Users) -> + Users + end, [], ?AUTH_CACHE); + false -> + [] + end + end. + +db_count_users(Server, Opts, Mod) -> + case erlang:function_exported(Mod, count_users, 2) of + true -> + Mod:count_users(Server, Opts); + false -> + case use_cache(Mod, Server) of + true -> + ets_cache:fold( + fun({_, S}, {ok, _}, Num) when S == Server -> + Num + 1; + (_, _, Num) -> + Num + end, 0, ?AUTH_CACHE); + false -> + 0 + end + end. + +%%%---------------------------------------------------------------------- +%%% SCRAM stuff +%%%---------------------------------------------------------------------- +is_password_scram_valid(Password, Scram) -> + case jid:resourceprep(Password) of + error -> + false; + _ -> + IterationCount = Scram#scram.iterationcount, + Salt = misc:decode_base64(Scram#scram.salt), + SaltedPassword = scram:salted_password(Password, Salt, IterationCount), + StoredKey = scram:stored_key(scram:client_key(SaltedPassword)), + misc:decode_base64(Scram#scram.storedkey) == StoredKey + end. + +password_to_scram(Password) -> + password_to_scram(Password, ?SCRAM_DEFAULT_ITERATION_COUNT). + +password_to_scram(#scram{} = Password, _IterationCount) -> + Password; +password_to_scram(Password, IterationCount) -> + Salt = randoms:bytes(?SALT_LENGTH), + SaltedPassword = scram:salted_password(Password, Salt, IterationCount), + StoredKey = scram:stored_key(scram:client_key(SaltedPassword)), + ServerKey = scram:server_key(SaltedPassword), + #scram{storedkey = misc:encode_base64(StoredKey), + serverkey = misc:encode_base64(ServerKey), + salt = misc:encode_base64(Salt), + iterationcount = IterationCount}. + +%%%---------------------------------------------------------------------- +%%% Cache stuff +%%%---------------------------------------------------------------------- +-spec init_cache(map()) -> ok. +init_cache(HostModules) -> + case use_cache(HostModules) of + true -> + ets_cache:new(?AUTH_CACHE, cache_opts()); + false -> + ets_cache:delete(?AUTH_CACHE) + end. + +-spec cache_opts() -> [proplists:property()]. +cache_opts() -> + MaxSize = ejabberd_config:get_option( + auth_cache_size, + ejabberd_config:cache_size(global)), + CacheMissed = ejabberd_config:get_option( + auth_cache_missed, + ejabberd_config:cache_missed(global)), + LifeTime = case ejabberd_config:get_option( + auth_cache_life_time, + ejabberd_config:cache_life_time(global)) of + infinity -> infinity; + I -> timer:seconds(I) + end, + [{max_size, MaxSize}, {cache_missed, CacheMissed}, {life_time, LifeTime}]. + +-spec use_cache(map()) -> boolean(). +use_cache(HostModules) -> + lists:any( + fun({Host, Modules}) -> + lists:any(fun(Module) -> + use_cache(Module, Host) + end, Modules) + end, maps:to_list(HostModules)). + +-spec use_cache(module(), binary()) -> boolean(). +use_cache(Mod, LServer) -> + case erlang:function_exported(Mod, use_cache, 1) of + true -> Mod:use_cache(LServer); + false -> + ejabberd_config:get_option( + {auth_use_cache, LServer}, + ejabberd_config:use_cache(LServer)) + end. + +-spec cache_nodes(module(), binary()) -> [node()]. +cache_nodes(Mod, LServer) -> + case erlang:function_exported(Mod, cache_nodes, 1) of + true -> Mod:cache_nodes(LServer); + false -> ejabberd_cluster:get_nodes() + end. + +%%%---------------------------------------------------------------------- %%% Internal functions %%%---------------------------------------------------------------------- --spec auth_modules() -> [module()]. +-spec auth_modules() -> [{binary(), module()}]. auth_modules() -> - lists:usort(lists:flatmap(fun auth_modules/1, ?MYHOSTS)). + lists:flatmap( + fun(Host) -> + [{Host, Mod} || Mod <- auth_modules(Host)] + end, ?MYHOSTS). -spec auth_modules(binary()) -> [module()]. auth_modules(Server) -> @@ -516,6 +739,65 @@ auth_modules(Server) -> (misc:atom_to_binary(M))/binary>>) || M <- Methods]. +-spec match_passwords(password(), password(), + binary(), digest_fun() | undefined) -> boolean(). +match_passwords(Password, #scram{} = Scram, <<"">>, undefined) -> + is_password_scram_valid(Password, Scram); +match_passwords(Password, #scram{} = Scram, Digest, DigestFun) -> + StoredKey = misc:decode_base64(Scram#scram.storedkey), + DigRes = if Digest /= <<"">> -> + Digest == DigestFun(StoredKey); + true -> false + end, + if DigRes -> + true; + true -> + StoredKey == Password andalso Password /= <<"">> + end; +match_passwords(ProvidedPassword, ValidPassword, <<"">>, undefined) -> + ProvidedPassword == ValidPassword andalso ProvidedPassword /= <<"">>; +match_passwords(ProvidedPassword, ValidPassword, Digest, DigestFun) -> + DigRes = if Digest /= <<"">> -> + Digest == DigestFun(ValidPassword); + true -> false + end, + if DigRes -> + true; + true -> + ValidPassword == ProvidedPassword andalso ProvidedPassword /= <<"">> + end. + +-spec validate_credentials(binary(), binary()) -> + {ok, binary(), binary()} | {error, invalid_jid}. +validate_credentials(User, Server) -> + validate_credentials(User, Server, #scram{}). + +-spec validate_credentials(binary(), binary(), password()) -> + {ok, binary(), binary()} | {error, invalid_jid | invalid_password}. +validate_credentials(_User, _Server, <<"">>) -> + {error, invalid_password}; +validate_credentials(User, Server, Password) -> + case jid:nodeprep(User) of + error -> + {error, invalid_jid}; + LUser -> + case jid:nameprep(Server) of + error -> + {error, invalid_jid}; + LServer -> + if is_record(Password, scram) -> + {ok, LUser, LServer}; + true -> + case jid:resourceprep(Password) of + error -> + {error, invalid_password}; + _ -> + {ok, LUser, LServer} + end + end + end + end. + export(Server) -> ejabberd_auth_mnesia:export(Server). @@ -536,6 +818,10 @@ import(_LServer, {sql, _}, sql, <<"users">>, _) -> -spec opt_type(auth_method) -> fun((atom() | [atom()]) -> [atom()]); (auth_password_format) -> fun((plain | scram) -> plain | scram); + (auth_use_cache) -> fun((boolean()) -> boolean()); + (auth_cache_missed) -> fun((boolean()) -> boolean()); + (auth_cache_life_time) -> fun((timeout()) -> timeout()); + (auth_cache_size) -> fun((timeout()) -> timeout()); (atom()) -> [atom()]. opt_type(auth_method) -> fun (V) when is_list(V) -> @@ -546,4 +832,20 @@ opt_type(auth_password_format) -> fun (plain) -> plain; (scram) -> scram end; -opt_type(_) -> [auth_method, auth_password_format]. +opt_type(auth_use_cache) -> + fun(B) when is_boolean(B) -> B end; +opt_type(auth_cache_missed) -> + fun(B) when is_boolean(B) -> B end; +opt_type(auth_cache_life_time) -> + fun(I) when is_integer(I), I>0 -> I; + (unlimited) -> infinity; + (infinity) -> infinity + end; +opt_type(auth_cache_size) -> + fun(I) when is_integer(I), I>0 -> I; + (unlimited) -> infinity; + (infinity) -> infinity + end; +opt_type(_) -> + [auth_method, auth_password_format, auth_use_cache, + auth_cache_missed, auth_cache_life_time, auth_cache_size]. |