aboutsummaryrefslogtreecommitdiff
path: root/src/stun/ejabberd_stun.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/stun/ejabberd_stun.erl')
-rw-r--r--src/stun/ejabberd_stun.erl224
1 files changed, 102 insertions, 122 deletions
diff --git a/src/stun/ejabberd_stun.erl b/src/stun/ejabberd_stun.erl
index dbc681cb7..1046fff11 100644
--- a/src/stun/ejabberd_stun.erl
+++ b/src/stun/ejabberd_stun.erl
@@ -30,35 +30,31 @@
-behaviour(gen_fsm).
%% API
--export([start_link/2,
- start/2,
- socket_type/0,
+-export([start_link/2, start/2, socket_type/0,
udp_recv/5]).
%% gen_fsm callbacks
--export([init/1,
- handle_event/3,
- handle_sync_event/4,
- handle_info/3,
- terminate/3,
- code_change/4]).
+-export([init/1, handle_event/3, handle_sync_event/4,
+ handle_info/3, terminate/3, code_change/4]).
%% gen_fsm states
--export([wait_for_tls/2,
- session_established/2]).
+-export([wait_for_tls/2, session_established/2]).
-include("ejabberd.hrl").
+
-include("stun.hrl").
--define(MAX_BUF_SIZE, 64*1024). %% 64kb
--define(TIMEOUT, 10000). %% 10 sec
+-define(MAX_BUF_SIZE, 64 * 1024).
+
+-define(TIMEOUT, 10000).
--record(state, {sock,
- sock_mod = gen_tcp,
- certfile,
- peer,
- tref,
- buf = <<>>}).
+-record(state,
+ {sock :: inet:socket() | tls:tls_socket(),
+ sock_mod = gen_tcp :: gen_udp | gen_tcp | tls,
+ certfile :: binary(),
+ peer = {{0,0,0,0}, 0} :: {inet:ip_address(), inet:port_number()},
+ tref = make_ref() :: reference(),
+ buf = <<>> :: binary()}).
%%====================================================================
%% API
@@ -69,23 +65,20 @@ start({gen_tcp, Sock}, Opts) ->
start_link(Sock, Opts) ->
gen_fsm:start_link(?MODULE, [Sock, Opts], []).
-socket_type() ->
- raw.
+socket_type() -> raw.
udp_recv(Sock, Addr, Port, Data, _Opts) ->
case stun_codec:decode(Data) of
- {ok, Msg, <<>>} ->
- ?DEBUG("got:~n~p", [Msg]),
- case process(Addr, Port, Msg) of
- RespMsg when is_record(RespMsg, stun) ->
- ?DEBUG("sent:~n~p", [RespMsg]),
- Data1 = stun_codec:encode(RespMsg),
- gen_udp:send(Sock, Addr, Port, Data1);
- _ ->
- ok
- end;
- _ ->
- ok
+ {ok, Msg, <<>>} ->
+ ?DEBUG("got:~n~p", [Msg]),
+ case process(Addr, Port, Msg) of
+ RespMsg when is_record(RespMsg, stun) ->
+ ?DEBUG("sent:~n~p", [RespMsg]),
+ Data1 = stun_codec:encode(RespMsg),
+ gen_udp:send(Sock, Addr, Port, Data1);
+ _ -> ok
+ end;
+ _ -> ok
end.
%%====================================================================
@@ -93,38 +86,38 @@ udp_recv(Sock, Addr, Port, Data, _Opts) ->
%%====================================================================
init([Sock, Opts]) ->
case inet:peername(Sock) of
- {ok, Addr} ->
- inet:setopts(Sock, [{active, once}]),
- TRef = erlang:start_timer(?TIMEOUT, self(), stop),
- State = #state{sock = Sock, peer = Addr, tref = TRef},
- case proplists:get_value(certfile, Opts) of
- undefined ->
- {ok, session_established, State};
- CertFile ->
- {ok, wait_for_tls, State#state{certfile = CertFile}}
- end;
- Err ->
- Err
+ {ok, Addr} ->
+ inet:setopts(Sock, [{active, once}]),
+ TRef = erlang:start_timer(?TIMEOUT, self(), stop),
+ State = #state{sock = Sock, peer = Addr, tref = TRef},
+ case proplists:get_value(certfile, Opts) of
+ undefined -> {ok, session_established, State};
+ CertFile ->
+ {ok, wait_for_tls, State#state{certfile = CertFile}}
+ end;
+ Err -> Err
end.
wait_for_tls(Event, State) ->
- ?INFO_MSG("unexpected event in wait_for_tls: ~p", [Event]),
+ ?INFO_MSG("unexpected event in wait_for_tls: ~p",
+ [Event]),
{next_state, wait_for_tls, State}.
-session_established(Msg, State) when is_record(Msg, stun) ->
+session_established(Msg, State)
+ when is_record(Msg, stun) ->
?DEBUG("got:~n~p", [Msg]),
{Addr, Port} = State#state.peer,
case process(Addr, Port, Msg) of
- Resp when is_record(Resp, stun) ->
- ?DEBUG("sent:~n~p", [Resp]),
- Data = stun_codec:encode(Resp),
- (State#state.sock_mod):send(State#state.sock, Data);
- _ ->
- ok
+ Resp when is_record(Resp, stun) ->
+ ?DEBUG("sent:~n~p", [Resp]),
+ Data = stun_codec:encode(Resp),
+ (State#state.sock_mod):send(State#state.sock, Data);
+ _ -> ok
end,
{next_state, session_established, State};
session_established(Event, State) ->
- ?INFO_MSG("unexpected event in session_established: ~p", [Event]),
+ ?INFO_MSG("unexpected event in session_established: ~p",
+ [Event]),
{next_state, session_established, State}.
handle_event(_Event, StateName, State) ->
@@ -133,42 +126,38 @@ handle_event(_Event, StateName, State) ->
handle_sync_event(_Event, _From, StateName, State) ->
{reply, {error, badarg}, StateName, State}.
-handle_info({tcp, Sock, TLSData}, wait_for_tls, State) ->
+handle_info({tcp, Sock, TLSData}, wait_for_tls,
+ State) ->
Buf = <<(State#state.buf)/binary, TLSData/binary>>,
- %% Check if the initial message is a TLS handshake
case Buf of
- _ when size(Buf) < 3 ->
- {next_state, wait_for_tls,
- update_state(State#state{buf = Buf})};
- <<_:16, 1, _/binary>> ->
- TLSOpts = [{certfile, State#state.certfile}],
- {ok, TLSSock} = tls:tcp_to_tls(Sock, TLSOpts),
- NewState = State#state{sock = TLSSock,
- buf = <<>>,
- sock_mod = tls},
- case tls:recv_data(TLSSock, Buf) of
- {ok, Data} ->
- process_data(session_established, NewState, Data);
- _Err ->
- {stop, normal, NewState}
- end;
- _ ->
- process_data(session_established, State, TLSData)
+ _ when byte_size(Buf) < 3 ->
+ {next_state, wait_for_tls,
+ update_state(State#state{buf = Buf})};
+ <<_:16, 1, _/binary>> ->
+ TLSOpts = [{certfile, State#state.certfile}],
+ {ok, TLSSock} = tls:tcp_to_tls(Sock, TLSOpts),
+ NewState = State#state{sock = TLSSock, buf = <<>>,
+ sock_mod = tls},
+ case tls:recv_data(TLSSock, Buf) of
+ {ok, Data} ->
+ process_data(session_established, NewState, Data);
+ _Err -> {stop, normal, NewState}
+ end;
+ _ -> process_data(session_established, State, TLSData)
end;
handle_info({tcp, _Sock, TLSData}, StateName,
#state{sock_mod = tls} = State) ->
case tls:recv_data(State#state.sock, TLSData) of
- {ok, Data} ->
- process_data(StateName, State, Data);
- _Err ->
- {stop, normal, State}
+ {ok, Data} -> process_data(StateName, State, Data);
+ _Err -> {stop, normal, State}
end;
handle_info({tcp, _Sock, Data}, StateName, State) ->
process_data(StateName, State, Data);
handle_info({tcp_closed, _Sock}, _StateName, State) ->
?DEBUG("connection reset by peer", []),
{stop, normal, State};
-handle_info({tcp_error, _Sock, Reason}, _StateName, State) ->
+handle_info({tcp_error, _Sock, Reason}, _StateName,
+ State) ->
?DEBUG("connection error: ~p", [Reason]),
{stop, normal, State};
handle_info({timeout, TRef, stop}, _StateName,
@@ -188,58 +177,55 @@ code_change(_OldVsn, StateName, State, _Extra) ->
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
-process(Addr, Port, #stun{class = request, unsupported = []} = Msg) ->
+process(Addr, Port,
+ #stun{class = request, unsupported = []} = Msg) ->
Resp = prepare_response(Msg),
- if Msg#stun.method == ?STUN_METHOD_BINDING ->
- case stun_codec:version(Msg) of
- old ->
- Resp#stun{class = response,
- 'MAPPED-ADDRESS' = {Addr, Port}};
- new ->
- Resp#stun{class = response,
- 'XOR-MAPPED-ADDRESS' = {Addr, Port}}
- end;
+ if Msg#stun.method == (?STUN_METHOD_BINDING) ->
+ case stun_codec:version(Msg) of
+ old ->
+ Resp#stun{class = response,
+ 'MAPPED-ADDRESS' = {Addr, Port}};
+ new ->
+ Resp#stun{class = response,
+ 'XOR-MAPPED-ADDRESS' = {Addr, Port}}
+ end;
true ->
- Resp#stun{class = error,
- 'ERROR-CODE' = {405, <<"Method Not Allowed">>}}
+ Resp#stun{class = error,
+ 'ERROR-CODE' = {405, <<"Method Not Allowed">>}}
end;
process(_Addr, _Port, #stun{class = request} = Msg) ->
Resp = prepare_response(Msg),
Resp#stun{class = error,
'UNKNOWN-ATTRIBUTES' = Msg#stun.unsupported,
'ERROR-CODE' = {420, stun_codec:reason(420)}};
-process(_Addr, _Port, _Msg) ->
- pass.
+process(_Addr, _Port, _Msg) -> pass.
prepare_response(Msg) ->
- Version = list_to_binary("ejabberd " ++ ?VERSION),
- #stun{method = Msg#stun.method,
- magic = Msg#stun.magic,
- trid = Msg#stun.trid,
- 'SOFTWARE' = Version}.
+ Version = <<"ejabberd ", (iolist_to_binary(?VERSION))/binary>>,
+ #stun{method = Msg#stun.method, magic = Msg#stun.magic,
+ trid = Msg#stun.trid, 'SOFTWARE' = Version}.
-process_data(NextStateName, #state{buf = Buf} = State, Data) ->
+process_data(NextStateName, #state{buf = Buf} = State,
+ Data) ->
NewBuf = <<Buf/binary, Data/binary>>,
case stun_codec:decode(NewBuf) of
- {ok, Msg, Tail} ->
- gen_fsm:send_event(self(), Msg),
- process_data(NextStateName, State#state{buf = <<>>}, Tail);
- empty ->
- NewState = State#state{buf = <<>>},
- {next_state, NextStateName, update_state(NewState)};
- more when size(NewBuf) < ?MAX_BUF_SIZE ->
- NewState = State#state{buf = NewBuf},
- {next_state, NextStateName, update_state(NewState)};
- _ ->
- {stop, normal, State}
+ {ok, Msg, Tail} ->
+ gen_fsm:send_event(self(), Msg),
+ process_data(NextStateName, State#state{buf = <<>>},
+ Tail);
+ empty ->
+ NewState = State#state{buf = <<>>},
+ {next_state, NextStateName, update_state(NewState)};
+ more when byte_size(NewBuf) < (?MAX_BUF_SIZE) ->
+ NewState = State#state{buf = NewBuf},
+ {next_state, NextStateName, update_state(NewState)};
+ _ -> {stop, normal, State}
end.
update_state(#state{sock = Sock} = State) ->
case State#state.sock_mod of
- gen_tcp ->
- inet:setopts(Sock, [{active, once}]);
- SockMod ->
- SockMod:setopts(Sock, [{active, once}])
+ gen_tcp -> inet:setopts(Sock, [{active, once}]);
+ SockMod -> SockMod:setopts(Sock, [{active, once}])
end,
cancel_timer(State#state.tref),
TRef = erlang:start_timer(?TIMEOUT, self(), stop),
@@ -247,13 +233,7 @@ update_state(#state{sock = Sock} = State) ->
cancel_timer(TRef) ->
case erlang:cancel_timer(TRef) of
- false ->
- receive
- {timeout, TRef, _} ->
- ok
- after 0 ->
- ok
- end;
- _ ->
- ok
+ false ->
+ receive {timeout, TRef, _} -> ok after 0 -> ok end;
+ _ -> ok
end.