diff options
Diffstat (limited to 'src/ejabberd_auth_internal.erl')
-rw-r--r-- | src/ejabberd_auth_internal.erl | 547 |
1 files changed, 276 insertions, 271 deletions
diff --git a/src/ejabberd_auth_internal.erl b/src/ejabberd_auth_internal.erl index 4b5bcd327..b3587e211 100644 --- a/src/ejabberd_auth_internal.erl +++ b/src/ejabberd_auth_internal.erl @@ -25,32 +25,29 @@ %%%---------------------------------------------------------------------- -module(ejabberd_auth_internal). + -author('alexey@process-one.net'). +-behaviour(ejabberd_auth). + %% External exports --export([start/1, - set_password/3, - check_password/3, - check_password/5, - try_register/3, - dirty_get_registered_users/0, - get_vh_registered_users/1, +-export([start/1, set_password/3, check_password/3, + check_password/5, try_register/3, + dirty_get_registered_users/0, get_vh_registered_users/1, get_vh_registered_users/2, get_vh_registered_users_number/1, - get_vh_registered_users_number/2, - get_password/2, - get_password_s/2, - is_user_exists/2, - remove_user/2, - remove_user/3, - store_type/0, - plain_password_required/0 - ]). + get_vh_registered_users_number/2, get_password/2, + get_password_s/2, is_user_exists/2, remove_user/2, + remove_user/3, store_type/0, export/1, + plain_password_required/0]). -include("ejabberd.hrl"). --record(passwd, {us, password}). --record(reg_users_counter, {vhost, count}). +-record(passwd, {us = {<<"">>, <<"">>} :: {binary(), binary()} | '$1', + password = <<"">> :: binary() | scram() | '_'}). + +-record(reg_users_counter, {vhost = <<"">> :: binary(), + count = 0 :: integer() | '$1'}). -define(SALT_LENGTH, 16). @@ -58,8 +55,9 @@ %%% API %%%---------------------------------------------------------------------- start(Host) -> - mnesia:create_table(passwd, [{disc_copies, [node()]}, - {attributes, record_info(fields, passwd)}]), + mnesia:create_table(passwd, + [{disc_copies, [node()]}, + {attributes, record_info(fields, passwd)}]), mnesia:create_table(reg_users_counter, [{ram_copies, [node()]}, {attributes, record_info(fields, reg_users_counter)}]), @@ -72,22 +70,22 @@ update_reg_users_counter_table(Server) -> Set = get_vh_registered_users(Server), Size = length(Set), LServer = jlib:nameprep(Server), - F = fun() -> - mnesia:write(#reg_users_counter{vhost = LServer, - count = Size}) + F = fun () -> + mnesia:write(#reg_users_counter{vhost = LServer, + count = Size}) end, mnesia:sync_dirty(F). plain_password_required() -> case is_scrammed() of - false -> false; - true -> true + false -> false; + true -> true end. store_type() -> case is_scrammed() of - false -> plain; %% allows: PLAIN DIGEST-MD5 SCRAM - true -> scram %% allows: PLAIN SCRAM + false -> plain; %% allows: PLAIN DIGEST-MD5 SCRAM + true -> scram %% allows: PLAIN SCRAM end. check_password(User, Server, Password) -> @@ -95,46 +93,40 @@ check_password(User, Server, Password) -> LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read({passwd, US}) of - [#passwd{password = Password}] when is_list(Password) -> - Password /= ""; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - is_password_scram_valid(Password, Scram); - _ -> - false + [#passwd{password = Password}] + when is_binary(Password) -> + Password /= <<"">>; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + is_password_scram_valid(Password, Scram); + _ -> false end. -check_password(User, Server, Password, Digest, DigestGen) -> +check_password(User, Server, Password, Digest, + DigestGen) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read({passwd, US}) of - [#passwd{password = Passwd}] when is_list(Passwd) -> - DigRes = if - Digest /= "" -> - Digest == DigestGen(Passwd); - true -> - false - end, - if DigRes -> - true; - true -> - (Passwd == Password) and (Password /= "") - end; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - Passwd = base64:decode(Scram#scram.storedkey), - DigRes = if - Digest /= "" -> - Digest == DigestGen(Passwd); - true -> - false - end, - if DigRes -> - true; - true -> - (Passwd == Password) and (Password /= "") - end; - _ -> - false + [#passwd{password = Passwd}] when is_binary(Passwd) -> + DigRes = if Digest /= <<"">> -> + Digest == DigestGen(Passwd); + true -> false + end, + if DigRes -> true; + true -> (Passwd == Password) and (Password /= <<"">>) + end; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + Passwd = jlib:decode_base64(Scram#scram.storedkey), + DigRes = if Digest /= <<"">> -> + Digest == DigestGen(Passwd); + true -> false + end, + if DigRes -> true; + true -> (Passwd == Password) and (Password /= <<"">>) + end; + _ -> false end. %% @spec (User::string(), Server::string(), Password::string()) -> @@ -143,49 +135,48 @@ set_password(User, Server, Password) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), US = {LUser, LServer}, - if - (LUser == error) or (LServer == error) -> - {error, invalid_jid}; - true -> - F = fun() -> - Password2 = case is_scrammed() and is_list(Password) of - true -> password_to_scram(Password); - false -> Password - end, - mnesia:write(#passwd{us = US, - password = Password2}) - end, - {atomic, ok} = mnesia:transaction(F), - ok + if (LUser == error) or (LServer == error) -> + {error, invalid_jid}; + true -> + F = fun () -> + Password2 = case is_scrammed() and is_binary(Password) + of + true -> password_to_scram(Password); + false -> Password + end, + mnesia:write(#passwd{us = US, password = Password2}) + end, + {atomic, ok} = mnesia:transaction(F), + ok end. %% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid} | {aborted, Reason} -try_register(User, Server, Password) -> +try_register(User, Server, PasswordList) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), + Password = iolist_to_binary(PasswordList), US = {LUser, LServer}, - if - (LUser == error) or (LServer == error) -> - {error, invalid_jid}; - true -> - F = fun() -> - case mnesia:read({passwd, US}) of - [] -> - Password2 = case is_scrammed() and is_list(Password) of - true -> password_to_scram(Password); - false -> Password - end, - mnesia:write(#passwd{us = US, - password = Password2}), - mnesia:dirty_update_counter( - reg_users_counter, - LServer, 1), - ok; - [_E] -> - exists - end - end, - mnesia:transaction(F) + if (LUser == error) or (LServer == error) -> + {error, invalid_jid}; + true -> + F = fun () -> + case mnesia:read({passwd, US}) of + [] -> + Password2 = case is_scrammed() and + is_binary(Password) + of + true -> password_to_scram(Password); + false -> Password + end, + mnesia:write(#passwd{us = US, + password = Password2}), + mnesia:dirty_update_counter(reg_users_counter, + LServer, 1), + ok; + [_E] -> exists + end + end, + mnesia:transaction(F) end. %% Get all registered users in Mnesia @@ -194,75 +185,81 @@ dirty_get_registered_users() -> get_vh_registered_users(Server) -> LServer = jlib:nameprep(Server), - mnesia:dirty_select( - passwd, - [{#passwd{us = '$1', _ = '_'}, - [{'==', {element, 2, '$1'}, LServer}], - ['$1']}]). - -get_vh_registered_users(Server, [{from, Start}, {to, End}]) - when is_integer(Start) and is_integer(End) -> - get_vh_registered_users(Server, [{limit, End-Start+1}, {offset, Start}]); - -get_vh_registered_users(Server, [{limit, Limit}, {offset, Offset}]) - when is_integer(Limit) and is_integer(Offset) -> + mnesia:dirty_select(passwd, + [{#passwd{us = '$1', _ = '_'}, + [{'==', {element, 2, '$1'}, LServer}], ['$1']}]). + +get_vh_registered_users(Server, + [{from, Start}, {to, End}]) + when is_integer(Start) and is_integer(End) -> + get_vh_registered_users(Server, + [{limit, End - Start + 1}, {offset, Start}]); +get_vh_registered_users(Server, + [{limit, Limit}, {offset, Offset}]) + when is_integer(Limit) and is_integer(Offset) -> case get_vh_registered_users(Server) of - [] -> - []; - Users -> - Set = lists:keysort(1, Users), - L = length(Set), - Start = if Offset < 1 -> 1; - Offset > L -> L; - true -> Offset - end, - lists:sublist(Set, Start, Limit) + [] -> []; + Users -> + Set = lists:keysort(1, Users), + L = length(Set), + Start = if Offset < 1 -> 1; + Offset > L -> L; + true -> Offset + end, + lists:sublist(Set, Start, Limit) end; - -get_vh_registered_users(Server, [{prefix, Prefix}]) - when is_list(Prefix) -> - Set = [{U,S} || {U, S} <- get_vh_registered_users(Server), lists:prefix(Prefix, U)], +get_vh_registered_users(Server, [{prefix, Prefix}]) + when is_binary(Prefix) -> + Set = [{U, S} + || {U, S} <- get_vh_registered_users(Server), + str:prefix(Prefix, U)], lists:keysort(1, Set); - -get_vh_registered_users(Server, [{prefix, Prefix}, {from, Start}, {to, End}]) - when is_list(Prefix) and is_integer(Start) and is_integer(End) -> - get_vh_registered_users(Server, [{prefix, Prefix}, {limit, End-Start+1}, {offset, Start}]); - -get_vh_registered_users(Server, [{prefix, Prefix}, {limit, Limit}, {offset, Offset}]) - when is_list(Prefix) and is_integer(Limit) and is_integer(Offset) -> - case [{U,S} || {U, S} <- get_vh_registered_users(Server), lists:prefix(Prefix, U)] of - [] -> - []; - Users -> - Set = lists:keysort(1, Users), - L = length(Set), - Start = if Offset < 1 -> 1; - Offset > L -> L; - true -> Offset - end, - lists:sublist(Set, Start, Limit) +get_vh_registered_users(Server, + [{prefix, Prefix}, {from, Start}, {to, End}]) + when is_binary(Prefix) and is_integer(Start) and + is_integer(End) -> + get_vh_registered_users(Server, + [{prefix, Prefix}, {limit, End - Start + 1}, + {offset, Start}]); +get_vh_registered_users(Server, + [{prefix, Prefix}, {limit, Limit}, {offset, Offset}]) + when is_binary(Prefix) and is_integer(Limit) and + is_integer(Offset) -> + case [{U, S} + || {U, S} <- get_vh_registered_users(Server), + str:prefix(Prefix, U)] + of + [] -> []; + Users -> + Set = lists:keysort(1, Users), + L = length(Set), + Start = if Offset < 1 -> 1; + Offset > L -> L; + true -> Offset + end, + lists:sublist(Set, Start, Limit) end; - get_vh_registered_users(Server, _) -> get_vh_registered_users(Server). get_vh_registered_users_number(Server) -> LServer = jlib:nameprep(Server), - Query = mnesia:dirty_select( - reg_users_counter, - [{#reg_users_counter{vhost = LServer, count = '$1'}, - [], - ['$1']}]), + Query = mnesia:dirty_select(reg_users_counter, + [{#reg_users_counter{vhost = LServer, + count = '$1'}, + [], ['$1']}]), case Query of - [Count] -> - Count; - _ -> 0 + [Count] -> Count; + _ -> 0 end. -get_vh_registered_users_number(Server, [{prefix, Prefix}]) when is_list(Prefix) -> - Set = [{U, S} || {U, S} <- get_vh_registered_users(Server), lists:prefix(Prefix, U)], +get_vh_registered_users_number(Server, + [{prefix, Prefix}]) + when is_binary(Prefix) -> + Set = [{U, S} + || {U, S} <- get_vh_registered_users(Server), + str:prefix(Prefix, U)], length(Set); - get_vh_registered_users_number(Server, _) -> get_vh_registered_users_number(Server). @@ -271,15 +268,16 @@ get_password(User, Server) -> LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read(passwd, US) of - [#passwd{password = Password}] when is_list(Password) -> - Password; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - {base64:decode(Scram#scram.storedkey), - base64:decode(Scram#scram.serverkey), - base64:decode(Scram#scram.salt), - Scram#scram.iterationcount}; - _ -> - false + [#passwd{password = Password}] + when is_binary(Password) -> + Password; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + {jlib:decode_base64(Scram#scram.storedkey), + jlib:decode_base64(Scram#scram.serverkey), + jlib:decode_base64(Scram#scram.salt), + Scram#scram.iterationcount}; + _ -> false end. get_password_s(User, Server) -> @@ -287,12 +285,13 @@ get_password_s(User, Server) -> LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read(passwd, US) of - [#passwd{password = Password}] when is_list(Password) -> - Password; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - []; - _ -> - [] + [#passwd{password = Password}] + when is_binary(Password) -> + Password; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + <<"">>; + _ -> <<"">> end. %% @spec (User, Server) -> true | false | {error, Error} @@ -301,12 +300,9 @@ is_user_exists(User, Server) -> LServer = jlib:nameprep(Server), US = {LUser, LServer}, case catch mnesia:dirty_read({passwd, US}) of - [] -> - false; - [_] -> - true; - Other -> - {error, Other} + [] -> false; + [_] -> true; + Other -> {error, Other} end. %% @spec (User, Server) -> ok @@ -316,13 +312,13 @@ remove_user(User, Server) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), US = {LUser, LServer}, - F = fun() -> + F = fun () -> mnesia:delete({passwd, US}), - mnesia:dirty_update_counter(reg_users_counter, - LServer, -1) - end, + mnesia:dirty_update_counter(reg_users_counter, LServer, + -1) + end, mnesia:transaction(F), - ok. + ok. %% @spec (User, Server, Password) -> ok | not_exists | not_allowed | bad_request %% @doc Remove user if the provided password is correct. @@ -330,79 +326,65 @@ remove_user(User, Server, Password) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), US = {LUser, LServer}, - F = fun() -> + F = fun () -> case mnesia:read({passwd, US}) of - [#passwd{password = Password}] when is_list(Password) -> - mnesia:delete({passwd, US}), - mnesia:dirty_update_counter(reg_users_counter, - LServer, -1), - ok; - [#passwd{password = Scram}] when is_record(Scram, scram) -> - case is_password_scram_valid(Password, Scram) of - true -> - mnesia:delete({passwd, US}), - mnesia:dirty_update_counter(reg_users_counter, - LServer, -1), - ok; - false -> - not_allowed - end; - _ -> - not_exists + [#passwd{password = Password}] + when is_binary(Password) -> + mnesia:delete({passwd, US}), + mnesia:dirty_update_counter(reg_users_counter, LServer, + -1), + ok; + [#passwd{password = Scram}] + when is_record(Scram, scram) -> + case is_password_scram_valid(Password, Scram) of + true -> + mnesia:delete({passwd, US}), + mnesia:dirty_update_counter(reg_users_counter, + LServer, -1), + ok; + false -> not_allowed + end; + _ -> not_exists end - end, + end, case mnesia:transaction(F) of - {atomic, ok} -> - ok; - {atomic, Res} -> - Res; - _ -> - bad_request + {atomic, ok} -> ok; + {atomic, Res} -> Res; + _ -> bad_request end. update_table() -> Fields = record_info(fields, passwd), case mnesia:table_info(passwd, attributes) of - Fields -> - maybe_scram_passwords(), - ok; - [user, password] -> - ?INFO_MSG("Converting passwd table from " - "{user, password} format", []), - Host = ?MYNAME, - {atomic, ok} = mnesia:create_table( - ejabberd_auth_internal_tmp_table, - [{disc_only_copies, [node()]}, - {type, bag}, - {local_content, true}, - {record_name, passwd}, - {attributes, record_info(fields, passwd)}]), - mnesia:transform_table(passwd, ignore, Fields), - F1 = fun() -> - mnesia:write_lock_table(ejabberd_auth_internal_tmp_table), - mnesia:foldl( - fun(#passwd{us = U} = R, _) -> - mnesia:dirty_write( - ejabberd_auth_internal_tmp_table, - R#passwd{us = {U, Host}}) - end, ok, passwd) - end, - mnesia:transaction(F1), - mnesia:clear_table(passwd), - F2 = fun() -> - mnesia:write_lock_table(passwd), - mnesia:foldl( - fun(R, _) -> - mnesia:dirty_write(R) - end, ok, ejabberd_auth_internal_tmp_table) - end, - mnesia:transaction(F2), - mnesia:delete_table(ejabberd_auth_internal_tmp_table); - _ -> - ?INFO_MSG("Recreating passwd table", []), - mnesia:transform_table(passwd, ignore, Fields) + Fields -> + convert_to_binary(Fields), + maybe_scram_passwords(), + ok; + _ -> + ?INFO_MSG("Recreating passwd table", []), + mnesia:transform_table(passwd, ignore, Fields) end. +convert_to_binary(Fields) -> + ejabberd_config:convert_table_to_binary( + passwd, Fields, set, + fun(#passwd{us = {U, _}}) -> U end, + fun(#passwd{us = {U, S}, password = Pass} = R) -> + NewUS = {iolist_to_binary(U), iolist_to_binary(S)}, + NewPass = case Pass of + #scram{storedkey = StoredKey, + serverkey = ServerKey, + salt = Salt} -> + Pass#scram{ + storedkey = iolist_to_binary(StoredKey), + serverkey = iolist_to_binary(ServerKey), + salt = iolist_to_binary(Salt)}; + _ -> + iolist_to_binary(Pass) + end, + R#passwd{us = NewUS, password = NewPass} + end). + %%% %%% SCRAM %%% @@ -411,38 +393,43 @@ update_table() -> %% or if at least the first password is scrammed. is_scrammed() -> OptionScram = is_option_scram(), - FirstElement = mnesia:dirty_read(passwd, mnesia:dirty_first(passwd)), + FirstElement = mnesia:dirty_read(passwd, + mnesia:dirty_first(passwd)), case {OptionScram, FirstElement} of - {true, _} -> - true; - {false, [#passwd{password = Scram}]} when is_record(Scram, scram) -> - true; - _ -> - false + {true, _} -> true; + {false, [#passwd{password = Scram}]} + when is_record(Scram, scram) -> + true; + _ -> false end. is_option_scram() -> - scram == ejabberd_config:get_local_option({auth_password_format, ?MYNAME}). + scram == + ejabberd_config:get_local_option({auth_password_format, ?MYNAME}, + fun(V) -> V end). maybe_alert_password_scrammed_without_option() -> case is_scrammed() andalso not is_option_scram() of - true -> - ?ERROR_MSG("Some passwords were stored in the database as SCRAM, " - "but 'auth_password_format' is not configured 'scram'. " - "The option will now be considered to be 'scram'.", []); - false -> - ok + true -> + ?ERROR_MSG("Some passwords were stored in the database " + "as SCRAM, but 'auth_password_format' " + "is not configured 'scram'. The option " + "will now be considered to be 'scram'.", + []); + false -> ok end. maybe_scram_passwords() -> case is_scrammed() of - true -> scram_passwords(); - false -> ok + true -> scram_passwords(); + false -> ok end. scram_passwords() -> - ?INFO_MSG("Converting the stored passwords into SCRAM bits", []), - Fun = fun(#passwd{password = Password} = P) -> + ?INFO_MSG("Converting the stored passwords into " + "SCRAM bits", + []), + Fun = fun (#passwd{password = Password} = P) -> Scram = password_to_scram(Password), P#passwd{password = Scram} end, @@ -450,21 +437,39 @@ scram_passwords() -> mnesia:transform_table(passwd, Fun, Fields). password_to_scram(Password) -> - password_to_scram(Password, ?SCRAM_DEFAULT_ITERATION_COUNT). + password_to_scram(Password, + ?SCRAM_DEFAULT_ITERATION_COUNT). password_to_scram(Password, IterationCount) -> Salt = crypto:rand_bytes(?SALT_LENGTH), - SaltedPassword = scram:salted_password(Password, Salt, IterationCount), - StoredKey = scram:stored_key(scram:client_key(SaltedPassword)), + SaltedPassword = scram:salted_password(Password, Salt, + IterationCount), + StoredKey = + scram:stored_key(scram:client_key(SaltedPassword)), ServerKey = scram:server_key(SaltedPassword), - #scram{storedkey = base64:encode(StoredKey), - serverkey = base64:encode(ServerKey), - salt = base64:encode(Salt), + #scram{storedkey = jlib:encode_base64(StoredKey), + serverkey = jlib:encode_base64(ServerKey), + salt = jlib:encode_base64(Salt), iterationcount = IterationCount}. is_password_scram_valid(Password, Scram) -> IterationCount = Scram#scram.iterationcount, - Salt = base64:decode(Scram#scram.salt), - SaltedPassword = scram:salted_password(Password, Salt, IterationCount), - StoredKey = scram:stored_key(scram:client_key(SaltedPassword)), - (base64:decode(Scram#scram.storedkey) == StoredKey). + Salt = jlib:decode_base64(Scram#scram.salt), + SaltedPassword = scram:salted_password(Password, Salt, + IterationCount), + StoredKey = + scram:stored_key(scram:client_key(SaltedPassword)), + jlib:decode_base64(Scram#scram.storedkey) == StoredKey. + +export(_Server) -> + [{passwd, + fun(Host, #passwd{us = {LUser, LServer}, password = Password}) + when LServer == Host -> + Username = ejabberd_odbc:escape(LUser), + Pass = ejabberd_odbc:escape(Password), + [[<<"delete from users where username='">>, Username, <<"';">>], + [<<"insert into users(username, password) " + "values ('">>, Username, <<"', '">>, Pass, <<"');">>]]; + (_Host, _R) -> + [] + end}]. |