summaryrefslogtreecommitdiff
path: root/src/ejabberd_c2s.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/ejabberd_c2s.erl')
-rw-r--r--src/ejabberd_c2s.erl192
1 files changed, 118 insertions, 74 deletions
diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl
index 88d26b1d..e9f53b6e 100644
--- a/src/ejabberd_c2s.erl
+++ b/src/ejabberd_c2s.erl
@@ -67,6 +67,7 @@
-record(state, {socket,
sockmod,
socket_monitor,
+ xml_socket,
streamid,
sasl_state,
access,
@@ -124,16 +125,12 @@
-define(STREAM_TRAILER, "</stream:stream>").
--define(INVALID_NS_ERR,
- xml:element_to_string(?SERR_INVALID_NAMESPACE)).
--define(INVALID_XML_ERR,
- xml:element_to_string(?SERR_XML_NOT_WELL_FORMED)).
--define(HOST_UNKNOWN_ERR,
- xml:element_to_string(?SERR_HOST_UNKNOWN)).
+-define(INVALID_NS_ERR, ?SERR_INVALID_NAMESPACE).
+-define(INVALID_XML_ERR, ?SERR_XML_NOT_WELL_FORMED).
+-define(HOST_UNKNOWN_ERR, ?SERR_HOST_UNKNOWN).
-define(POLICY_VIOLATION_ERR(Lang, Text),
- xml:element_to_string(?SERRT_POLICY_VIOLATION(Lang, Text))).
--define(INVALID_FROM,
- xml:element_to_string(?SERR_INVALID_FROM)).
+ ?SERRT_POLICY_VIOLATION(Lang, Text)).
+-define(INVALID_FROM, ?SERR_INVALID_FROM).
%%%----------------------------------------------------------------------
@@ -175,6 +172,11 @@ init([{SockMod, Socket}, Opts]) ->
{value, {_, S}} -> S;
_ -> none
end,
+ XMLSocket =
+ case lists:keysearch(xml_socket, 1, Opts) of
+ {value, {_, XS}} -> XS;
+ _ -> false
+ end,
Zlib = lists:member(zlib, Opts),
StartTLS = lists:member(starttls, Opts),
StartTLSRequired = lists:member(starttls_required, Opts),
@@ -205,6 +207,7 @@ init([{SockMod, Socket}, Opts]) ->
{ok, wait_for_stream, #state{socket = Socket1,
sockmod = SockMod,
socket_monitor = SocketMonitor,
+ xml_socket = XMLSocket,
zlib = Zlib,
tls = TLS,
tls_required = StartTLSRequired,
@@ -231,9 +234,9 @@ get_subscribed(FsmRef) ->
wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
DefaultLang = case ?MYLANG of
undefined ->
- " xml:lang='en'";
+ "en";
DL ->
- " xml:lang='" ++ DL ++ "'"
+ DL
end,
case xml:get_attr_s("xmlns:stream", Attrs) of
?NS_STREAM ->
@@ -244,12 +247,7 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
change_shaper(StateData, jlib:make_jid("", Server, "")),
case xml:get_attr_s("version", Attrs) of
"1.0" ->
- Header = io_lib:format(?STREAM_HEADER,
- [StateData#state.streamid,
- Server,
- " version='1.0'",
- DefaultLang]),
- send_text(StateData, Header),
+ send_header(StateData, Server, "1.0", DefaultLang),
case StateData#state.authenticated of
false ->
SASLState =
@@ -351,22 +349,18 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
end
end;
_ ->
- Header = io_lib:format(
- ?STREAM_HEADER,
- [StateData#state.streamid, Server, "",
- DefaultLang]),
+ send_header(StateData, Server, "", DefaultLang),
if
(not StateData#state.tls_enabled) and
StateData#state.tls_required ->
- send_text(StateData,
- Header ++
- ?POLICY_VIOLATION_ERR(
- Lang,
- "Use of STARTTLS required") ++
- ?STREAM_TRAILER),
+ send_element(
+ StateData,
+ ?POLICY_VIOLATION_ERR(
+ Lang,
+ "Use of STARTTLS required")),
+ send_trailer(StateData),
{stop, normal, StateData};
true ->
- send_text(StateData, Header),
fsm_next_state(wait_for_auth,
StateData#state{
server = Server,
@@ -374,20 +368,15 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
end
end;
_ ->
- Header = io_lib:format(
- ?STREAM_HEADER,
- [StateData#state.streamid, ?MYNAME, "",
- DefaultLang]),
- send_text(StateData,
- Header ++ ?HOST_UNKNOWN_ERR ++ ?STREAM_TRAILER),
+ send_header(StateData, ?MYNAME, "", DefaultLang),
+ send_element(StateData, ?HOST_UNKNOWN_ERR),
+ send_trailer(StateData),
{stop, normal, StateData}
end;
_ ->
- Header = io_lib:format(
- ?STREAM_HEADER,
- [StateData#state.streamid, ?MYNAME, "", DefaultLang]),
- send_text(StateData,
- Header ++ ?INVALID_NS_ERR ++ ?STREAM_TRAILER),
+ send_header(StateData, ?MYNAME, "", DefaultLang),
+ send_element(StateData, ?INVALID_NS_ERR),
+ send_trailer(StateData),
{stop, normal, StateData}
end;
@@ -395,18 +384,19 @@ wait_for_stream(timeout, StateData) ->
{stop, normal, StateData};
wait_for_stream({xmlstreamelement, _}, StateData) ->
- send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_element(StateData, ?INVALID_XML_ERR),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_stream({xmlstreamend, _}, StateData) ->
- send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_element(StateData, ?INVALID_XML_ERR),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_stream({xmlstreamerror, _}, StateData) ->
- Header = io_lib:format(?STREAM_HEADER,
- ["none", ?MYNAME, " version='1.0'", ""]),
- send_text(StateData,
- Header ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_header(StateData, ?MYNAME, "1.0", ""),
+ send_element(StateData, ?INVALID_XML_ERR),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_stream(closed, StateData) ->
@@ -538,11 +528,12 @@ wait_for_auth(timeout, StateData) ->
{stop, normal, StateData};
wait_for_auth({xmlstreamend, _Name}, StateData) ->
- send_text(StateData, ?STREAM_TRAILER),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_auth({xmlstreamerror, _}, StateData) ->
- send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_element(StateData, ?INVALID_XML_ERR),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_auth(closed, StateData) ->
@@ -665,10 +656,10 @@ wait_for_feature_request({xmlstreamelement, El}, StateData) ->
if
(SockMod == gen_tcp) and TLSRequired ->
Lang = StateData#state.lang,
- send_text(StateData, ?POLICY_VIOLATION_ERR(
- Lang,
- "Use of STARTTLS required") ++
- ?STREAM_TRAILER),
+ send_element(StateData, ?POLICY_VIOLATION_ERR(
+ Lang,
+ "Use of STARTTLS required")),
+ send_trailer(StateData),
{stop, normal, StateData};
true ->
process_unauthenticated_stanza(StateData, El),
@@ -680,11 +671,12 @@ wait_for_feature_request(timeout, StateData) ->
{stop, normal, StateData};
wait_for_feature_request({xmlstreamend, _Name}, StateData) ->
- send_text(StateData, ?STREAM_TRAILER),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_feature_request({xmlstreamerror, _}, StateData) ->
- send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_element(StateData, ?INVALID_XML_ERR),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_feature_request(closed, StateData) ->
@@ -748,11 +740,12 @@ wait_for_sasl_response(timeout, StateData) ->
{stop, normal, StateData};
wait_for_sasl_response({xmlstreamend, _Name}, StateData) ->
- send_text(StateData, ?STREAM_TRAILER),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_sasl_response({xmlstreamerror, _}, StateData) ->
- send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_element(StateData, ?INVALID_XML_ERR),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_sasl_response(closed, StateData) ->
@@ -797,11 +790,12 @@ wait_for_bind(timeout, StateData) ->
{stop, normal, StateData};
wait_for_bind({xmlstreamend, _Name}, StateData) ->
- send_text(StateData, ?STREAM_TRAILER),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_bind({xmlstreamerror, _}, StateData) ->
- send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_element(StateData, ?INVALID_XML_ERR),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_bind(closed, StateData) ->
@@ -868,11 +862,12 @@ wait_for_session(timeout, StateData) ->
{stop, normal, StateData};
wait_for_session({xmlstreamend, _Name}, StateData) ->
- send_text(StateData, ?STREAM_TRAILER),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_session({xmlstreamerror, _}, StateData) ->
- send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_element(StateData, ?INVALID_XML_ERR),
+ send_trailer(StateData),
{stop, normal, StateData};
wait_for_session(closed, StateData) ->
@@ -884,7 +879,8 @@ session_established({xmlstreamelement, El}, StateData) ->
% Check 'from' attribute in stanza RFC 3920 Section 9.1.2
case check_from(El, FromJID) of
'invalid-from' ->
- send_text(StateData, ?INVALID_FROM ++ ?STREAM_TRAILER),
+ send_element(StateData, ?INVALID_FROM),
+ send_trailer(StateData),
{stop, normal, StateData};
_NewEl ->
session_established2(El, StateData)
@@ -900,16 +896,17 @@ session_established(timeout, StateData) ->
fsm_next_state(session_established, StateData);
session_established({xmlstreamend, _Name}, StateData) ->
- send_text(StateData, ?STREAM_TRAILER),
+ send_trailer(StateData),
{stop, normal, StateData};
session_established({xmlstreamerror, "XML stanza is too big" = E}, StateData) ->
- Text = ?POLICY_VIOLATION_ERR(StateData#state.lang, E) ++ ?STREAM_TRAILER,
- send_text(StateData, Text),
+ send_element(StateData, ?POLICY_VIOLATION_ERR(StateData#state.lang, E)),
+ send_trailer(StateData),
{stop, normal, StateData};
session_established({xmlstreamerror, _}, StateData) ->
- send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
+ send_element(StateData, ?INVALID_XML_ERR),
+ send_trailer(StateData),
{stop, normal, StateData};
session_established(closed, StateData) ->
@@ -1070,10 +1067,9 @@ handle_info({send_text, Text}, StateName, StateData) ->
fsm_next_state(StateName, StateData);
handle_info(replaced, _StateName, StateData) ->
Lang = StateData#state.lang,
- send_text(StateData,
- xml:element_to_string(
- ?SERRT_CONFLICT(Lang, "Replaced by new connection"))
- ++ ?STREAM_TRAILER),
+ send_element(StateData,
+ ?SERRT_CONFLICT(Lang, "Replaced by new connection")),
+ send_trailer(StateData),
{stop, normal, StateData#state{authenticated = replaced}};
%% Process Packets that are to be send to the user
handle_info({route, From, To, Packet}, StateName, StateData) ->
@@ -1273,18 +1269,15 @@ handle_info({route, From, To, Packet}, StateName, StateData) ->
Pass == exit ->
%% When Pass==exit, NewState contains a string instead of a #state{}
Lang = StateData#state.lang,
- catch send_text(StateData,
- xml:element_to_string(
- ?SERRT_CONFLICT(Lang, NewState))
- ++ ?STREAM_TRAILER),
+ send_element(StateData, ?SERRT_CONFLICT(Lang, NewState)),
+ send_trailer(StateData),
{stop, normal, StateData};
Pass ->
Attrs2 = jlib:replace_from_to_attrs(jlib:jid_to_string(From),
jlib:jid_to_string(To),
NewAttrs),
FixedPacket = {xmlelement, Name, Attrs2, Els},
- Text = xml:element_to_string(FixedPacket),
- send_text(StateData, Text),
+ send_element(StateData, FixedPacket),
ejabberd_hooks:run(user_receive_packet,
StateData#state.server,
[StateData#state.jid, From, To, FixedPacket]),
@@ -1379,9 +1372,60 @@ send_text(StateData, Text) ->
?DEBUG("Send XML on stream = ~p", [lists:flatten(Text)]),
(StateData#state.sockmod):send(StateData#state.socket, Text).
+send_element(StateData, El) when StateData#state.xml_socket ->
+ (StateData#state.sockmod):send_xml(StateData#state.socket,
+ {xmlstreamelement, El});
send_element(StateData, El) ->
send_text(StateData, xml:element_to_string(El)).
+send_header(StateData, Server, Version, Lang)
+ when StateData#state.xml_socket ->
+ VersionAttr =
+ case Version of
+ "" -> [];
+ _ -> [{"version", Version}]
+ end,
+ LangAttr =
+ case Lang of
+ "" -> [];
+ _ -> [{"xml:lang", Lang}]
+ end,
+ Header =
+ {xmlstreamstart,
+ "stream:stream",
+ VersionAttr ++
+ LangAttr ++
+ [{"xmlns", "jabber:client"},
+ {"xmlns:stream", "http://etherx.jabber.org/streams"},
+ {"id", StateData#state.streamid},
+ {"from", Server}]},
+ (StateData#state.sockmod):send_xml(
+ StateData#state.socket, Header);
+send_header(StateData, Server, Version, Lang) ->
+ VersionStr =
+ case Version of
+ "" -> "";
+ _ -> [" version='", Version, "'"]
+ end,
+ LangStr =
+ case Lang of
+ "" -> "";
+ _ -> [" xml:lang='", Lang, "'"]
+ end,
+ Header = io_lib:format(?STREAM_HEADER,
+ [StateData#state.streamid,
+ Server,
+ VersionStr,
+ LangStr]),
+ send_text(StateData, Header).
+
+send_trailer(StateData) when StateData#state.xml_socket ->
+ (StateData#state.sockmod):send_xml(
+ StateData#state.socket,
+ {xmlstreamend, "stream:stream"});
+send_trailer(StateData) ->
+ send_text(StateData, ?STREAM_TRAILER).
+
new_id() ->
randoms:get_string().