aboutsummaryrefslogtreecommitdiff
path: root/src/ejabberd_websocket.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/ejabberd_websocket.erl')
-rw-r--r--src/ejabberd_websocket.erl171
1 files changed, 113 insertions, 58 deletions
diff --git a/src/ejabberd_websocket.erl b/src/ejabberd_websocket.erl
index 0cdd9bac5..b77e39820 100644
--- a/src/ejabberd_websocket.erl
+++ b/src/ejabberd_websocket.erl
@@ -33,21 +33,19 @@
%%% NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
%%% POSSIBILITY OF SUCH DAMAGE.
%%% ==========================================================================================================
-%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
+%%% ejabberd, Copyright (C) 2002-2019 ProcessOne
%%%----------------------------------------------------------------------
-module(ejabberd_websocket).
-
-protocol({rfc, 6455}).
-author('ecestari@process-one.net').
--export([check/2, socket_handoff/8]).
+-export([socket_handoff/5]).
--include("ejabberd.hrl").
-include("logger.hrl").
--include("jlib.hrl").
+-include("xmpp.hrl").
-include("ejabberd_http.hrl").
@@ -63,35 +61,47 @@
?AC_ALLOW_HEADERS, ?AC_MAX_AGE]).
-define(HEADER, [?CT_XML, ?AC_ALLOW_ORIGIN, ?AC_ALLOW_HEADERS]).
-check(_Path, Headers) ->
- RequiredHeaders = [{'Upgrade', <<"websocket">>},
- {'Connection', ignore}, {'Host', ignore},
- {<<"Sec-Websocket-Key">>, ignore},
- {<<"Sec-Websocket-Version">>, <<"13">>}],
-
- F = fun ({Tag, Val}) ->
- case lists:keyfind(Tag, 1, Headers) of
- false -> true; % header not found, keep in list
- {_, HVal} ->
- case Val of
- ignore -> false; % ignore value -> ok, remove from list
- _ ->
- % expected value -> ok, remove from list (false)
- % value is different, keep in list (true)
- str:to_lower(HVal) /= Val
- end
- end
- end,
- case lists:filter(F, RequiredHeaders) of
- [] -> true;
- _MissingHeaders -> false
+is_valid_websocket_upgrade(_Path, Headers) ->
+ HeadersToValidate = [{'Upgrade', <<"websocket">>},
+ {'Connection', ignore},
+ {'Host', ignore},
+ {<<"Sec-Websocket-Key">>, ignore},
+ {<<"Sec-Websocket-Version">>, <<"13">>}],
+ Res = lists:all(
+ fun({Tag, Val}) ->
+ case lists:keyfind(Tag, 1, Headers) of
+ false ->
+ false;
+ {_, _} when Val == ignore ->
+ true;
+ {_, HVal} ->
+ str:to_lower(HVal) == Val
+ end
+ end, HeadersToValidate),
+
+ case {Res, lists:keyfind(<<"Origin">>, 1, Headers), get_origin()} of
+ {false, _, _} ->
+ false;
+ {true, _, []} ->
+ true;
+ {true, {_, HVal}, Origins} ->
+ HValLow = str:to_lower(HVal),
+ case lists:any(fun(V) -> V == HValLow end, Origins) of
+ true ->
+ true;
+ _ ->
+ invalid_origin
+ end;
+ {true, false, _} ->
+ true
end.
socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path,
headers = Headers, host = Host, port = Port,
- opts = HOpts},
- Socket, SockMod, Buf, _Opts, HandlerModule, InfoMsgFun) ->
- case check(LocalPath, Headers) of
+ socket = Socket, sockmod = SockMod,
+ data = Buf, opts = HOpts},
+ _Opts, HandlerModule, InfoMsgFun) ->
+ case is_valid_websocket_upgrade(LocalPath, Headers) of
true ->
WS = #ws{socket = Socket,
sockmod = SockMod,
@@ -106,14 +116,17 @@ socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path,
http_opts = HOpts},
connect(WS, HandlerModule);
- _ ->
- {200, ?HEADER, InfoMsgFun()}
+ false ->
+ {200, ?HEADER, InfoMsgFun()};
+ invalid_origin ->
+ {403, ?HEADER, #xmlel{name = <<"h1">>,
+ children = [{xmlcdata, <<"403 Bad Request - Invalid origin">>}]}}
end;
-socket_handoff(_, #request{method = 'OPTIONS'}, _, _, _, _, _, _) ->
+socket_handoff(_, #request{method = 'OPTIONS'}, _, _, _) ->
{200, ?OPTIONS_HEADER, []};
-socket_handoff(_, #request{method = 'HEAD'}, _, _, _, _, _, _) ->
+socket_handoff(_, #request{method = 'HEAD'}, _, _, _) ->
{200, ?HEADER, []};
-socket_handoff(_, _, _, _, _, _, _, _) ->
+socket_handoff(_, _, _, _, _) ->
{400, ?HEADER, #xmlel{name = <<"h1">>,
children = [{xmlcdata, <<"400 Bad Request">>}]}}.
@@ -141,7 +154,7 @@ connect(#ws{socket = Socket, sockmod = SockMod} = Ws, WsLoop) ->
_ ->
SockMod:setopts(Socket, [{packet, 0}, {active, true}])
end,
- ws_loop(none, Socket, WsHandleLoopPid, SockMod).
+ ws_loop(none, Socket, WsHandleLoopPid, SockMod, none).
handshake(#ws{headers = Headers} = State) ->
{_, Key} = lists:keyfind(<<"Sec-Websocket-Key">>, 1,
@@ -152,8 +165,8 @@ handshake(#ws{headers = Headers} = State) ->
V ->
[<<"Sec-Websocket-Protocol:">>, V, <<"\r\n">>]
end,
- Hash = jlib:encode_base64(
- p1_sha:sha1(<<Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11">>)),
+ Hash = base64:encode(
+ crypto:hash(sha, <<Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11">>)),
{State, [<<"HTTP/1.1 101 Switching Protocols\r\n">>,
<<"Upgrade: websocket\r\n">>,
<<"Connection: Upgrade\r\n">>,
@@ -174,51 +187,66 @@ find_subprotocol(Headers) ->
end.
-ws_loop(FrameInfo, Socket, WsHandleLoopPid, SocketMode) ->
+ws_loop(FrameInfo, Socket, WsHandleLoopPid, SocketMode, Shaper) ->
receive
{DataType, _Socket, Data} when DataType =:= tcp orelse DataType =:= raw ->
- case handle_data(DataType, FrameInfo, Data, Socket, WsHandleLoopPid, SocketMode) of
+ case handle_data(DataType, FrameInfo, Data, Socket, WsHandleLoopPid, SocketMode, Shaper) of
{error, Error} ->
- ?DEBUG("tls decode error ~p", [Error]),
+ ?DEBUG("TLS decode error ~p", [Error]),
websocket_close(Socket, WsHandleLoopPid, SocketMode, 1002); % protocol error
- {NewFrameInfo, ToSend} ->
+ {NewFrameInfo, ToSend, NewShaper} ->
lists:foreach(fun(Pkt) -> SocketMode:send(Socket, Pkt)
end, ToSend),
- ws_loop(NewFrameInfo, Socket, WsHandleLoopPid, SocketMode)
+ ws_loop(NewFrameInfo, Socket, WsHandleLoopPid, SocketMode, NewShaper)
end;
+ {new_shaper, NewShaper} ->
+ NewShaper = case NewShaper of
+ none when Shaper /= none ->
+ activate(Socket, SocketMode, true), none;
+ _ ->
+ NewShaper
+ end,
+ ws_loop(FrameInfo, Socket, WsHandleLoopPid, SocketMode, NewShaper);
{tcp_closed, _Socket} ->
- ?DEBUG("tcp connection was closed, exit", []),
+ ?DEBUG("TCP connection was closed, exit", []),
websocket_close(Socket, WsHandleLoopPid, SocketMode, 0);
+ {tcp_error, Socket, Reason} ->
+ ?DEBUG("TCP connection error: ~ts", [inet:format_error(Reason)]),
+ websocket_close(Socket, WsHandleLoopPid, SocketMode, 0);
{'DOWN', Ref, process, WsHandleLoopPid, Reason} ->
Code = case Reason of
normal ->
1000; % normal close
_ ->
- ?ERROR_MSG("linked websocket controlling loop crashed "
+ ?ERROR_MSG("Linked websocket controlling loop crashed "
"with reason: ~p",
[Reason]),
1011 % internal error
end,
erlang:demonitor(Ref),
websocket_close(Socket, WsHandleLoopPid, SocketMode, Code);
- {send, Data} ->
+ {text, Data} ->
SocketMode:send(Socket, encode_frame(Data, 1)),
ws_loop(FrameInfo, Socket, WsHandleLoopPid,
- SocketMode);
+ SocketMode, Shaper);
+ {data, Data} ->
+ SocketMode:send(Socket, encode_frame(Data, 2)),
+ ws_loop(FrameInfo, Socket, WsHandleLoopPid,
+ SocketMode, Shaper);
{ping, Data} ->
SocketMode:send(Socket, encode_frame(Data, 9)),
ws_loop(FrameInfo, Socket, WsHandleLoopPid,
- SocketMode);
+ SocketMode, Shaper);
shutdown ->
- ?DEBUG("shutdown request received, closing websocket "
+ ?DEBUG("Shutdown request received, closing websocket "
"with pid ~p",
[self()]),
websocket_close(Socket, WsHandleLoopPid, SocketMode, 1001); % going away
_Ignored ->
- ?WARNING_MSG("received unexpected message, ignoring: ~p",
+ ?WARNING_MSG("Received unexpected message, ignoring: ~p",
[_Ignored]),
ws_loop(FrameInfo, Socket, WsHandleLoopPid,
- SocketMode)
+ SocketMode, Shaper)
end.
encode_frame(Data, Opcode) ->
@@ -328,7 +356,7 @@ process_frame(#frame_info{unprocessed = none,
8 -> % Close
CloseCode = case Unmasked of
<<Code:16/integer-big, Message/binary>> ->
- ?DEBUG("WebSocket close op: ~p ~s",
+ ?DEBUG("WebSocket close op: ~p ~ts",
[Code, Message]),
Code;
<<Code:16/integer-big>> ->
@@ -373,17 +401,17 @@ process_frame(#frame_info{unprocessed =
process_frame(FrameInfo#frame_info{unprocessed = <<>>},
<<UnprocessedPre/binary, Data/binary>>).
-handle_data(tcp, FrameInfo, Data, Socket, WsHandleLoopPid, fast_tls) ->
+handle_data(tcp, FrameInfo, Data, Socket, WsHandleLoopPid, fast_tls, Shaper) ->
case fast_tls:recv_data(Socket, Data) of
{ok, NewData} ->
- handle_data_int(FrameInfo, NewData, Socket, WsHandleLoopPid, fast_tls);
+ handle_data_int(FrameInfo, NewData, Socket, WsHandleLoopPid, fast_tls, Shaper);
{error, Error} ->
{error, Error}
end;
-handle_data(_, FrameInfo, Data, Socket, WsHandleLoopPid, SockMod) ->
- handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SockMod).
+handle_data(_, FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper) ->
+ handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper).
-handle_data_int(FrameInfo, Data, _Socket, WsHandleLoopPid, _SocketMode) ->
+handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SocketMode, Shaper) ->
{NewFrameInfo, Recv, Send} = process_frame(FrameInfo, Data),
lists:foreach(fun (El) ->
case El of
@@ -396,7 +424,7 @@ handle_data_int(FrameInfo, Data, _Socket, WsHandleLoopPid, _SocketMode) ->
end
end,
Recv),
- {NewFrameInfo, Send}.
+ {NewFrameInfo, Send, handle_shaping(Data, Socket, SocketMode, Shaper)}.
websocket_close(Socket, WsHandleLoopPid,
SocketMode, CloseCode) when CloseCode > 0 ->
@@ -406,3 +434,30 @@ websocket_close(Socket, WsHandleLoopPid,
websocket_close(Socket, WsHandleLoopPid, SocketMode, _CloseCode) ->
WsHandleLoopPid ! closed,
SocketMode:close(Socket).
+
+get_origin() ->
+ ejabberd_option:websocket_origin().
+
+handle_shaping(_Data, _Socket, _SocketMode, none) ->
+ none;
+handle_shaping(Data, Socket, SocketMode, Shaper) ->
+ {NewShaper, Pause} = ejabberd_shaper:update(Shaper, byte_size(Data)),
+ if Pause > 0 ->
+ activate_after(Socket, self(), Pause);
+ true -> activate(Socket, SocketMode, once)
+ end,
+ NewShaper.
+
+activate(Socket, SockMod, ActiveState) ->
+ case SockMod of
+ gen_tcp -> inet:setopts(Socket, [{active, ActiveState}]);
+ _ -> SockMod:setopts(Socket, [{active, ActiveState}])
+ end.
+
+activate_after(Socket, Pid, Pause) ->
+ if Pause > 0 ->
+ erlang:send_after(Pause, Pid, {tcp, Socket, <<>>});
+ true ->
+ Pid ! {tcp, Socket, <<>>}
+ end,
+ ok.