aboutsummaryrefslogtreecommitdiff
path: root/src/cyrsasl_scram.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/cyrsasl_scram.erl')
-rw-r--r--src/cyrsasl_scram.erl297
1 files changed, 158 insertions, 139 deletions
diff --git a/src/cyrsasl_scram.erl b/src/cyrsasl_scram.erl
index dc671b243..33d18cd1a 100644
--- a/src/cyrsasl_scram.erl
+++ b/src/cyrsasl_scram.erl
@@ -25,166 +25,185 @@
%%%----------------------------------------------------------------------
-module(cyrsasl_scram).
+
-author('stephen.roettger@googlemail.com').
--export([start/1,
- stop/0,
- mech_new/4,
- mech_step/2]).
+-export([start/1, stop/0, mech_new/4, mech_step/2]).
-include("ejabberd.hrl").
-behaviour(cyrsasl).
--record(state, {step, stored_key, server_key, username, get_password, check_password,
- auth_message, client_nonce, server_nonce}).
+-record(state,
+ {step = 2 :: 2 | 4,
+ stored_key = <<"">> :: binary(),
+ server_key = <<"">> :: binary(),
+ username = <<"">> :: binary(),
+ get_password :: fun(),
+ check_password :: fun(),
+ auth_message = <<"">> :: binary(),
+ client_nonce = <<"">> :: binary(),
+ server_nonce = <<"">> :: binary()}).
-define(SALT_LENGTH, 16).
+
-define(NONCE_LENGTH, 16).
start(_Opts) ->
- cyrsasl:register_mechanism("SCRAM-SHA-1", ?MODULE, scram).
+ cyrsasl:register_mechanism(<<"SCRAM-SHA-1">>, ?MODULE,
+ scram).
-stop() ->
- ok.
+stop() -> ok.
-mech_new(_Host, GetPassword, _CheckPassword, _CheckPasswordDigest) ->
+mech_new(_Host, GetPassword, _CheckPassword,
+ _CheckPasswordDigest) ->
{ok, #state{step = 2, get_password = GetPassword}}.
mech_step(#state{step = 2} = State, ClientIn) ->
- case string:tokens(ClientIn, ",") of
- [CBind, UserNameAttribute, ClientNonceAttribute] when (CBind == "y") or (CBind == "n") ->
- case parse_attribute(UserNameAttribute) of
- {error, Reason} ->
- {error, Reason};
- {_, EscapedUserName} ->
- case unescape_username(EscapedUserName) of
- error ->
- {error, "protocol-error-bad-username"};
- UserName ->
- case parse_attribute(ClientNonceAttribute) of
- {$r, ClientNonce} ->
- case (State#state.get_password)(UserName) of
- {false, _} ->
- {error, "not-authorized", UserName};
- {Ret, _AuthModule} ->
- {StoredKey, ServerKey, Salt, IterationCount} = if
- is_tuple(Ret) ->
- Ret;
- true ->
- TempSalt = crypto:rand_bytes(?SALT_LENGTH),
- SaltedPassword = scram:salted_password(Ret, TempSalt, ?SCRAM_DEFAULT_ITERATION_COUNT),
- {scram:stored_key(scram:client_key(SaltedPassword)),
- scram:server_key(SaltedPassword), TempSalt, ?SCRAM_DEFAULT_ITERATION_COUNT}
- end,
- ClientFirstMessageBare = string:substr(ClientIn, string:str(ClientIn, "n=")),
- ServerNonce = base64:encode_to_string(crypto:rand_bytes(?NONCE_LENGTH)),
- ServerFirstMessage = "r=" ++ ClientNonce ++ ServerNonce ++ "," ++
- "s=" ++ base64:encode_to_string(Salt) ++ "," ++
- "i=" ++ integer_to_list(IterationCount),
- {continue,
- ServerFirstMessage,
- State#state{step = 4, stored_key = StoredKey, server_key = ServerKey,
- auth_message = ClientFirstMessageBare ++ "," ++ ServerFirstMessage,
- client_nonce = ClientNonce, server_nonce = ServerNonce, username = UserName}}
- end;
- _Else ->
- {error, "not-supported"}
- end
- end
- end;
- _Else ->
- {error, "bad-protocol"}
- end;
+ case str:tokens(ClientIn, <<",">>) of
+ [CBind, UserNameAttribute, ClientNonceAttribute]
+ when (CBind == <<"y">>) or (CBind == <<"n">>) ->
+ case parse_attribute(UserNameAttribute) of
+ {error, Reason} -> {error, Reason};
+ {_, EscapedUserName} ->
+ case unescape_username(EscapedUserName) of
+ error -> {error, <<"protocol-error-bad-username">>};
+ UserName ->
+ case parse_attribute(ClientNonceAttribute) of
+ {$r, ClientNonce} ->
+ case (State#state.get_password)(UserName) of
+ {false, _} -> {error, <<"not-authorized">>, UserName};
+ {Ret, _AuthModule} ->
+ {StoredKey, ServerKey, Salt, IterationCount} =
+ if is_tuple(Ret) -> Ret;
+ true ->
+ TempSalt =
+ crypto:rand_bytes(?SALT_LENGTH),
+ SaltedPassword =
+ scram:salted_password(Ret,
+ TempSalt,
+ ?SCRAM_DEFAULT_ITERATION_COUNT),
+ {scram:stored_key(scram:client_key(SaltedPassword)),
+ scram:server_key(SaltedPassword),
+ TempSalt,
+ ?SCRAM_DEFAULT_ITERATION_COUNT}
+ end,
+ ClientFirstMessageBare =
+ str:substr(ClientIn,
+ str:str(ClientIn, <<"n=">>)),
+ ServerNonce =
+ jlib:encode_base64(crypto:rand_bytes(?NONCE_LENGTH)),
+ ServerFirstMessage =
+ iolist_to_binary(
+ ["r=",
+ ClientNonce,
+ ServerNonce,
+ ",", "s=",
+ jlib:encode_base64(Salt),
+ ",", "i=",
+ integer_to_list(IterationCount)]),
+ {continue, ServerFirstMessage,
+ State#state{step = 4, stored_key = StoredKey,
+ server_key = ServerKey,
+ auth_message =
+ <<ClientFirstMessageBare/binary,
+ ",", ServerFirstMessage/binary>>,
+ client_nonce = ClientNonce,
+ server_nonce = ServerNonce,
+ username = UserName}}
+ end;
+ _Else -> {error, <<"not-supported">>}
+ end
+ end
+ end;
+ _Else -> {error, <<"bad-protocol">>}
+ end;
mech_step(#state{step = 4} = State, ClientIn) ->
- case string:tokens(ClientIn, ",") of
- [GS2ChannelBindingAttribute, NonceAttribute, ClientProofAttribute] ->
- case parse_attribute(GS2ChannelBindingAttribute) of
- {$c, CVal} when (CVal == "biws") or (CVal == "eSws") ->
- %% biws is base64 for n,, => channelbinding not supported
- %% eSws is base64 for y,, => channelbinding supported by client only
- Nonce = State#state.client_nonce ++ State#state.server_nonce,
- case parse_attribute(NonceAttribute) of
- {$r, CompareNonce} when CompareNonce == Nonce ->
- case parse_attribute(ClientProofAttribute) of
- {$p, ClientProofB64} ->
- ClientProof = base64:decode(ClientProofB64),
- AuthMessage = State#state.auth_message ++ "," ++ string:substr(ClientIn, 1, string:str(ClientIn, ",p=")-1),
- ClientSignature = scram:client_signature(State#state.stored_key, AuthMessage),
- ClientKey = scram:client_key(ClientProof, ClientSignature),
- CompareStoredKey = scram:stored_key(ClientKey),
- if CompareStoredKey == State#state.stored_key ->
- ServerSignature = scram:server_signature(State#state.server_key, AuthMessage),
- {ok, [{username, State#state.username}], "v=" ++ base64:encode_to_string(ServerSignature)};
- true ->
- {error, "bad-auth"}
- end;
- _Else ->
- {error, "bad-protocol"}
- end;
- {$r, _} ->
- {error, "bad-nonce"};
- _Else ->
- {error, "bad-protocol"}
- end;
- _Else ->
- {error, "bad-protocol"}
+ case str:tokens(ClientIn, <<",">>) of
+ [GS2ChannelBindingAttribute, NonceAttribute,
+ ClientProofAttribute] ->
+ case parse_attribute(GS2ChannelBindingAttribute) of
+ {$c, CVal} when (CVal == <<"biws">>) or (CVal == <<"eSws">>) ->
+ %% biws is base64 for n,, => channelbinding not supported
+ %% eSws is base64 for y,, => channelbinding supported by client only
+ Nonce = <<(State#state.client_nonce)/binary,
+ (State#state.server_nonce)/binary>>,
+ case parse_attribute(NonceAttribute) of
+ {$r, CompareNonce} when CompareNonce == Nonce ->
+ case parse_attribute(ClientProofAttribute) of
+ {$p, ClientProofB64} ->
+ ClientProof = jlib:decode_base64(ClientProofB64),
+ AuthMessage =
+ iolist_to_binary(
+ [State#state.auth_message,
+ ",",
+ str:substr(ClientIn, 1,
+ str:str(ClientIn, <<",p=">>)
+ - 1)]),
+ ClientSignature =
+ scram:client_signature(State#state.stored_key,
+ AuthMessage),
+ ClientKey = scram:client_key(ClientProof,
+ ClientSignature),
+ CompareStoredKey = scram:stored_key(ClientKey),
+ if CompareStoredKey == State#state.stored_key ->
+ ServerSignature =
+ scram:server_signature(State#state.server_key,
+ AuthMessage),
+ {ok, [{username, State#state.username}],
+ <<"v=",
+ (jlib:encode_base64(ServerSignature))/binary>>};
+ true -> {error, <<"bad-auth">>}
+ end;
+ _Else -> {error, <<"bad-protocol">>}
+ end;
+ {$r, _} -> {error, <<"bad-nonce">>};
+ _Else -> {error, <<"bad-protocol">>}
end;
- _Else ->
- {error, "bad-protocol"}
- end.
+ _Else -> {error, <<"bad-protocol">>}
+ end;
+ _Else -> {error, <<"bad-protocol">>}
+ end.
parse_attribute(Attribute) ->
- AttributeLen = string:len(Attribute),
- if
- AttributeLen >= 3 ->
- SecondChar = lists:nth(2, Attribute),
- case is_alpha(lists:nth(1, Attribute)) of
- true ->
- if
- SecondChar == $= ->
- String = string:substr(Attribute, 3),
- {lists:nth(1, Attribute), String};
- true ->
- {error, "bad-format second char not equal sign"}
- end;
- _Else ->
- {error, "bad-format first char not a letter"}
- end;
- true ->
- {error, "bad-format attribute too short"}
- end.
+ AttributeLen = byte_size(Attribute),
+ if AttributeLen >= 3 ->
+ AttributeS = binary_to_list(Attribute),
+ SecondChar = lists:nth(2, AttributeS),
+ case is_alpha(lists:nth(1, AttributeS)) of
+ true ->
+ if SecondChar == $= ->
+ String = str:substr(Attribute, 3),
+ {lists:nth(1, AttributeS), String};
+ true -> {error, <<"bad-format second char not equal sign">>}
+ end;
+ _Else -> {error, <<"bad-format first char not a letter">>}
+ end;
+ true -> {error, <<"bad-format attribute too short">>}
+ end.
-unescape_username("") ->
- "";
+unescape_username(<<"">>) -> <<"">>;
unescape_username(EscapedUsername) ->
- Pos = string:str(EscapedUsername, "="),
- if
- Pos == 0 ->
- EscapedUsername;
- true ->
- Start = string:substr(EscapedUsername, 1, Pos-1),
- End = string:substr(EscapedUsername, Pos),
- EndLen = string:len(End),
- if
- EndLen < 3 ->
- error;
- true ->
- case string:substr(End, 1, 3) of
- "=2C" ->
- Start ++ "," ++ unescape_username(string:substr(End, 4));
- "=3D" ->
- Start ++ "=" ++ unescape_username(string:substr(End, 4));
- _Else ->
- error
- end
- end
- end.
-
-is_alpha(Char) when Char >= $a, Char =< $z ->
- true;
-is_alpha(Char) when Char >= $A, Char =< $Z ->
- true;
-is_alpha(_) ->
- false.
+ Pos = str:str(EscapedUsername, <<"=">>),
+ if Pos == 0 -> EscapedUsername;
+ true ->
+ Start = str:substr(EscapedUsername, 1, Pos - 1),
+ End = str:substr(EscapedUsername, Pos),
+ EndLen = byte_size(End),
+ if EndLen < 3 -> error;
+ true ->
+ case str:substr(End, 1, 3) of
+ <<"=2C">> ->
+ <<Start/binary, ",",
+ (unescape_username(str:substr(End, 4)))/binary>>;
+ <<"=3D">> ->
+ <<Start/binary, "=",
+ (unescape_username(str:substr(End, 4)))/binary>>;
+ _Else -> error
+ end
+ end
+ end.
+is_alpha(Char) when Char >= $a, Char =< $z -> true;
+is_alpha(Char) when Char >= $A, Char =< $Z -> true;
+is_alpha(_) -> false.