diff options
Diffstat (limited to 'src/cyrsasl_scram.erl')
-rw-r--r-- | src/cyrsasl_scram.erl | 84 |
1 files changed, 51 insertions, 33 deletions
diff --git a/src/cyrsasl_scram.erl b/src/cyrsasl_scram.erl index 1e2a5c681..55e06fd25 100644 --- a/src/cyrsasl_scram.erl +++ b/src/cyrsasl_scram.erl @@ -29,7 +29,7 @@ -protocol({rfc, 5802}). --export([start/1, stop/0, mech_new/4, mech_step/2]). +-export([start/1, stop/0, mech_new/4, mech_step/2, format_error/1]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -41,6 +41,7 @@ stored_key = <<"">> :: binary(), server_key = <<"">> :: binary(), username = <<"">> :: binary(), + auth_module :: module(), get_password :: fun(), check_password :: fun(), auth_message = <<"">> :: binary(), @@ -48,15 +49,39 @@ server_nonce = <<"">> :: binary()}). -define(SALT_LENGTH, 16). - -define(NONCE_LENGTH, 16). +-type error_reason() :: unsupported_extension | bad_username | + not_authorized | saslprep_failed | + parser_failed | bad_attribute | + nonce_mismatch | bad_channel_binding. + +-export_type([error_reason/0]). + start(_Opts) -> cyrsasl:register_mechanism(<<"SCRAM-SHA-1">>, ?MODULE, scram). stop() -> ok. +-spec format_error(error_reason()) -> {atom(), binary()}. +format_error(unsupported_extension) -> + {'bad-protocol', <<"Unsupported extension">>}; +format_error(bad_username) -> + {'invalid-authzid', <<"Malformed username">>}; +format_error(not_authorized) -> + {'not-authorized', <<"Invalid username or password">>}; +format_error(saslprep_failed) -> + {'not-authorized', <<"SASLprep failed">>}; +format_error(parser_failed) -> + {'bad-protocol', <<"Response decoding failed">>}; +format_error(bad_attribute) -> + {'bad-protocol', <<"Malformed or unexpected attribute">>}; +format_error(nonce_mismatch) -> + {'bad-protocol', <<"Nonce mismatch">>}; +format_error(bad_channel_binding) -> + {'bad-protocol', <<"Invalid channel binding">>}. + mech_new(_Host, GetPassword, _CheckPassword, _CheckPasswordDigest) -> {ok, #state{step = 2, get_password = GetPassword}}. @@ -64,22 +89,22 @@ mech_new(_Host, GetPassword, _CheckPassword, mech_step(#state{step = 2} = State, ClientIn) -> case re:split(ClientIn, <<",">>, [{return, binary}]) of [_CBind, _AuthorizationIdentity, _UserNameAttribute, _ClientNonceAttribute, ExtensionAttribute | _] - when ExtensionAttribute /= [] -> - {error, 'protocol-error-extension-not-supported'}; + when ExtensionAttribute /= <<"">> -> + {error, unsupported_extension}; [CBind, _AuthorizationIdentity, UserNameAttribute, ClientNonceAttribute | _] when (CBind == <<"y">>) or (CBind == <<"n">>) -> case parse_attribute(UserNameAttribute) of {error, Reason} -> {error, Reason}; {_, EscapedUserName} -> case unescape_username(EscapedUserName) of - error -> {error, 'protocol-error-bad-username'}; + error -> {error, bad_username}; UserName -> case parse_attribute(ClientNonceAttribute) of {$r, ClientNonce} -> - {Ret, _AuthModule} = (State#state.get_password)(UserName), + {Ret, AuthModule} = (State#state.get_password)(UserName), case {Ret, jid:resourceprep(Ret)} of - {false, _} -> {error, 'not-authorized', UserName}; - {_, error} when is_binary(Ret) -> ?WARNING_MSG("invalid plain password", []), {error, 'not-authorized', UserName}; + {false, _} -> {error, not_authorized, UserName}; + {_, error} when is_binary(Ret) -> {error, saslprep_failed, UserName}; {Ret, _} -> {StoredKey, ServerKey, Salt, IterationCount} = if is_tuple(Ret) -> Ret; @@ -112,6 +137,7 @@ mech_step(#state{step = 2} = State, ClientIn) -> {continue, ServerFirstMessage, State#state{step = 4, stored_key = StoredKey, server_key = ServerKey, + auth_module = AuthModule, auth_message = <<ClientFirstMessageBare/binary, ",", ServerFirstMessage/binary>>, @@ -119,11 +145,11 @@ mech_step(#state{step = 2} = State, ClientIn) -> server_nonce = ServerNonce, username = UserName}} end; - _Else -> {error, 'not-supported'} + _ -> {error, bad_attribute} end end end; - _Else -> {error, 'bad-protocol'} + _Else -> {error, parser_failed} end; mech_step(#state{step = 4} = State, ClientIn) -> case str:tokens(ClientIn, <<",">>) of @@ -158,39 +184,31 @@ mech_step(#state{step = 4} = State, ClientIn) -> scram:server_signature(State#state.server_key, AuthMessage), {ok, [{username, State#state.username}, + {auth_module, State#state.auth_module}, {authzid, State#state.username}], <<"v=", (jlib:encode_base64(ServerSignature))/binary>>}; - true -> {error, 'bad-auth', State#state.username} + true -> {error, not_authorized, State#state.username} end; - _Else -> {error, 'bad-protocol'} + _ -> {error, bad_attribute} end; - {$r, _} -> {error, 'bad-nonce'}; - _Else -> {error, 'bad-protocol'} + {$r, _} -> {error, nonce_mismatch}; + _ -> {error, bad_attribute} end; - true -> {error, 'bad-channel-binding'} + true -> {error, bad_channel_binding} end; - _Else -> {error, 'bad-protocol'} + _ -> {error, bad_attribute} end; - _Else -> {error, 'bad-protocol'} + _ -> {error, parser_failed} end. -parse_attribute(Attribute) -> - AttributeLen = byte_size(Attribute), - if AttributeLen >= 3 -> - AttributeS = binary_to_list(Attribute), - SecondChar = lists:nth(2, AttributeS), - case is_alpha(lists:nth(1, AttributeS)) of - true -> - if SecondChar == $= -> - String = str:substr(Attribute, 3), - {lists:nth(1, AttributeS), String}; - true -> {error, 'bad-format-second-char-not-equal-sign'} - end; - _Else -> {error, 'bad-format-first-char-not-a-letter'} - end; - true -> {error, 'bad-format-attribute-too-short'} - end. +parse_attribute(<<Name, $=, Val/binary>>) when Val /= <<>> -> + case is_alpha(Name) of + true -> {Name, Val}; + false -> {error, bad_attribute} + end; +parse_attribute(_) -> + {error, bad_attribute}. unescape_username(<<"">>) -> <<"">>; unescape_username(EscapedUsername) -> |