diff options
Diffstat (limited to 'src/stun/ejabberd_stun.erl')
-rw-r--r-- | src/stun/ejabberd_stun.erl | 224 |
1 files changed, 102 insertions, 122 deletions
diff --git a/src/stun/ejabberd_stun.erl b/src/stun/ejabberd_stun.erl index dbc681cb7..1046fff11 100644 --- a/src/stun/ejabberd_stun.erl +++ b/src/stun/ejabberd_stun.erl @@ -30,35 +30,31 @@ -behaviour(gen_fsm). %% API --export([start_link/2, - start/2, - socket_type/0, +-export([start_link/2, start/2, socket_type/0, udp_recv/5]). %% gen_fsm callbacks --export([init/1, - handle_event/3, - handle_sync_event/4, - handle_info/3, - terminate/3, - code_change/4]). +-export([init/1, handle_event/3, handle_sync_event/4, + handle_info/3, terminate/3, code_change/4]). %% gen_fsm states --export([wait_for_tls/2, - session_established/2]). +-export([wait_for_tls/2, session_established/2]). -include("ejabberd.hrl"). + -include("stun.hrl"). --define(MAX_BUF_SIZE, 64*1024). %% 64kb --define(TIMEOUT, 10000). %% 10 sec +-define(MAX_BUF_SIZE, 64 * 1024). + +-define(TIMEOUT, 10000). --record(state, {sock, - sock_mod = gen_tcp, - certfile, - peer, - tref, - buf = <<>>}). +-record(state, + {sock :: inet:socket() | tls:tls_socket(), + sock_mod = gen_tcp :: gen_udp | gen_tcp | tls, + certfile :: binary(), + peer = {{0,0,0,0}, 0} :: {inet:ip_address(), inet:port_number()}, + tref = make_ref() :: reference(), + buf = <<>> :: binary()}). %%==================================================================== %% API @@ -69,23 +65,20 @@ start({gen_tcp, Sock}, Opts) -> start_link(Sock, Opts) -> gen_fsm:start_link(?MODULE, [Sock, Opts], []). -socket_type() -> - raw. +socket_type() -> raw. udp_recv(Sock, Addr, Port, Data, _Opts) -> case stun_codec:decode(Data) of - {ok, Msg, <<>>} -> - ?DEBUG("got:~n~p", [Msg]), - case process(Addr, Port, Msg) of - RespMsg when is_record(RespMsg, stun) -> - ?DEBUG("sent:~n~p", [RespMsg]), - Data1 = stun_codec:encode(RespMsg), - gen_udp:send(Sock, Addr, Port, Data1); - _ -> - ok - end; - _ -> - ok + {ok, Msg, <<>>} -> + ?DEBUG("got:~n~p", [Msg]), + case process(Addr, Port, Msg) of + RespMsg when is_record(RespMsg, stun) -> + ?DEBUG("sent:~n~p", [RespMsg]), + Data1 = stun_codec:encode(RespMsg), + gen_udp:send(Sock, Addr, Port, Data1); + _ -> ok + end; + _ -> ok end. %%==================================================================== @@ -93,38 +86,38 @@ udp_recv(Sock, Addr, Port, Data, _Opts) -> %%==================================================================== init([Sock, Opts]) -> case inet:peername(Sock) of - {ok, Addr} -> - inet:setopts(Sock, [{active, once}]), - TRef = erlang:start_timer(?TIMEOUT, self(), stop), - State = #state{sock = Sock, peer = Addr, tref = TRef}, - case proplists:get_value(certfile, Opts) of - undefined -> - {ok, session_established, State}; - CertFile -> - {ok, wait_for_tls, State#state{certfile = CertFile}} - end; - Err -> - Err + {ok, Addr} -> + inet:setopts(Sock, [{active, once}]), + TRef = erlang:start_timer(?TIMEOUT, self(), stop), + State = #state{sock = Sock, peer = Addr, tref = TRef}, + case proplists:get_value(certfile, Opts) of + undefined -> {ok, session_established, State}; + CertFile -> + {ok, wait_for_tls, State#state{certfile = CertFile}} + end; + Err -> Err end. wait_for_tls(Event, State) -> - ?INFO_MSG("unexpected event in wait_for_tls: ~p", [Event]), + ?INFO_MSG("unexpected event in wait_for_tls: ~p", + [Event]), {next_state, wait_for_tls, State}. -session_established(Msg, State) when is_record(Msg, stun) -> +session_established(Msg, State) + when is_record(Msg, stun) -> ?DEBUG("got:~n~p", [Msg]), {Addr, Port} = State#state.peer, case process(Addr, Port, Msg) of - Resp when is_record(Resp, stun) -> - ?DEBUG("sent:~n~p", [Resp]), - Data = stun_codec:encode(Resp), - (State#state.sock_mod):send(State#state.sock, Data); - _ -> - ok + Resp when is_record(Resp, stun) -> + ?DEBUG("sent:~n~p", [Resp]), + Data = stun_codec:encode(Resp), + (State#state.sock_mod):send(State#state.sock, Data); + _ -> ok end, {next_state, session_established, State}; session_established(Event, State) -> - ?INFO_MSG("unexpected event in session_established: ~p", [Event]), + ?INFO_MSG("unexpected event in session_established: ~p", + [Event]), {next_state, session_established, State}. handle_event(_Event, StateName, State) -> @@ -133,42 +126,38 @@ handle_event(_Event, StateName, State) -> handle_sync_event(_Event, _From, StateName, State) -> {reply, {error, badarg}, StateName, State}. -handle_info({tcp, Sock, TLSData}, wait_for_tls, State) -> +handle_info({tcp, Sock, TLSData}, wait_for_tls, + State) -> Buf = <<(State#state.buf)/binary, TLSData/binary>>, - %% Check if the initial message is a TLS handshake case Buf of - _ when size(Buf) < 3 -> - {next_state, wait_for_tls, - update_state(State#state{buf = Buf})}; - <<_:16, 1, _/binary>> -> - TLSOpts = [{certfile, State#state.certfile}], - {ok, TLSSock} = tls:tcp_to_tls(Sock, TLSOpts), - NewState = State#state{sock = TLSSock, - buf = <<>>, - sock_mod = tls}, - case tls:recv_data(TLSSock, Buf) of - {ok, Data} -> - process_data(session_established, NewState, Data); - _Err -> - {stop, normal, NewState} - end; - _ -> - process_data(session_established, State, TLSData) + _ when byte_size(Buf) < 3 -> + {next_state, wait_for_tls, + update_state(State#state{buf = Buf})}; + <<_:16, 1, _/binary>> -> + TLSOpts = [{certfile, State#state.certfile}], + {ok, TLSSock} = tls:tcp_to_tls(Sock, TLSOpts), + NewState = State#state{sock = TLSSock, buf = <<>>, + sock_mod = tls}, + case tls:recv_data(TLSSock, Buf) of + {ok, Data} -> + process_data(session_established, NewState, Data); + _Err -> {stop, normal, NewState} + end; + _ -> process_data(session_established, State, TLSData) end; handle_info({tcp, _Sock, TLSData}, StateName, #state{sock_mod = tls} = State) -> case tls:recv_data(State#state.sock, TLSData) of - {ok, Data} -> - process_data(StateName, State, Data); - _Err -> - {stop, normal, State} + {ok, Data} -> process_data(StateName, State, Data); + _Err -> {stop, normal, State} end; handle_info({tcp, _Sock, Data}, StateName, State) -> process_data(StateName, State, Data); handle_info({tcp_closed, _Sock}, _StateName, State) -> ?DEBUG("connection reset by peer", []), {stop, normal, State}; -handle_info({tcp_error, _Sock, Reason}, _StateName, State) -> +handle_info({tcp_error, _Sock, Reason}, _StateName, + State) -> ?DEBUG("connection error: ~p", [Reason]), {stop, normal, State}; handle_info({timeout, TRef, stop}, _StateName, @@ -188,58 +177,55 @@ code_change(_OldVsn, StateName, State, _Extra) -> %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- -process(Addr, Port, #stun{class = request, unsupported = []} = Msg) -> +process(Addr, Port, + #stun{class = request, unsupported = []} = Msg) -> Resp = prepare_response(Msg), - if Msg#stun.method == ?STUN_METHOD_BINDING -> - case stun_codec:version(Msg) of - old -> - Resp#stun{class = response, - 'MAPPED-ADDRESS' = {Addr, Port}}; - new -> - Resp#stun{class = response, - 'XOR-MAPPED-ADDRESS' = {Addr, Port}} - end; + if Msg#stun.method == (?STUN_METHOD_BINDING) -> + case stun_codec:version(Msg) of + old -> + Resp#stun{class = response, + 'MAPPED-ADDRESS' = {Addr, Port}}; + new -> + Resp#stun{class = response, + 'XOR-MAPPED-ADDRESS' = {Addr, Port}} + end; true -> - Resp#stun{class = error, - 'ERROR-CODE' = {405, <<"Method Not Allowed">>}} + Resp#stun{class = error, + 'ERROR-CODE' = {405, <<"Method Not Allowed">>}} end; process(_Addr, _Port, #stun{class = request} = Msg) -> Resp = prepare_response(Msg), Resp#stun{class = error, 'UNKNOWN-ATTRIBUTES' = Msg#stun.unsupported, 'ERROR-CODE' = {420, stun_codec:reason(420)}}; -process(_Addr, _Port, _Msg) -> - pass. +process(_Addr, _Port, _Msg) -> pass. prepare_response(Msg) -> - Version = list_to_binary("ejabberd " ++ ?VERSION), - #stun{method = Msg#stun.method, - magic = Msg#stun.magic, - trid = Msg#stun.trid, - 'SOFTWARE' = Version}. + Version = <<"ejabberd ", (iolist_to_binary(?VERSION))/binary>>, + #stun{method = Msg#stun.method, magic = Msg#stun.magic, + trid = Msg#stun.trid, 'SOFTWARE' = Version}. -process_data(NextStateName, #state{buf = Buf} = State, Data) -> +process_data(NextStateName, #state{buf = Buf} = State, + Data) -> NewBuf = <<Buf/binary, Data/binary>>, case stun_codec:decode(NewBuf) of - {ok, Msg, Tail} -> - gen_fsm:send_event(self(), Msg), - process_data(NextStateName, State#state{buf = <<>>}, Tail); - empty -> - NewState = State#state{buf = <<>>}, - {next_state, NextStateName, update_state(NewState)}; - more when size(NewBuf) < ?MAX_BUF_SIZE -> - NewState = State#state{buf = NewBuf}, - {next_state, NextStateName, update_state(NewState)}; - _ -> - {stop, normal, State} + {ok, Msg, Tail} -> + gen_fsm:send_event(self(), Msg), + process_data(NextStateName, State#state{buf = <<>>}, + Tail); + empty -> + NewState = State#state{buf = <<>>}, + {next_state, NextStateName, update_state(NewState)}; + more when byte_size(NewBuf) < (?MAX_BUF_SIZE) -> + NewState = State#state{buf = NewBuf}, + {next_state, NextStateName, update_state(NewState)}; + _ -> {stop, normal, State} end. update_state(#state{sock = Sock} = State) -> case State#state.sock_mod of - gen_tcp -> - inet:setopts(Sock, [{active, once}]); - SockMod -> - SockMod:setopts(Sock, [{active, once}]) + gen_tcp -> inet:setopts(Sock, [{active, once}]); + SockMod -> SockMod:setopts(Sock, [{active, once}]) end, cancel_timer(State#state.tref), TRef = erlang:start_timer(?TIMEOUT, self(), stop), @@ -247,13 +233,7 @@ update_state(#state{sock = Sock} = State) -> cancel_timer(TRef) -> case erlang:cancel_timer(TRef) of - false -> - receive - {timeout, TRef, _} -> - ok - after 0 -> - ok - end; - _ -> - ok + false -> + receive {timeout, TRef, _} -> ok after 0 -> ok end; + _ -> ok end. |