diff options
Diffstat (limited to 'src/ejabberd_websocket.erl')
-rw-r--r-- | src/ejabberd_websocket.erl | 171 |
1 files changed, 113 insertions, 58 deletions
diff --git a/src/ejabberd_websocket.erl b/src/ejabberd_websocket.erl index 0cdd9bac5..b77e39820 100644 --- a/src/ejabberd_websocket.erl +++ b/src/ejabberd_websocket.erl @@ -33,21 +33,19 @@ %%% NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE %%% POSSIBILITY OF SUCH DAMAGE. %%% ========================================================================================================== -%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% ejabberd, Copyright (C) 2002-2019 ProcessOne %%%---------------------------------------------------------------------- -module(ejabberd_websocket). - -protocol({rfc, 6455}). -author('ecestari@process-one.net'). --export([check/2, socket_handoff/8]). +-export([socket_handoff/5]). --include("ejabberd.hrl"). -include("logger.hrl"). --include("jlib.hrl"). +-include("xmpp.hrl"). -include("ejabberd_http.hrl"). @@ -63,35 +61,47 @@ ?AC_ALLOW_HEADERS, ?AC_MAX_AGE]). -define(HEADER, [?CT_XML, ?AC_ALLOW_ORIGIN, ?AC_ALLOW_HEADERS]). -check(_Path, Headers) -> - RequiredHeaders = [{'Upgrade', <<"websocket">>}, - {'Connection', ignore}, {'Host', ignore}, - {<<"Sec-Websocket-Key">>, ignore}, - {<<"Sec-Websocket-Version">>, <<"13">>}], - - F = fun ({Tag, Val}) -> - case lists:keyfind(Tag, 1, Headers) of - false -> true; % header not found, keep in list - {_, HVal} -> - case Val of - ignore -> false; % ignore value -> ok, remove from list - _ -> - % expected value -> ok, remove from list (false) - % value is different, keep in list (true) - str:to_lower(HVal) /= Val - end - end - end, - case lists:filter(F, RequiredHeaders) of - [] -> true; - _MissingHeaders -> false +is_valid_websocket_upgrade(_Path, Headers) -> + HeadersToValidate = [{'Upgrade', <<"websocket">>}, + {'Connection', ignore}, + {'Host', ignore}, + {<<"Sec-Websocket-Key">>, ignore}, + {<<"Sec-Websocket-Version">>, <<"13">>}], + Res = lists:all( + fun({Tag, Val}) -> + case lists:keyfind(Tag, 1, Headers) of + false -> + false; + {_, _} when Val == ignore -> + true; + {_, HVal} -> + str:to_lower(HVal) == Val + end + end, HeadersToValidate), + + case {Res, lists:keyfind(<<"Origin">>, 1, Headers), get_origin()} of + {false, _, _} -> + false; + {true, _, []} -> + true; + {true, {_, HVal}, Origins} -> + HValLow = str:to_lower(HVal), + case lists:any(fun(V) -> V == HValLow end, Origins) of + true -> + true; + _ -> + invalid_origin + end; + {true, false, _} -> + true end. socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path, headers = Headers, host = Host, port = Port, - opts = HOpts}, - Socket, SockMod, Buf, _Opts, HandlerModule, InfoMsgFun) -> - case check(LocalPath, Headers) of + socket = Socket, sockmod = SockMod, + data = Buf, opts = HOpts}, + _Opts, HandlerModule, InfoMsgFun) -> + case is_valid_websocket_upgrade(LocalPath, Headers) of true -> WS = #ws{socket = Socket, sockmod = SockMod, @@ -106,14 +116,17 @@ socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path, http_opts = HOpts}, connect(WS, HandlerModule); - _ -> - {200, ?HEADER, InfoMsgFun()} + false -> + {200, ?HEADER, InfoMsgFun()}; + invalid_origin -> + {403, ?HEADER, #xmlel{name = <<"h1">>, + children = [{xmlcdata, <<"403 Bad Request - Invalid origin">>}]}} end; -socket_handoff(_, #request{method = 'OPTIONS'}, _, _, _, _, _, _) -> +socket_handoff(_, #request{method = 'OPTIONS'}, _, _, _) -> {200, ?OPTIONS_HEADER, []}; -socket_handoff(_, #request{method = 'HEAD'}, _, _, _, _, _, _) -> +socket_handoff(_, #request{method = 'HEAD'}, _, _, _) -> {200, ?HEADER, []}; -socket_handoff(_, _, _, _, _, _, _, _) -> +socket_handoff(_, _, _, _, _) -> {400, ?HEADER, #xmlel{name = <<"h1">>, children = [{xmlcdata, <<"400 Bad Request">>}]}}. @@ -141,7 +154,7 @@ connect(#ws{socket = Socket, sockmod = SockMod} = Ws, WsLoop) -> _ -> SockMod:setopts(Socket, [{packet, 0}, {active, true}]) end, - ws_loop(none, Socket, WsHandleLoopPid, SockMod). + ws_loop(none, Socket, WsHandleLoopPid, SockMod, none). handshake(#ws{headers = Headers} = State) -> {_, Key} = lists:keyfind(<<"Sec-Websocket-Key">>, 1, @@ -152,8 +165,8 @@ handshake(#ws{headers = Headers} = State) -> V -> [<<"Sec-Websocket-Protocol:">>, V, <<"\r\n">>] end, - Hash = jlib:encode_base64( - p1_sha:sha1(<<Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11">>)), + Hash = base64:encode( + crypto:hash(sha, <<Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11">>)), {State, [<<"HTTP/1.1 101 Switching Protocols\r\n">>, <<"Upgrade: websocket\r\n">>, <<"Connection: Upgrade\r\n">>, @@ -174,51 +187,66 @@ find_subprotocol(Headers) -> end. -ws_loop(FrameInfo, Socket, WsHandleLoopPid, SocketMode) -> +ws_loop(FrameInfo, Socket, WsHandleLoopPid, SocketMode, Shaper) -> receive {DataType, _Socket, Data} when DataType =:= tcp orelse DataType =:= raw -> - case handle_data(DataType, FrameInfo, Data, Socket, WsHandleLoopPid, SocketMode) of + case handle_data(DataType, FrameInfo, Data, Socket, WsHandleLoopPid, SocketMode, Shaper) of {error, Error} -> - ?DEBUG("tls decode error ~p", [Error]), + ?DEBUG("TLS decode error ~p", [Error]), websocket_close(Socket, WsHandleLoopPid, SocketMode, 1002); % protocol error - {NewFrameInfo, ToSend} -> + {NewFrameInfo, ToSend, NewShaper} -> lists:foreach(fun(Pkt) -> SocketMode:send(Socket, Pkt) end, ToSend), - ws_loop(NewFrameInfo, Socket, WsHandleLoopPid, SocketMode) + ws_loop(NewFrameInfo, Socket, WsHandleLoopPid, SocketMode, NewShaper) end; + {new_shaper, NewShaper} -> + NewShaper = case NewShaper of + none when Shaper /= none -> + activate(Socket, SocketMode, true), none; + _ -> + NewShaper + end, + ws_loop(FrameInfo, Socket, WsHandleLoopPid, SocketMode, NewShaper); {tcp_closed, _Socket} -> - ?DEBUG("tcp connection was closed, exit", []), + ?DEBUG("TCP connection was closed, exit", []), websocket_close(Socket, WsHandleLoopPid, SocketMode, 0); + {tcp_error, Socket, Reason} -> + ?DEBUG("TCP connection error: ~ts", [inet:format_error(Reason)]), + websocket_close(Socket, WsHandleLoopPid, SocketMode, 0); {'DOWN', Ref, process, WsHandleLoopPid, Reason} -> Code = case Reason of normal -> 1000; % normal close _ -> - ?ERROR_MSG("linked websocket controlling loop crashed " + ?ERROR_MSG("Linked websocket controlling loop crashed " "with reason: ~p", [Reason]), 1011 % internal error end, erlang:demonitor(Ref), websocket_close(Socket, WsHandleLoopPid, SocketMode, Code); - {send, Data} -> + {text, Data} -> SocketMode:send(Socket, encode_frame(Data, 1)), ws_loop(FrameInfo, Socket, WsHandleLoopPid, - SocketMode); + SocketMode, Shaper); + {data, Data} -> + SocketMode:send(Socket, encode_frame(Data, 2)), + ws_loop(FrameInfo, Socket, WsHandleLoopPid, + SocketMode, Shaper); {ping, Data} -> SocketMode:send(Socket, encode_frame(Data, 9)), ws_loop(FrameInfo, Socket, WsHandleLoopPid, - SocketMode); + SocketMode, Shaper); shutdown -> - ?DEBUG("shutdown request received, closing websocket " + ?DEBUG("Shutdown request received, closing websocket " "with pid ~p", [self()]), websocket_close(Socket, WsHandleLoopPid, SocketMode, 1001); % going away _Ignored -> - ?WARNING_MSG("received unexpected message, ignoring: ~p", + ?WARNING_MSG("Received unexpected message, ignoring: ~p", [_Ignored]), ws_loop(FrameInfo, Socket, WsHandleLoopPid, - SocketMode) + SocketMode, Shaper) end. encode_frame(Data, Opcode) -> @@ -328,7 +356,7 @@ process_frame(#frame_info{unprocessed = none, 8 -> % Close CloseCode = case Unmasked of <<Code:16/integer-big, Message/binary>> -> - ?DEBUG("WebSocket close op: ~p ~s", + ?DEBUG("WebSocket close op: ~p ~ts", [Code, Message]), Code; <<Code:16/integer-big>> -> @@ -373,17 +401,17 @@ process_frame(#frame_info{unprocessed = process_frame(FrameInfo#frame_info{unprocessed = <<>>}, <<UnprocessedPre/binary, Data/binary>>). -handle_data(tcp, FrameInfo, Data, Socket, WsHandleLoopPid, fast_tls) -> +handle_data(tcp, FrameInfo, Data, Socket, WsHandleLoopPid, fast_tls, Shaper) -> case fast_tls:recv_data(Socket, Data) of {ok, NewData} -> - handle_data_int(FrameInfo, NewData, Socket, WsHandleLoopPid, fast_tls); + handle_data_int(FrameInfo, NewData, Socket, WsHandleLoopPid, fast_tls, Shaper); {error, Error} -> {error, Error} end; -handle_data(_, FrameInfo, Data, Socket, WsHandleLoopPid, SockMod) -> - handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SockMod). +handle_data(_, FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper) -> + handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SockMod, Shaper). -handle_data_int(FrameInfo, Data, _Socket, WsHandleLoopPid, _SocketMode) -> +handle_data_int(FrameInfo, Data, Socket, WsHandleLoopPid, SocketMode, Shaper) -> {NewFrameInfo, Recv, Send} = process_frame(FrameInfo, Data), lists:foreach(fun (El) -> case El of @@ -396,7 +424,7 @@ handle_data_int(FrameInfo, Data, _Socket, WsHandleLoopPid, _SocketMode) -> end end, Recv), - {NewFrameInfo, Send}. + {NewFrameInfo, Send, handle_shaping(Data, Socket, SocketMode, Shaper)}. websocket_close(Socket, WsHandleLoopPid, SocketMode, CloseCode) when CloseCode > 0 -> @@ -406,3 +434,30 @@ websocket_close(Socket, WsHandleLoopPid, websocket_close(Socket, WsHandleLoopPid, SocketMode, _CloseCode) -> WsHandleLoopPid ! closed, SocketMode:close(Socket). + +get_origin() -> + ejabberd_option:websocket_origin(). + +handle_shaping(_Data, _Socket, _SocketMode, none) -> + none; +handle_shaping(Data, Socket, SocketMode, Shaper) -> + {NewShaper, Pause} = ejabberd_shaper:update(Shaper, byte_size(Data)), + if Pause > 0 -> + activate_after(Socket, self(), Pause); + true -> activate(Socket, SocketMode, once) + end, + NewShaper. + +activate(Socket, SockMod, ActiveState) -> + case SockMod of + gen_tcp -> inet:setopts(Socket, [{active, ActiveState}]); + _ -> SockMod:setopts(Socket, [{active, ActiveState}]) + end. + +activate_after(Socket, Pid, Pause) -> + if Pause > 0 -> + erlang:send_after(Pause, Pid, {tcp, Socket, <<>>}); + true -> + Pid ! {tcp, Socket, <<>>} + end, + ok. |