diff options
Diffstat (limited to 'src/ejabberd_auth_mnesia.erl')
-rw-r--r-- | src/ejabberd_auth_mnesia.erl | 453 |
1 files changed, 127 insertions, 326 deletions
diff --git a/src/ejabberd_auth_mnesia.erl b/src/ejabberd_auth_mnesia.erl index 592b9c566..02c22f9d5 100644 --- a/src/ejabberd_auth_mnesia.erl +++ b/src/ejabberd_auth_mnesia.erl @@ -31,15 +31,11 @@ -behaviour(ejabberd_auth). --export([start/1, stop/1, set_password/3, check_password/4, - check_password/6, try_register/3, - dirty_get_registered_users/0, get_vh_registered_users/1, - get_vh_registered_users/2, init_db/0, - get_vh_registered_users_number/1, - get_vh_registered_users_number/2, get_password/2, - get_password_s/2, is_user_exists/2, remove_user/2, - remove_user/3, store_type/0, export/1, import/2, - plain_password_required/0]). +-export([start/1, stop/1, set_password/3, try_register/3, + get_users/2, init_db/0, + count_users/2, get_password/2, + remove_user/2, store_type/1, export/1, import/2, + plain_password_required/1, use_cache/1]). -export([need_transform/1, transform/1]). -include("ejabberd.hrl"). @@ -52,8 +48,6 @@ -record(reg_users_counter, {vhost = <<"">> :: binary(), count = 0 :: integer() | '$1'}). --define(SALT_LENGTH, 16). - %%%---------------------------------------------------------------------- %%% API %%%---------------------------------------------------------------------- @@ -67,14 +61,14 @@ stop(_Host) -> init_db() -> ejabberd_mnesia:create(?MODULE, passwd, - [{disc_copies, [node()]}, + [{disc_only_copies, [node()]}, {attributes, record_info(fields, passwd)}]), ejabberd_mnesia:create(?MODULE, reg_users_counter, [{ram_copies, [node()]}, {attributes, record_info(fields, reg_users_counter)}]). update_reg_users_counter_table(Server) -> - Set = get_vh_registered_users(Server), + Set = get_users(Server, []), Size = length(Set), LServer = jid:nameprep(Server), F = fun () -> @@ -83,309 +77,153 @@ update_reg_users_counter_table(Server) -> end, mnesia:sync_dirty(F). -plain_password_required() -> - is_scrammed(). - -store_type() -> - ejabberd_auth:password_format(?MYNAME). - -check_password(User, AuthzId, Server, Password) -> - if AuthzId /= <<>> andalso AuthzId /= User -> - false; - true -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - US = {LUser, LServer}, - case catch mnesia:dirty_read({passwd, US}) of - [#passwd{password = Password}] when is_binary(Password) -> - Password /= <<"">>; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - is_password_scram_valid(Password, Scram); - _ -> false - end +use_cache(_) -> + case mnesia:table_info(passwd, storage_type) of + disc_only_copies -> true; + _ -> false end. -check_password(User, AuthzId, Server, Password, Digest, - DigestGen) -> - if AuthzId /= <<>> andalso AuthzId /= User -> - false; - true -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - US = {LUser, LServer}, - case catch mnesia:dirty_read({passwd, US}) of - [#passwd{password = Passwd}] when is_binary(Passwd) -> - DigRes = if Digest /= <<"">> -> - Digest == DigestGen(Passwd); - true -> false - end, - if DigRes -> true; - true -> (Passwd == Password) and (Password /= <<"">>) - end; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - Passwd = misc:decode_base64(Scram#scram.storedkey), - DigRes = if Digest /= <<"">> -> - Digest == DigestGen(Passwd); - true -> false - end, - if DigRes -> true; - true -> (Passwd == Password) and (Password /= <<"">>) - end; - _ -> false - end - end. +plain_password_required(Server) -> + store_type(Server) == scram. + +store_type(Server) -> + ejabberd_auth:password_format(Server). -%% @spec (User::string(), Server::string(), Password::string()) -> -%% ok | {error, invalid_jid} set_password(User, Server, Password) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - LPassword = jid:resourceprep(Password), - US = {LUser, LServer}, - if (LUser == error) or (LServer == error) -> - {error, invalid_jid}; - LPassword == error -> - {error, invalid_password}; - true -> - F = fun () -> - Password2 = case is_scrammed() and is_binary(Password) - of - true -> password_to_scram(Password); - false -> Password - end, - mnesia:write(#passwd{us = US, password = Password2}) - end, - {atomic, ok} = mnesia:transaction(F), - ok + US = {User, Server}, + F = fun () -> + mnesia:write(#passwd{us = US, password = Password}) + end, + case mnesia:transaction(F) of + {atomic, ok} -> + ok; + {aborted, Reason} -> + ?ERROR_MSG("Mnesia transaction failed: ~p", [Reason]), + {error, db_failure} end. -%% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid} | {error, not_allowed} | {error, Reason} -try_register(User, Server, PasswordList) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - Password = if is_list(PasswordList); is_binary(PasswordList) -> - iolist_to_binary(PasswordList); - true -> PasswordList - end, - LPassword = jid:resourceprep(Password), - US = {LUser, LServer}, - if (LUser == error) or (LServer == error) -> - {error, invalid_jid}; - (LPassword == error) and not is_record(Password, scram) -> - {error, invalid_password}; - true -> - F = fun () -> - case mnesia:read({passwd, US}) of - [] -> - Password2 = case is_scrammed() and - is_binary(Password) - of - true -> password_to_scram(Password); - false -> Password - end, - mnesia:write(#passwd{us = US, - password = Password2}), - mnesia:dirty_update_counter(reg_users_counter, - LServer, 1), - ok; - [_E] -> exists - end - end, - mnesia:transaction(F) +try_register(User, Server, Password) -> + US = {User, Server}, + F = fun () -> + case mnesia:read({passwd, US}) of + [] -> + mnesia:write(#passwd{us = US, password = Password}), + mnesia:dirty_update_counter(reg_users_counter, Server, 1), + ok; + [_] -> + {error, exists} + end + end, + case mnesia:transaction(F) of + {atomic, Res} -> + Res; + {aborted, Reason} -> + ?ERROR_MSG("Mnesia transaction failed: ~p", [Reason]), + {error, db_failure} end. -%% Get all registered users in Mnesia -dirty_get_registered_users() -> - mnesia:dirty_all_keys(passwd). - -get_vh_registered_users(Server) -> - LServer = jid:nameprep(Server), +get_users(Server, []) -> mnesia:dirty_select(passwd, [{#passwd{us = '$1', _ = '_'}, - [{'==', {element, 2, '$1'}, LServer}], ['$1']}]). - -get_vh_registered_users(Server, - [{from, Start}, {to, End}]) - when is_integer(Start) and is_integer(End) -> - get_vh_registered_users(Server, - [{limit, End - Start + 1}, {offset, Start}]); -get_vh_registered_users(Server, - [{limit, Limit}, {offset, Offset}]) - when is_integer(Limit) and is_integer(Offset) -> - case get_vh_registered_users(Server) of - [] -> []; - Users -> - Set = lists:keysort(1, Users), - L = length(Set), - Start = if Offset < 1 -> 1; - Offset > L -> L; - true -> Offset - end, - lists:sublist(Set, Start, Limit) + [{'==', {element, 2, '$1'}, Server}], ['$1']}]); +get_users(Server, [{from, Start}, {to, End}]) + when is_integer(Start) and is_integer(End) -> + get_users(Server, [{limit, End - Start + 1}, {offset, Start}]); +get_users(Server, [{limit, Limit}, {offset, Offset}]) + when is_integer(Limit) and is_integer(Offset) -> + case get_users(Server, []) of + [] -> + []; + Users -> + Set = lists:keysort(1, Users), + L = length(Set), + Start = if Offset < 1 -> 1; + Offset > L -> L; + true -> Offset + end, + lists:sublist(Set, Start, Limit) end; -get_vh_registered_users(Server, [{prefix, Prefix}]) - when is_binary(Prefix) -> - Set = [{U, S} - || {U, S} <- get_vh_registered_users(Server), - str:prefix(Prefix, U)], +get_users(Server, [{prefix, Prefix}]) when is_binary(Prefix) -> + Set = [{U, S} || {U, S} <- get_users(Server, []), str:prefix(Prefix, U)], lists:keysort(1, Set); -get_vh_registered_users(Server, - [{prefix, Prefix}, {from, Start}, {to, End}]) - when is_binary(Prefix) and is_integer(Start) and - is_integer(End) -> - get_vh_registered_users(Server, - [{prefix, Prefix}, {limit, End - Start + 1}, - {offset, Start}]); -get_vh_registered_users(Server, - [{prefix, Prefix}, {limit, Limit}, {offset, Offset}]) - when is_binary(Prefix) and is_integer(Limit) and - is_integer(Offset) -> - case [{U, S} - || {U, S} <- get_vh_registered_users(Server), - str:prefix(Prefix, U)] - of - [] -> []; - Users -> - Set = lists:keysort(1, Users), - L = length(Set), - Start = if Offset < 1 -> 1; - Offset > L -> L; - true -> Offset - end, - lists:sublist(Set, Start, Limit) +get_users(Server, [{prefix, Prefix}, {from, Start}, {to, End}]) + when is_binary(Prefix) and is_integer(Start) and is_integer(End) -> + get_users(Server, [{prefix, Prefix}, {limit, End - Start + 1}, + {offset, Start}]); +get_users(Server, [{prefix, Prefix}, {limit, Limit}, {offset, Offset}]) + when is_binary(Prefix) and is_integer(Limit) and is_integer(Offset) -> + case [{U, S} || {U, S} <- get_users(Server, []), str:prefix(Prefix, U)] of + [] -> + []; + Users -> + Set = lists:keysort(1, Users), + L = length(Set), + Start = if Offset < 1 -> 1; + Offset > L -> L; + true -> Offset + end, + lists:sublist(Set, Start, Limit) end; -get_vh_registered_users(Server, _) -> - get_vh_registered_users(Server). - -get_vh_registered_users_number(Server) -> - LServer = jid:nameprep(Server), - Query = mnesia:dirty_select(reg_users_counter, - [{#reg_users_counter{vhost = LServer, - count = '$1'}, - [], ['$1']}]), - case Query of - [Count] -> Count; - _ -> 0 - end. - -get_vh_registered_users_number(Server, - [{prefix, Prefix}]) - when is_binary(Prefix) -> - Set = [{U, S} - || {U, S} <- get_vh_registered_users(Server), - str:prefix(Prefix, U)], +get_users(Server, _) -> + get_users(Server, []). + +count_users(Server, []) -> + case mnesia:dirty_select( + reg_users_counter, + [{#reg_users_counter{vhost = Server, count = '$1'}, + [], ['$1']}]) of + [Count] -> Count; + _ -> 0 + end; +count_users(Server, [{prefix, Prefix}]) when is_binary(Prefix) -> + Set = [{U, S} || {U, S} <- get_users(Server, []), str:prefix(Prefix, U)], length(Set); -get_vh_registered_users_number(Server, _) -> - get_vh_registered_users_number(Server). +count_users(Server, _) -> + count_users(Server, []). get_password(User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - US = {LUser, LServer}, - case catch mnesia:dirty_read(passwd, US) of - [#passwd{password = Password}] - when is_binary(Password) -> - Password; - [#passwd{password = Scram}] - when is_record(Scram, scram) -> - {misc:decode_base64(Scram#scram.storedkey), - misc:decode_base64(Scram#scram.serverkey), - misc:decode_base64(Scram#scram.salt), - Scram#scram.iterationcount}; - _ -> false - end. - -get_password_s(User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - US = {LUser, LServer}, - case catch mnesia:dirty_read(passwd, US) of - [#passwd{password = Password}] - when is_binary(Password) -> - Password; - [#passwd{password = Scram}] - when is_record(Scram, scram) -> - <<"">>; - _ -> <<"">> - end. - -%% @spec (User, Server) -> true | false | {error, Error} -is_user_exists(User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - US = {LUser, LServer}, - case catch mnesia:dirty_read({passwd, US}) of - [] -> false; - [_] -> true; - Other -> {error, Other} + case mnesia:dirty_read(passwd, {User, Server}) of + [#passwd{password = Password}] -> + {ok, Password}; + _ -> + error end. -%% @spec (User, Server) -> ok -%% @doc Remove user. -%% Note: it returns ok even if there was some problem removing the user. remove_user(User, Server) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - US = {LUser, LServer}, + US = {User, Server}, F = fun () -> mnesia:delete({passwd, US}), - mnesia:dirty_update_counter(reg_users_counter, LServer, - -1) - end, - mnesia:transaction(F), - ok. - -%% @spec (User, Server, Password) -> ok | not_exists | not_allowed | bad_request -%% @doc Remove user if the provided password is correct. -remove_user(User, Server, Password) -> - LUser = jid:nodeprep(User), - LServer = jid:nameprep(Server), - US = {LUser, LServer}, - F = fun () -> - case mnesia:read({passwd, US}) of - [#passwd{password = Password}] - when is_binary(Password) -> - mnesia:delete({passwd, US}), - mnesia:dirty_update_counter(reg_users_counter, LServer, - -1), - ok; - [#passwd{password = Scram}] - when is_record(Scram, scram) -> - case is_password_scram_valid(Password, Scram) of - true -> - mnesia:delete({passwd, US}), - mnesia:dirty_update_counter(reg_users_counter, - LServer, -1), - ok; - false -> not_allowed - end; - _ -> not_exists - end + mnesia:dirty_update_counter(reg_users_counter, Server, -1), + ok end, case mnesia:transaction(F) of - {atomic, ok} -> ok; - {atomic, Res} -> Res; - _ -> bad_request + {atomic, ok} -> + ok; + {aborted, Reason} -> + ?ERROR_MSG("Mnesia transaction failed: ~p", [Reason]), + {error, db_failure} end. need_transform(#passwd{us = {U, S}, password = Pass}) -> if is_binary(Pass) -> - IsScrammed = is_scrammed(), - if IsScrammed -> + case store_type(S) of + scram -> ?INFO_MSG("Passwords in Mnesia table 'passwd' " - "will be SCRAM'ed", []); - true -> - ok - end, - IsScrammed; + "will be SCRAM'ed", []), + true; + plain -> + false + end; is_record(Pass, scram) -> - case is_scrammed() of - true -> - next; - false -> + case store_type(S) of + scram -> + false; + plain -> ?WARNING_MSG("Some passwords were stored in the database " "as SCRAM, but 'auth_password_format' " - "is not configured as 'scram'.", []), + "is not configured as 'scram': some " + "authentication mechanisms such as DIGEST-MD5 " + "would *fail*", []), false end; is_list(U) orelse is_list(S) orelse is_list(Pass) -> @@ -410,61 +248,24 @@ transform(#passwd{us = {U, S}, password = Pass} = R) transform(R#passwd{us = NewUS, password = NewPass}); transform(#passwd{us = {U, S}, password = Password} = P) when is_binary(Password) -> - case is_scrammed() of - true -> + case store_type(S) of + scram -> case jid:resourceprep(Password) of error -> ?ERROR_MSG("SASLprep failed for password of user ~s@~s", [U, S]), P; _ -> - Scram = password_to_scram(Password), + Scram = ejabberd_auth:password_to_scram(Password), P#passwd{password = Scram} end; - false -> + plain -> P end; transform(#passwd{password = Password} = P) when is_record(Password, scram) -> P. -%%% -%%% SCRAM -%%% - -is_scrammed() -> - scram == store_type(). - -password_to_scram(Password) -> - password_to_scram(Password, - ?SCRAM_DEFAULT_ITERATION_COUNT). - -password_to_scram(Password, IterationCount) -> - Salt = randoms:bytes(?SALT_LENGTH), - SaltedPassword = scram:salted_password(Password, Salt, - IterationCount), - StoredKey = - scram:stored_key(scram:client_key(SaltedPassword)), - ServerKey = scram:server_key(SaltedPassword), - #scram{storedkey = misc:encode_base64(StoredKey), - serverkey = misc:encode_base64(ServerKey), - salt = misc:encode_base64(Salt), - iterationcount = IterationCount}. - -is_password_scram_valid(Password, Scram) -> - case jid:resourceprep(Password) of - error -> - false; - _ -> - IterationCount = Scram#scram.iterationcount, - Salt = misc:decode_base64(Scram#scram.salt), - SaltedPassword = scram:salted_password(Password, Salt, - IterationCount), - StoredKey = - scram:stored_key(scram:client_key(SaltedPassword)), - misc:decode_base64(Scram#scram.storedkey) == StoredKey - end. - export(_Server) -> [{passwd, fun(Host, #passwd{us = {LUser, LServer}, password = Password}) |