summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlexey Shchepin <alexey@process-one.net>2005-10-25 01:08:37 +0000
committerAlexey Shchepin <alexey@process-one.net>2005-10-25 01:08:37 +0000
commit1433dafe6bda9840cb687c5c3270584fb3ee55d1 (patch)
tree0da87012f614b5de26f1d8b28233fff4225e92c7 /src
parent* src/ejabberd_app.erl: Try to load tls_drv at startup to avoid (diff)
* src/tls/tls_drv.c: Support for "connect" method
* src/tls/tls.erl: Likewise * src/ejabberd_s2s_in.erl: Support for STARTTLS+Dialback * src/ejabberd_s2s_out.erl: Likewise * src/ejabberd_receiver.erl: Added a few hacks ({active,once} mode should be used instead of recv/3 call to avoid them) * src/ejabberd_config.erl: Added s2s_use_starttls and s2s_certfile options * src/ejabberd.cfg.example: Likewise SVN Revision: 426
Diffstat (limited to '')
-rw-r--r--src/ejabberd.cfg.example5
-rw-r--r--src/ejabberd_config.erl4
-rw-r--r--src/ejabberd_receiver.erl37
-rw-r--r--src/ejabberd_s2s_in.erl116
-rw-r--r--src/ejabberd_s2s_out.erl342
-rw-r--r--src/tls/tls.erl30
-rw-r--r--src/tls/tls_drv.c44
7 files changed, 414 insertions, 164 deletions
diff --git a/src/ejabberd.cfg.example b/src/ejabberd.cfg.example
index df3c3461..1cc468bf 100644
--- a/src/ejabberd.cfg.example
+++ b/src/ejabberd.cfg.example
@@ -115,6 +115,11 @@
[{password, "secret"}]}]}
]}.
+
+% Use STARTTLS+Dialback for S2S connections
+{s2s_use_starttls, true}.
+{s2s_certfile, "./ssl.pem"}.
+
% If SRV lookup fails, then port 5269 is used to communicate with remote server
{outgoing_s2s_port, 5269}.
diff --git a/src/ejabberd_config.erl b/src/ejabberd_config.erl
index 8bce8eae..51c183a2 100644
--- a/src/ejabberd_config.erl
+++ b/src/ejabberd_config.erl
@@ -108,6 +108,10 @@ process_term(Term, State) ->
add_option(listen, Val, State);
{outgoing_s2s_port, Port} ->
add_option(outgoing_s2s_port, Port, State);
+ {s2s_use_starttls, Port} ->
+ add_option(s2s_use_starttls, Port, State);
+ {s2s_certfile, Port} ->
+ add_option(s2s_certfile, Port, State);
{Opt, Val} ->
lists:foldl(fun(Host, S) -> process_host_term(Term, Host, S) end,
State, State#state.hosts)
diff --git a/src/ejabberd_receiver.erl b/src/ejabberd_receiver.erl
index 1f1897fb..204771c1 100644
--- a/src/ejabberd_receiver.erl
+++ b/src/ejabberd_receiver.erl
@@ -36,24 +36,27 @@ receiver(Socket, SockMod, Shaper, C2SPid) ->
receiver(Socket, SockMod, ShaperState, C2SPid, XMLStreamState, Timeout) ->
Res = (catch SockMod:recv(Socket, 0, Timeout)),
- case Res of
- {ok, Data} ->
- receive
- {starttls, TLSSocket} ->
- xml_stream:close(XMLStreamState),
- XMLStreamState1 = xml_stream:new(C2SPid),
- TLSRes = tls:recv_data(TLSSocket, Data),
- receiver1(TLSSocket, tls,
- ShaperState, C2SPid, XMLStreamState1, Timeout,
- TLSRes)
- after 0 ->
- receiver1(Socket, SockMod,
- ShaperState, C2SPid, XMLStreamState, Timeout,
- Res)
- end;
- _ ->
+ receive
+ {starttls, TLSSocket} ->
+ xml_stream:close(XMLStreamState),
+ XMLStreamState1 = xml_stream:new(C2SPid),
+ TLSRes = case Res of
+ {ok, Data} ->
+ tls:recv_data(TLSSocket, Data);
+ _ ->
+ tls:recv_data(TLSSocket, "")
+ end,
+ receiver1(TLSSocket, tls,
+ ShaperState, C2SPid, XMLStreamState1, Timeout,
+ TLSRes);
+ {change_timeout, NewTimeout} -> % Dirty hack
+ receiver1(Socket, SockMod,
+ ShaperState, C2SPid, XMLStreamState, NewTimeout,
+ Res)
+ after 0 ->
receiver1(Socket, SockMod,
- ShaperState, C2SPid, XMLStreamState, Timeout, Res)
+ ShaperState, C2SPid, XMLStreamState, Timeout,
+ Res)
end.
diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl
index 1c09c060..d2b61675 100644
--- a/src/ejabberd_s2s_in.erl
+++ b/src/ejabberd_s2s_in.erl
@@ -14,13 +14,12 @@
%% External exports
-export([start/2,
- start_link/2,
- send_text/2,
- send_element/2]).
+ start_link/2]).
%% gen_fsm callbacks
-export([init/1,
wait_for_stream/2,
+ wait_for_feature_request/2,
stream_established/2,
handle_event/3,
handle_sync_event/4,
@@ -34,9 +33,13 @@
-define(DICT, dict).
-record(state, {socket,
+ sockmod,
receiver,
streamid,
shaper,
+ tls = false,
+ tls_enabled = false,
+ tls_options = [],
connections = ?DICT:new(),
timer}).
@@ -49,13 +52,13 @@
-define(FSMOPTS, []).
-endif.
--define(STREAM_HEADER,
+-define(STREAM_HEADER(Version),
("<?xml version='1.0'?>"
"<stream:stream "
"xmlns:stream='http://etherx.jabber.org/streams' "
"xmlns='jabber:server' "
"xmlns:db='jabber:server:dialback' "
- "id='" ++ StateData#state.streamid ++ "'>")
+ "id='" ++ StateData#state.streamid ++ "'" ++ Version ++ ">")
).
-define(STREAM_TRAILER, "</stream:stream>").
@@ -96,12 +99,28 @@ init([{SockMod, Socket}, Opts]) ->
{value, {_, S}} -> S;
_ -> none
end,
+ StartTLS = case ejabberd_config:get_local_option(s2s_use_starttls) of
+ undefined ->
+ false;
+ UseStartTLS ->
+ UseStartTLS
+ end,
+ TLSOpts = case ejabberd_config:get_local_option(s2s_certfile) of
+ undefined ->
+ [];
+ CertFile ->
+ [{certfile, CertFile}]
+ end,
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
{ok, wait_for_stream,
#state{socket = Socket,
+ sockmod = SockMod,
receiver = ReceiverPid,
streamid = new_id(),
shaper = Shaper,
+ tls = StartTLS,
+ tls_enabled = false,
+ tls_options = TLSOpts,
timer = Timer}}.
%%----------------------------------------------------------------------
@@ -113,18 +132,28 @@ init([{SockMod, Socket}, Opts]) ->
wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
% TODO
- case {xml:get_attr_s("xmlns", Attrs), xml:get_attr_s("xmlns:db", Attrs)} of
- {"jabber:server", "jabber:server:dialback"} ->
- send_text(StateData#state.socket, ?STREAM_HEADER),
- {next_state, stream_established, StateData#state{}};
+ case {xml:get_attr_s("xmlns", Attrs),
+ xml:get_attr_s("xmlns:db", Attrs),
+ xml:get_attr_s("version", Attrs) == "1.0"} of
+ {"jabber:server", "jabber:server:dialback", true} when
+ StateData#state.tls ->
+ send_text(StateData, ?STREAM_HEADER(" version='1.0'")),
+ send_element(StateData,
+ {xmlelement, "stream:features", [],
+ [{xmlelement, "starttls",
+ [{"xmlns", ?NS_TLS}], []}]}),
+ {next_state, wait_for_feature_request, StateData};
+ {"jabber:server", "jabber:server:dialback", _} ->
+ send_text(StateData, ?STREAM_HEADER("")),
+ {next_state, stream_established, StateData};
_ ->
- send_text(StateData#state.socket, ?INVALID_NAMESPACE_ERR),
+ send_text(StateData, ?INVALID_NAMESPACE_ERR),
{stop, normal, StateData}
end;
wait_for_stream({xmlstreamerror, _}, StateData) ->
- send_text(StateData#state.socket,
- ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_text(StateData,
+ ?STREAM_HEADER("") ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_stream(timeout, StateData) ->
@@ -133,6 +162,45 @@ wait_for_stream(timeout, StateData) ->
wait_for_stream(closed, StateData) ->
{stop, normal, StateData}.
+
+wait_for_feature_request({xmlstreamelement, El}, StateData) ->
+ {xmlelement, Name, Attrs, Els} = El,
+ TLS = StateData#state.tls,
+ TLSEnabled = StateData#state.tls_enabled,
+ SockMod = StateData#state.sockmod,
+ case {xml:get_attr_s("xmlns", Attrs), Name} of
+ {?NS_TLS, "starttls"} when TLS == true,
+ TLSEnabled == false,
+ SockMod == gen_tcp ->
+ ?INFO_MSG("starttls", []),
+ Socket = StateData#state.socket,
+ TLSOpts = StateData#state.tls_options,
+ {ok, TLSSocket} = tls:tcp_to_tls(Socket, TLSOpts),
+ ejabberd_receiver:starttls(StateData#state.receiver, TLSSocket),
+ send_element(StateData,
+ {xmlelement, "proceed", [{"xmlns", ?NS_TLS}], []}),
+ {next_state, wait_for_stream,
+ StateData#state{sockmod = tls,
+ socket = TLSSocket,
+ streamid = new_id(),
+ tls_enabled = true
+ }};
+ _ ->
+ stream_established({xmlstreamelement, El}, StateData)
+ end;
+
+wait_for_feature_request({xmlstreamend, _Name}, StateData) ->
+ send_text(StateData, ?STREAM_TRAILER),
+ {stop, normal, StateData};
+
+wait_for_feature_request({xmlstreamerror, _}, StateData) ->
+ send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ {stop, normal, StateData};
+
+wait_for_feature_request(closed, StateData) ->
+ {stop, normal, StateData}.
+
+
stream_established({xmlstreamelement, El}, StateData) ->
cancel_timer(StateData#state.timer),
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
@@ -154,7 +222,7 @@ stream_established({xmlstreamelement, El}, StateData) ->
StateData#state{connections = Conns,
timer = Timer}};
_ ->
- send_text(StateData#state.socket, ?HOST_UNKNOWN_ERR),
+ send_text(StateData, ?HOST_UNKNOWN_ERR),
{stop, normal, StateData}
end;
{verify, To, From, Id, Key} ->
@@ -165,7 +233,7 @@ stream_established({xmlstreamelement, El}, StateData) ->
Type = if Key == Key1 -> "valid";
true -> "invalid"
end,
- send_element(StateData#state.socket,
+ send_element(StateData,
{xmlelement,
"db:verify",
[{"from", To},
@@ -204,7 +272,7 @@ stream_established({xmlstreamelement, El}, StateData) ->
end;
stream_established({valid, From, To}, StateData) ->
- send_element(StateData#state.socket,
+ send_element(StateData,
{xmlelement,
"db:result",
[{"from", To},
@@ -219,7 +287,7 @@ stream_established({valid, From, To}, StateData) ->
{next_state, stream_established, NSD};
stream_established({invalid, From, To}, StateData) ->
- send_element(StateData#state.socket,
+ send_element(StateData,
{xmlelement,
"db:result",
[{"from", To},
@@ -237,8 +305,8 @@ stream_established({xmlstreamend, _Name}, StateData) ->
{stop, normal, StateData};
stream_established({xmlstreamerror, _}, StateData) ->
- send_text(StateData#state.socket,
- ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_text(StateData,
+ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
stream_established(timeout, StateData) ->
@@ -294,7 +362,7 @@ code_change(_OldVsn, StateName, StateData, _Extra) ->
%% {stop, Reason, NewStateData}
%%----------------------------------------------------------------------
handle_info({send_text, Text}, StateName, StateData) ->
- send_text(StateData#state.socket, Text),
+ send_text(StateData, Text),
{next_state, StateName, StateData};
handle_info({timeout, Timer, _}, StateName,
@@ -312,18 +380,18 @@ handle_info(_, StateName, StateData) ->
%%----------------------------------------------------------------------
terminate(Reason, _StateName, StateData) ->
?INFO_MSG("terminated: ~p", [Reason]),
- gen_tcp:close(StateData#state.socket),
+ (StateData#state.sockmod):close(StateData#state.socket),
ok.
%%%----------------------------------------------------------------------
%%% Internal functions
%%%----------------------------------------------------------------------
-send_text(Socket, Text) ->
- gen_tcp:send(Socket,Text).
+send_text(StateData, Text) ->
+ (StateData#state.sockmod):send(StateData#state.socket, Text).
-send_element(Socket, El) ->
- send_text(Socket, xml:element_to_string(El)).
+send_element(StateData, El) ->
+ send_text(StateData, xml:element_to_string(El)).
change_shaper(StateData, Host, JID) ->
diff --git a/src/ejabberd_s2s_out.erl b/src/ejabberd_s2s_out.erl
index aec98ba6..0b3d46f1 100644
--- a/src/ejabberd_s2s_out.erl
+++ b/src/ejabberd_s2s_out.erl
@@ -13,13 +13,15 @@
-behaviour(gen_fsm).
%% External exports
--export([start/3, start_link/3, send_text/2, send_element/2]).
+-export([start/3, start_link/3]).
%% gen_fsm callbacks
-export([init/1,
open_socket/2,
wait_for_stream/2,
wait_for_validation/2,
+ wait_for_features/2,
+ wait_for_starttls_proceed/2,
stream_established/2,
handle_event/3,
handle_sync_event/4,
@@ -30,8 +32,15 @@
-include("ejabberd.hrl").
-include("jlib.hrl").
--record(state, {socket, receiver, streamid,
- myname, server, xmlpid, queue,
+-record(state, {socket, receiver,
+ sockmod,
+ streamid,
+ use_v10,
+ tls = false,
+ tls_required = false,
+ tls_enabled = false,
+ tls_options = [],
+ myname, server, queue,
new = false, verify = false,
timer}).
@@ -49,7 +58,7 @@
"xmlns:stream='http://etherx.jabber.org/streams' "
"xmlns='jabber:server' "
"xmlns:db='jabber:server:dialback' "
- "to='~s'>"
+ "to='~s'~s>"
).
-define(STREAM_TRAILER, "</stream:stream>").
@@ -86,6 +95,19 @@ start_link(From, Host, Type) ->
init([From, Server, Type]) ->
?INFO_MSG("started: ~p", [{From, Server, Type}]),
gen_fsm:send_event(self(), init),
+ TLS = case ejabberd_config:get_local_option(s2s_use_starttls) of
+ undefined ->
+ false;
+ UseStartTLS ->
+ UseStartTLS
+ end,
+ UseV10 = TLS,
+ TLSOpts = case ejabberd_config:get_local_option(s2s_certfile) of
+ undefined ->
+ [];
+ CertFile ->
+ [{certfile, CertFile}, connect]
+ end,
{New, Verify} = case Type of
{new, Key} ->
{Key, false};
@@ -93,7 +115,10 @@ init([From, Server, Type]) ->
{false, {Pid, Key, SID}}
end,
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
- {ok, open_socket, #state{queue = queue:new(),
+ {ok, open_socket, #state{use_v10 = UseV10,
+ tls = TLS,
+ tls_options = TLSOpts,
+ queue = queue:new(),
myname = From,
server = Server,
new = New,
@@ -113,23 +138,34 @@ open_socket(init, StateData) ->
ASCIIAddr ->
?DEBUG("s2s_out: connecting to ~s:~p~n", [ASCIIAddr, Port]),
case gen_tcp:connect(ASCIIAddr, Port,
- [binary, {packet, 0}]) of
+ [binary, {packet, 0},
+ {active, false}]) of
{ok, _Socket} = R -> R;
{error, Reason1} ->
?DEBUG("s2s_out: connect return ~p~n", [Reason1]),
catch gen_tcp:connect(Addr, Port,
- [binary, {packet, 0}, inet6])
+ [binary, {packet, 0},
+ {active, false}, inet6])
end
end,
case Res of
{ok, Socket} ->
- XMLStreamPid = xml_stream:start(self()),
- send_text(Socket, io_lib:format(?STREAM_HEADER,
- [StateData#state.server])),
- {next_state, wait_for_stream,
- StateData#state{socket = Socket,
- xmlpid = XMLStreamPid,
- streamid = new_id()}};
+ ReceiverPid = ejabberd_receiver:start(Socket, gen_tcp, none),
+ Version = if
+ StateData#state.use_v10 ->
+ " version='1.0'";
+ true ->
+ ""
+ end,
+ NewStateData = StateData#state{socket = Socket,
+ sockmod = gen_tcp,
+ tls_enabled = false,
+ receiver = ReceiverPid,
+ streamid = new_id()},
+ send_text(NewStateData, io_lib:format(?STREAM_HEADER,
+ [StateData#state.server,
+ Version])),
+ {next_state, wait_for_stream, NewStateData};
{error, Reason} ->
?DEBUG("s2s_out: inet6 connect return ~p~n", [Reason]),
Error = ?ERR_REMOTE_SERVER_NOT_FOUND,
@@ -140,58 +176,36 @@ open_socket(init, StateData) ->
Error = ?ERR_REMOTE_SERVER_NOT_FOUND,
bounce_messages(Error),
{stop, normal, StateData}
- end.
+ end;
+open_socket(_, StateData) ->
+ {next_state, open_socket, StateData}.
wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
- % TODO
- case {xml:get_attr_s("xmlns", Attrs), xml:get_attr_s("xmlns:db", Attrs)} of
- {"jabber:server", "jabber:server:dialback"} ->
- Server = StateData#state.server,
- New = case StateData#state.new of
- false ->
- case ejabberd_s2s:try_register(
- {StateData#state.myname, Server}) of
- {key, Key} ->
- Key;
- false ->
- false
- end;
- Key ->
- Key
- end,
- case New of
- false ->
- ok;
- Key1 ->
- send_element(StateData#state.socket,
- {xmlelement,
- "db:result",
- [{"from", StateData#state.myname},
- {"to", Server}],
- [{xmlcdata, Key1}]})
- end,
- case StateData#state.verify of
- false ->
- ok;
- {Pid, Key2, SID} ->
- send_element(StateData#state.socket,
- {xmlelement,
- "db:verify",
- [{"from", StateData#state.myname},
- {"to", StateData#state.server},
- {"id", SID}],
- [{xmlcdata, Key2}]})
- end,
- {next_state, wait_for_validation, StateData#state{new = New}};
+ case {xml:get_attr_s("xmlns", Attrs),
+ xml:get_attr_s("xmlns:db", Attrs),
+ xml:get_attr_s("version", Attrs) == "1.0"} of
+ {"jabber:server", "jabber:server:dialback", false} ->
+ send_db_request(StateData);
+ {"jabber:server", "jabber:server:dialback", true} when
+ StateData#state.use_v10 ->
+ {next_state, wait_for_features, StateData};
+ {"jabber:server", "", true} when StateData#state.use_v10 ->
+ ?INFO_MSG("restarted: ~p", [{StateData#state.myname,
+ StateData#state.server}]),
+ % TODO: clear message queue
+ (StateData#state.sockmod):close(StateData#state.socket),
+ gen_fsm:send_event(self(), init),
+ {next_state, open_socket, StateData#state{socket = undefined,
+ use_v10 = false}};
_ ->
- send_text(StateData#state.socket, ?INVALID_NAMESPACE_ERR),
+ send_text(StateData, ?INVALID_NAMESPACE_ERR),
{stop, normal, StateData}
end;
wait_for_stream({xmlstreamerror, _}, StateData) ->
- send_text(StateData#state.socket,
- ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_text(StateData,
+ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_stream(timeout, StateData) ->
@@ -208,7 +222,7 @@ wait_for_validation({xmlstreamelement, El}, StateData) ->
?INFO_MSG("recv result: ~p", [{From, To, Id, Type}]),
case Type of
"valid" ->
- send_queue(StateData#state.socket, StateData#state.queue),
+ send_queue(StateData, StateData#state.queue),
{next_state, stream_established,
StateData#state{queue = queue:new()}};
_ ->
@@ -248,8 +262,8 @@ wait_for_validation({xmlstreamend, Name}, StateData) ->
{stop, normal, StateData};
wait_for_validation({xmlstreamerror, _}, StateData) ->
- send_text(StateData#state.socket,
- ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_text(StateData,
+ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
wait_for_validation(timeout, StateData) ->
@@ -259,6 +273,111 @@ wait_for_validation(closed, StateData) ->
{stop, normal, StateData}.
+wait_for_features({xmlstreamelement, El}, StateData) ->
+ case El of
+ {xmlelement, "stream:features", _Attrs, Els} ->
+ {StartTLS, StartTLSRequired} =
+ lists:foldl(
+ fun({xmlelement, "starttls", Attrs1, Els1} = El1, Acc) ->
+ case xml:get_attr_s("xmlns", Attrs1) of
+ ?NS_TLS ->
+ Req = case xml:get_subtag(El1, "required") of
+ {xmlelement, _, _, _} -> true;
+ false -> false
+ end,
+ {true, Req};
+ _ ->
+ Acc
+ end;
+ (_, Acc) ->
+ Acc
+ end, {false, false}, Els),
+ if
+ StartTLS and StateData#state.tls and
+ (not StateData#state.tls_enabled) ->
+ StateData#state.receiver ! {change_timeout, 100},
+ send_element(StateData,
+ {xmlelement, "starttls",
+ [{"xmlns", ?NS_TLS}], []}),
+ {next_state, wait_for_starttls_proceed, StateData};
+ StartTLSRequired and (not StateData#state.tls) ->
+ ?INFO_MSG("restarted: ~p", [{StateData#state.myname,
+ StateData#state.server}]),
+ (StateData#state.sockmod):close(StateData#state.socket),
+ gen_fsm:send_event(self(), init),
+ {next_state, open_socket,
+ StateData#state{socket = undefined,
+ use_v10 = false}};
+ true ->
+ send_db_request(StateData)
+ end;
+ _ ->
+ send_text(StateData,
+ xml:element_to_string(?SERR_BAD_FORMAT) ++
+ ?STREAM_TRAILER),
+ {stop, normal, StateData}
+ end;
+
+wait_for_features({xmlstreamend, Name}, StateData) ->
+ {stop, normal, StateData};
+
+wait_for_features({xmlstreamerror, _}, StateData) ->
+ send_text(StateData,
+ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ {stop, normal, StateData};
+
+wait_for_features(timeout, StateData) ->
+ {stop, normal, StateData};
+
+wait_for_features(closed, StateData) ->
+ {stop, normal, StateData}.
+
+
+wait_for_starttls_proceed({xmlstreamelement, El}, StateData) ->
+ case El of
+ {xmlelement, "proceed", Attrs, _Els} ->
+ case xml:get_attr_s("xmlns", Attrs) of
+ ?NS_TLS ->
+ ?INFO_MSG("starttls: ~p", [{StateData#state.myname,
+ StateData#state.server}]),
+ Socket = StateData#state.socket,
+ TLSOpts = StateData#state.tls_options,
+ {ok, TLSSocket} = tls:tcp_to_tls(Socket, TLSOpts),
+ ejabberd_receiver:starttls(
+ StateData#state.receiver, TLSSocket),
+ StateData#state.receiver ! {change_timeout, infinity},
+ NewStateData = StateData#state{sockmod = tls,
+ socket = TLSSocket,
+ streamid = new_id(),
+ tls_enabled = true
+ },
+ R = send_text(NewStateData,
+ io_lib:format(?STREAM_HEADER,
+ [StateData#state.server,
+ " version='1.0'"])),
+ {next_state, wait_for_stream, NewStateData};
+ _ ->
+ {stop, normal, StateData}
+ end;
+ _ ->
+ {stop, normal, StateData}
+ end;
+
+wait_for_starttls_proceed({xmlstreamend, Name}, StateData) ->
+ {stop, normal, StateData};
+
+wait_for_starttls_proceed({xmlstreamerror, _}, StateData) ->
+ send_text(StateData,
+ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ {stop, normal, StateData};
+
+wait_for_starttls_proceed(timeout, StateData) ->
+ {stop, normal, StateData};
+
+wait_for_starttls_proceed(closed, StateData) ->
+ {stop, normal, StateData}.
+
+
stream_established({xmlstreamelement, El}, StateData) ->
?INFO_MSG("stream established", []),
case is_verify_res(El) of
@@ -290,8 +409,8 @@ stream_established({xmlstreamend, Name}, StateData) ->
{stop, normal, StateData};
stream_established({xmlstreamerror, _}, StateData) ->
- send_text(StateData#state.socket,
- ?STREAM_HEADER ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_text(StateData,
+ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
stream_established(timeout, StateData) ->
@@ -347,7 +466,7 @@ code_change(OldVsn, StateName, StateData, Extra) ->
%% {stop, Reason, NewStateData}
%%----------------------------------------------------------------------
handle_info({send_text, Text}, StateName, StateData) ->
- send_text(StateData#state.socket, Text),
+ send_text(StateData, Text),
cancel_timer(StateData#state.timer),
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
{next_state, StateName, StateData#state{timer = Timer}};
@@ -357,7 +476,7 @@ handle_info({send_element, El}, StateName, StateData) ->
Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
case StateName of
stream_established ->
- send_element(StateData#state.socket, El),
+ send_element(StateData, El),
{next_state, StateName, StateData#state{timer = Timer}};
_ ->
Q = queue:in(El, StateData#state.queue),
@@ -365,17 +484,17 @@ handle_info({send_element, El}, StateName, StateData) ->
timer = Timer}}
end;
-handle_info({tcp, Socket, Data}, StateName, StateData) ->
- xml_stream:send_text(StateData#state.xmlpid, Data),
- {next_state, StateName, StateData};
-
-handle_info({tcp_closed, Socket}, StateName, StateData) ->
- gen_fsm:send_event(self(), closed),
- {next_state, StateName, StateData};
-
-handle_info({tcp_error, Socket, Reason}, StateName, StateData) ->
- gen_fsm:send_event(self(), closed),
- {next_state, StateName, StateData};
+%handle_info({tcp, Socket, Data}, StateName, StateData) ->
+% xml_stream:send_text(StateData#state.xmlpid, Data),
+% {next_state, StateName, StateData};
+%
+%handle_info({tcp_closed, Socket}, StateName, StateData) ->
+% gen_fsm:send_event(self(), closed),
+% {next_state, StateName, StateData};
+%
+%handle_info({tcp_error, Socket, Reason}, StateName, StateData) ->
+% gen_fsm:send_event(self(), closed),
+% {next_state, StateName, StateData};
handle_info({timeout, Timer, _}, StateName,
#state{timer = Timer} = StateData) ->
@@ -404,8 +523,7 @@ terminate(Reason, StateName, StateData) ->
undefined ->
ok;
Socket ->
- gen_tcp:close(Socket),
- exit(StateData#state.xmlpid, closed)
+ (StateData#state.sockmod):close(Socket)
end,
ok.
@@ -413,17 +531,17 @@ terminate(Reason, StateName, StateData) ->
%%% Internal functions
%%%----------------------------------------------------------------------
-send_text(Socket, Text) ->
- gen_tcp:send(Socket,Text).
+send_text(StateData, Text) ->
+ (StateData#state.sockmod):send(StateData#state.socket, Text).
-send_element(Socket, El) ->
- send_text(Socket, xml:element_to_string(El)).
+send_element(StateData, El) ->
+ send_text(StateData, xml:element_to_string(El)).
-send_queue(Socket, Q) ->
+send_queue(StateData, Q) ->
case queue:out(Q) of
{{value, El}, Q1} ->
- send_element(Socket, El),
- send_queue(Socket, Q1);
+ send_element(StateData, El),
+ send_queue(StateData, Q1);
{empty, Q1} ->
ok
end.
@@ -470,20 +588,46 @@ bounce_messages(Error) ->
ok
end.
-%is_key_packet({xmlelement, Name, Attrs, Els}) when Name == "db:result" ->
-% {key,
-% xml:get_attr_s("to", Attrs),
-% xml:get_attr_s("from", Attrs),
-% xml:get_attr_s("id", Attrs),
-% xml:get_cdata(Els)};
-%is_key_packet({xmlelement, Name, Attrs, Els}) when Name == "db:verify" ->
-% {verify,
-% xml:get_attr_s("to", Attrs),
-% xml:get_attr_s("from", Attrs),
-% xml:get_attr_s("id", Attrs),
-% xml:get_cdata(Els)};
-%is_key_packet(_) ->
-% false.
+
+send_db_request(StateData) ->
+ Server = StateData#state.server,
+ New = case StateData#state.new of
+ false ->
+ case ejabberd_s2s:try_register(
+ {StateData#state.myname, Server}) of
+ {key, Key} ->
+ Key;
+ false ->
+ false
+ end;
+ Key ->
+ Key
+ end,
+ case New of
+ false ->
+ ok;
+ Key1 ->
+ send_element(StateData,
+ {xmlelement,
+ "db:result",
+ [{"from", StateData#state.myname},
+ {"to", Server}],
+ [{xmlcdata, Key1}]})
+ end,
+ case StateData#state.verify of
+ false ->
+ ok;
+ {Pid, Key2, SID} ->
+ send_element(StateData,
+ {xmlelement,
+ "db:verify",
+ [{"from", StateData#state.myname},
+ {"to", StateData#state.server},
+ {"id", SID}],
+ [{xmlcdata, Key2}]})
+ end,
+ {next_state, wait_for_validation, StateData#state{new = New}}.
+
is_verify_res({xmlelement, Name, Attrs, Els}) when Name == "db:result" ->
{result,
diff --git a/src/tls/tls.erl b/src/tls/tls.erl
index 361c92fc..e1925520 100644
--- a/src/tls/tls.erl
+++ b/src/tls/tls.erl
@@ -27,11 +27,12 @@
code_change/3,
terminate/2]).
--define(SET_CERTIFICATE_FILE, 1).
--define(SET_ENCRYPTED_INPUT, 2).
--define(SET_DECRYPTED_OUTPUT, 3).
--define(GET_ENCRYPTED_OUTPUT, 4).
--define(GET_DECRYPTED_INPUT, 5).
+-define(SET_CERTIFICATE_FILE_ACCEPT, 1).
+-define(SET_CERTIFICATE_FILE_CONNECT, 2).
+-define(SET_ENCRYPTED_INPUT, 3).
+-define(SET_DECRYPTED_OUTPUT, 4).
+-define(GET_ENCRYPTED_OUTPUT, 5).
+-define(GET_DECRYPTED_INPUT, 6).
-record(tlssock, {tcpsock, tlsport}).
@@ -44,7 +45,7 @@ start_link() ->
init([]) ->
ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv),
Port = open_port({spawn, tls_drv}, [binary]),
- Res = port_control(Port, ?SET_CERTIFICATE_FILE, "./ssl.pem" ++ [0]),
+ Res = port_control(Port, ?SET_CERTIFICATE_FILE_ACCEPT, "./ssl.pem" ++ [0]),
case Res of
<<0>> ->
%ets:new(iconv_table, [set, public, named_table]),
@@ -86,8 +87,13 @@ tcp_to_tls(TCPSocket, Options) ->
{value, {certfile, CertFile}} ->
ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv),
Port = open_port({spawn, tls_drv}, [binary]),
- case port_control(Port, ?SET_CERTIFICATE_FILE,
- CertFile ++ [0]) of
+ Command = case lists:member(connect, Options) of
+ true ->
+ ?SET_CERTIFICATE_FILE_CONNECT;
+ false ->
+ ?SET_CERTIFICATE_FILE_ACCEPT
+ end,
+ case port_control(Port, Command, CertFile ++ [0]) of
<<0>> ->
{ok, #tlssock{tcpsock = TCPSocket, tlsport = Port}};
<<1, Error/binary>> ->
@@ -145,7 +151,10 @@ send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) ->
{error, binary_to_list(Error)}
end;
<<1, Error/binary>> ->
- {error, binary_to_list(Error)}
+ {error, binary_to_list(Error)};
+ <<2>> -> % Dirty hack
+ receive after 100 -> ok end,
+ send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet)
end.
@@ -158,7 +167,8 @@ test() ->
ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv),
Port = open_port({spawn, tls_drv}, [binary]),
io:format("open_port: ~p~n", [Port]),
- PCRes = port_control(Port, ?SET_CERTIFICATE_FILE, "./ssl.pem" ++ [0]),
+ PCRes = port_control(Port, ?SET_CERTIFICATE_FILE_ACCEPT,
+ "./ssl.pem" ++ [0]),
io:format("port_control: ~p~n", [PCRes]),
{ok, ListenSocket} = gen_tcp:listen(1234, [binary,
{packet, 0},
diff --git a/src/tls/tls_drv.c b/src/tls/tls_drv.c
index f320ee31..608830ff 100644
--- a/src/tls/tls_drv.c
+++ b/src/tls/tls_drv.c
@@ -4,6 +4,7 @@
#include <string.h>
#include <erl_driver.h>
#include <openssl/ssl.h>
+#include <openssl/err.h>
#define BUF_SIZE 1024
@@ -45,11 +46,12 @@ static void tls_drv_stop(ErlDrvData handle)
}
-#define SET_CERTIFICATE_FILE 1
-#define SET_ENCRYPTED_INPUT 2
-#define SET_DECRYPTED_OUTPUT 3
-#define GET_ENCRYPTED_OUTPUT 4
-#define GET_DECRYPTED_INPUT 5
+#define SET_CERTIFICATE_FILE_ACCEPT 1
+#define SET_CERTIFICATE_FILE_CONNECT 2
+#define SET_ENCRYPTED_INPUT 3
+#define SET_DECRYPTED_OUTPUT 4
+#define GET_ENCRYPTED_OUTPUT 5
+#define GET_DECRYPTED_INPUT 6
#define die_unless(cond, errstr) \
@@ -76,8 +78,9 @@ static int tls_drv_control(ErlDrvData handle,
switch (command)
{
- case SET_CERTIFICATE_FILE:
- d->ctx = SSL_CTX_new(SSLv23_server_method());
+ case SET_CERTIFICATE_FILE_ACCEPT:
+ case SET_CERTIFICATE_FILE_CONNECT:
+ d->ctx = SSL_CTX_new(SSLv23_method());
die_unless(d->ctx, "SSL_CTX_new failed");
res = SSL_CTX_use_certificate_file(d->ctx, buf, SSL_FILETYPE_PEM);
@@ -97,7 +100,10 @@ static int tls_drv_control(ErlDrvData handle,
SSL_set_bio(d->ssl, d->bio_read, d->bio_write);
- SSL_set_accept_state(d->ssl);
+ if (command == SET_CERTIFICATE_FILE_ACCEPT)
+ SSL_set_accept_state(d->ssl);
+ else
+ SSL_set_connect_state(d->ssl);
break;
case SET_ENCRYPTED_INPUT:
die_unless(d->ssl, "SSL not initialized");
@@ -106,6 +112,19 @@ static int tls_drv_control(ErlDrvData handle,
case SET_DECRYPTED_OUTPUT:
die_unless(d->ssl, "SSL not initialized");
res = SSL_write(d->ssl, buf, len);
+ if (res <= 0)
+ {
+ res = SSL_get_error(d->ssl, res);
+ if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE)
+ {
+ b = driver_alloc_binary(1);
+ b->orig_bytes[0] = 2;
+ *rbuf = (char *)b;
+ return 1;
+ } else {
+ die_unless(0, "SSL_write failed");
+ }
+ }
break;
case GET_ENCRYPTED_OUTPUT:
die_unless(d->ssl, "SSL not initialized");
@@ -128,13 +147,10 @@ static int tls_drv_control(ErlDrvData handle,
case GET_DECRYPTED_INPUT:
if (!SSL_is_init_finished(d->ssl))
{
- //printf("Doing SSL_accept\r\n");
- res = SSL_accept(d->ssl);
- //if (res == 0)
- // printf("SSL_accept returned zero\r\n");
- if (res < 0)
+ res = SSL_do_handshake(d->ssl);
+ if (res <= 0)
die_unless(SSL_get_error(d->ssl, res) == SSL_ERROR_WANT_READ,
- "SSL_accept failed");
+ "SSL_do_handshake failed");
} else {
size = BUF_SIZE + 1;
rlen = 1;