%%%---------------------------------------------------------------------- %%% File : cyrsasl_scram.erl %%% Author : Stephen Röttger <stephen.roettger@googlemail.com> %%% Purpose : SASL SCRAM authentication %%% Created : 7 Aug 2011 by Stephen Röttger <stephen.roettger@googlemail.com> %%% %%% %%% ejabberd, Copyright (C) 2002-2017 ProcessOne %%% %%% This program is free software; you can redistribute it and/or %%% modify it under the terms of the GNU General Public License as %%% published by the Free Software Foundation; either version 2 of the %%% License, or (at your option) any later version. %%% %%% This program is distributed in the hope that it will be useful, %%% but WITHOUT ANY WARRANTY; without even the implied warranty of %%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU %%% General Public License for more details. %%% %%% You should have received a copy of the GNU General Public License along %%% with this program; if not, write to the Free Software Foundation, Inc., %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. %%% %%%---------------------------------------------------------------------- -module(cyrsasl_scram). -author('stephen.roettger@googlemail.com'). -protocol({rfc, 5802}). -export([start/1, stop/0, mech_new/4, mech_step/2, format_error/1]). -include("ejabberd.hrl"). -include("logger.hrl"). -behaviour(cyrsasl). -record(state, {step = 2 :: 2 | 4, stored_key = <<"">> :: binary(), server_key = <<"">> :: binary(), username = <<"">> :: binary(), auth_module :: module(), get_password :: fun((binary()) -> {false | ejabberd_auth:password(), module()}), auth_message = <<"">> :: binary(), client_nonce = <<"">> :: binary(), server_nonce = <<"">> :: binary()}). -define(SALT_LENGTH, 16). -define(NONCE_LENGTH, 16). -type error_reason() :: unsupported_extension | bad_username | not_authorized | saslprep_failed | parser_failed | bad_attribute | nonce_mismatch | bad_channel_binding. -export_type([error_reason/0]). start(_Opts) -> cyrsasl:register_mechanism(<<"SCRAM-SHA-1">>, ?MODULE, scram). stop() -> ok. -spec format_error(error_reason()) -> {atom(), binary()}. format_error(unsupported_extension) -> {'bad-protocol', <<"Unsupported extension">>}; format_error(bad_username) -> {'invalid-authzid', <<"Malformed username">>}; format_error(not_authorized) -> {'not-authorized', <<"Invalid username or password">>}; format_error(saslprep_failed) -> {'not-authorized', <<"SASLprep failed">>}; format_error(parser_failed) -> {'bad-protocol', <<"Response decoding failed">>}; format_error(bad_attribute) -> {'bad-protocol', <<"Malformed or unexpected attribute">>}; format_error(nonce_mismatch) -> {'bad-protocol', <<"Nonce mismatch">>}; format_error(bad_channel_binding) -> {'bad-protocol', <<"Invalid channel binding">>}. mech_new(_Host, GetPassword, _CheckPassword, _CheckPasswordDigest) -> {ok, #state{step = 2, get_password = GetPassword}}. mech_step(#state{step = 2} = State, ClientIn) -> case re:split(ClientIn, <<",">>, [{return, binary}]) of [_CBind, _AuthorizationIdentity, _UserNameAttribute, _ClientNonceAttribute, ExtensionAttribute | _] when ExtensionAttribute /= <<"">> -> {error, unsupported_extension}; [CBind, _AuthorizationIdentity, UserNameAttribute, ClientNonceAttribute | _] when (CBind == <<"y">>) or (CBind == <<"n">>) -> case parse_attribute(UserNameAttribute) of {error, Reason} -> {error, Reason}; {_, EscapedUserName} -> case unescape_username(EscapedUserName) of error -> {error, bad_username}; UserName -> case parse_attribute(ClientNonceAttribute) of {$r, ClientNonce} -> {Pass, AuthModule} = (State#state.get_password)(UserName), LPass = if is_binary(Pass) -> jid:resourceprep(Pass); true -> Pass end, if Pass == false -> {error, not_authorized, UserName}; LPass == error -> {error, saslprep_failed, UserName}; true -> {StoredKey, ServerKey, Salt, IterationCount} = if is_record(Pass, scram) -> {base64:decode(Pass#scram.storedkey), base64:decode(Pass#scram.serverkey), base64:decode(Pass#scram.salt), Pass#scram.iterationcount}; true -> TempSalt = randoms:bytes(?SALT_LENGTH), SaltedPassword = scram:salted_password(Pass, TempSalt, ?SCRAM_DEFAULT_ITERATION_COUNT), {scram:stored_key(scram:client_key(SaltedPassword)), scram:server_key(SaltedPassword), TempSalt, ?SCRAM_DEFAULT_ITERATION_COUNT} end, ClientFirstMessageBare = str:substr(ClientIn, str:str(ClientIn, <<"n=">>)), ServerNonce = base64:encode(randoms:bytes(?NONCE_LENGTH)), ServerFirstMessage = iolist_to_binary( ["r=", ClientNonce, ServerNonce, ",", "s=", base64:encode(Salt), ",", "i=", integer_to_list(IterationCount)]), {continue, ServerFirstMessage, State#state{step = 4, stored_key = StoredKey, server_key = ServerKey, auth_module = AuthModule, auth_message = <<ClientFirstMessageBare/binary, ",", ServerFirstMessage/binary>>, client_nonce = ClientNonce, server_nonce = ServerNonce, username = UserName}} end; _ -> {error, bad_attribute} end end end; _Else -> {error, parser_failed} end; mech_step(#state{step = 4} = State, ClientIn) -> case str:tokens(ClientIn, <<",">>) of [GS2ChannelBindingAttribute, NonceAttribute, ClientProofAttribute] -> case parse_attribute(GS2ChannelBindingAttribute) of {$c, CVal} -> ChannelBindingSupport = try binary:first(base64:decode(CVal)) catch _:badarg -> 0 end, if (ChannelBindingSupport == $n) or (ChannelBindingSupport == $y) -> Nonce = <<(State#state.client_nonce)/binary, (State#state.server_nonce)/binary>>, case parse_attribute(NonceAttribute) of {$r, CompareNonce} when CompareNonce == Nonce -> case parse_attribute(ClientProofAttribute) of {$p, ClientProofB64} -> ClientProof = try base64:decode(ClientProofB64) catch _:badarg -> <<>> end, AuthMessage = iolist_to_binary( [State#state.auth_message, ",", str:substr(ClientIn, 1, str:str(ClientIn, <<",p=">>) - 1)]), ClientSignature = scram:client_signature(State#state.stored_key, AuthMessage), ClientKey = scram:client_key(ClientProof, ClientSignature), CompareStoredKey = scram:stored_key(ClientKey), if CompareStoredKey == State#state.stored_key -> ServerSignature = scram:server_signature(State#state.server_key, AuthMessage), {ok, [{username, State#state.username}, {auth_module, State#state.auth_module}, {authzid, State#state.username}], <<"v=", (base64:encode(ServerSignature))/binary>>}; true -> {error, not_authorized, State#state.username} end; _ -> {error, bad_attribute} end; {$r, _} -> {error, nonce_mismatch}; _ -> {error, bad_attribute} end; true -> {error, bad_channel_binding} end; _ -> {error, bad_attribute} end; _ -> {error, parser_failed} end. parse_attribute(<<Name, $=, Val/binary>>) when Val /= <<>> -> case is_alpha(Name) of true -> {Name, Val}; false -> {error, bad_attribute} end; parse_attribute(_) -> {error, bad_attribute}. unescape_username(<<"">>) -> <<"">>; unescape_username(EscapedUsername) -> Pos = str:str(EscapedUsername, <<"=">>), if Pos == 0 -> EscapedUsername; true -> Start = str:substr(EscapedUsername, 1, Pos - 1), End = str:substr(EscapedUsername, Pos), EndLen = byte_size(End), if EndLen < 3 -> error; true -> case str:substr(End, 1, 3) of <<"=2C">> -> <<Start/binary, ",", (unescape_username(str:substr(End, 4)))/binary>>; <<"=3D">> -> <<Start/binary, "=", (unescape_username(str:substr(End, 4)))/binary>>; _Else -> error end end end. is_alpha(Char) when Char >= $a, Char =< $z -> true; is_alpha(Char) when Char >= $A, Char =< $Z -> true; is_alpha(_) -> false.