summaryrefslogtreecommitdiff
path: root/src/ejabberd_websocket.erl
diff options
context:
space:
mode:
authorPaweł Chmielowski <pchmielowski@process-one.net>2019-04-26 15:29:43 +0200
committerPaweł Chmielowski <pchmielowski@process-one.net>2019-04-26 15:29:43 +0200
commitbcfe50f817b6365b2cada08e05cc8f59f5d00980 (patch)
treefad3fe97e903b433bf5c45a13f2aaa85bd96d75c /src/ejabberd_websocket.erl
parentAllow non-moderator subscribers to get list of room subscribers (diff)
Return "Bad request" error when origin in websocket connection doesn't match
This also allow websocket_origin option to accept multiple values instead of just single one.
Diffstat (limited to 'src/ejabberd_websocket.erl')
-rw-r--r--src/ejabberd_websocket.erl92
1 files changed, 55 insertions, 37 deletions
diff --git a/src/ejabberd_websocket.erl b/src/ejabberd_websocket.erl
index 2b5a0146..e954b42c 100644
--- a/src/ejabberd_websocket.erl
+++ b/src/ejabberd_websocket.erl
@@ -42,7 +42,7 @@
-author('ecestari@process-one.net').
--export([check/2, socket_handoff/5, opt_type/1]).
+-export([socket_handoff/5, opt_type/1]).
-include("logger.hrl").
@@ -62,29 +62,39 @@
?AC_ALLOW_HEADERS, ?AC_MAX_AGE]).
-define(HEADER, [?CT_XML, ?AC_ALLOW_ORIGIN, ?AC_ALLOW_HEADERS]).
-check(_Path, Headers) ->
- HeadersValidators = [{'Upgrade', <<"websocket">>, true},
- {'Connection', ignore, true}, {'Host', ignore, true},
- {<<"Sec-Websocket-Key">>, ignore, true},
- {<<"Sec-Websocket-Version">>, <<"13">>, true},
- {<<"Origin">>, get_origin(), false}],
-
- F = fun ({Tag, Val, Required}) ->
- case lists:keyfind(Tag, 1, Headers) of
- false -> Required; % header not found, keep in list if required
- {_, HVal} ->
- case Val of
- ignore -> false; % ignore value -> ok, remove from list
- _ ->
- % expected value -> ok, remove from list (false)
- % value is different, keep in list (true)
- str:to_lower(HVal) /= Val
- end
- end
- end,
- case lists:filter(F, HeadersValidators) of
- [] -> true;
- _InvalidHeaders -> false
+is_valid_websocket_upgrade(_Path, Headers) ->
+ HeadersToValidate = [{'Upgrade', <<"websocket">>},
+ {'Connection', ignore},
+ {'Host', ignore},
+ {<<"Sec-Websocket-Key">>, ignore},
+ {<<"Sec-Websocket-Version">>, <<"13">>}],
+ Res = lists:all(
+ fun({Tag, Val}) ->
+ case lists:keyfind(Tag, 1, Headers) of
+ false ->
+ false;
+ {_, _} when Val == ignore ->
+ true;
+ {_, HVal} ->
+ str:to_lower(HVal) == Val
+ end
+ end, HeadersToValidate),
+
+ case {Res, lists:keyfind(<<"Origin">>, 1, Headers), get_origin()} of
+ {false, _, _} ->
+ false;
+ {true, _, []} ->
+ true;
+ {true, {_, HVal}, Origins} ->
+ HValLow = str:to_lower(HVal),
+ case lists:any(fun(V) -> V == HValLow end, Origins) of
+ true ->
+ true;
+ _ ->
+ invalid_origin
+ end;
+ {true, false, _} ->
+ true
end.
socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path,
@@ -92,7 +102,7 @@ socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path,
socket = Socket, sockmod = SockMod,
data = Buf, opts = HOpts},
_Opts, HandlerModule, InfoMsgFun) ->
- case check(LocalPath, Headers) of
+ case is_valid_websocket_upgrade(LocalPath, Headers) of
true ->
WS = #ws{socket = Socket,
sockmod = SockMod,
@@ -107,8 +117,11 @@ socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path,
http_opts = HOpts},
connect(WS, HandlerModule);
- _ ->
- {200, ?HEADER, InfoMsgFun()}
+ false ->
+ {200, ?HEADER, InfoMsgFun()};
+ invalid_origin ->
+ {403, ?HEADER, #xmlel{name = <<"h1">>,
+ children = [{xmlcdata, <<"403 Bad Request - Invalid origin">>}]}}
end;
socket_handoff(_, #request{method = 'OPTIONS'}, _, _, _) ->
{200, ?OPTIONS_HEADER, []};
@@ -413,22 +426,27 @@ websocket_close(Socket, WsHandleLoopPid, SocketMode, _CloseCode) ->
SocketMode:close(Socket).
get_origin() ->
- ejabberd_config:get_option(websocket_origin, ignore).
+ ejabberd_config:get_option(websocket_origin, []).
opt_type(websocket_ping_interval) ->
fun (I) when is_integer(I), I >= 0 -> I end;
opt_type(websocket_timeout) ->
fun (I) when is_integer(I), I > 0 -> I end;
opt_type(websocket_origin) ->
- %% Accept only values conforming to RFC6454 section 7.1
- fun (<<"null">>) -> <<"null">>;
- (null) -> <<"null">>;
- (Origin) ->
- URIs = [_|_] = lists:flatmap(
- fun(<<>>) -> [];
- (URI) -> [misc:try_url(URI)]
- end, re:split(Origin, "\\s")),
- str:join(URIs, <<" ">>)
+ fun Verify(V) when is_binary(V) ->
+ Verify([V]);
+ Verify([]) ->
+ [];
+ Verify([<<"null">> | R]) ->
+ [<<"null">> | Verify(R)];
+ Verify([null | R]) ->
+ [<<"null">> | Verify(R)];
+ Verify([V | R]) when is_binary(V) ->
+ URIs = [_|_] = lists:filtermap(
+ fun(<<>>) -> false;
+ (URI) -> {true, misc:try_url(URI)}
+ end, re:split(V, "\\s+")),
+ [str:join(URIs, <<" ">>) | Verify(R)]
end;
opt_type(_) ->
[websocket_ping_interval, websocket_timeout, websocket_origin].