diff options
Diffstat (limited to 'src/ejabberd_c2s.erl')
-rw-r--r-- | src/ejabberd_c2s.erl | 192 |
1 files changed, 118 insertions, 74 deletions
diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index 88d26b1d..e9f53b6e 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -67,6 +67,7 @@ -record(state, {socket, sockmod, socket_monitor, + xml_socket, streamid, sasl_state, access, @@ -124,16 +125,12 @@ -define(STREAM_TRAILER, "</stream:stream>"). --define(INVALID_NS_ERR, - xml:element_to_string(?SERR_INVALID_NAMESPACE)). --define(INVALID_XML_ERR, - xml:element_to_string(?SERR_XML_NOT_WELL_FORMED)). --define(HOST_UNKNOWN_ERR, - xml:element_to_string(?SERR_HOST_UNKNOWN)). +-define(INVALID_NS_ERR, ?SERR_INVALID_NAMESPACE). +-define(INVALID_XML_ERR, ?SERR_XML_NOT_WELL_FORMED). +-define(HOST_UNKNOWN_ERR, ?SERR_HOST_UNKNOWN). -define(POLICY_VIOLATION_ERR(Lang, Text), - xml:element_to_string(?SERRT_POLICY_VIOLATION(Lang, Text))). --define(INVALID_FROM, - xml:element_to_string(?SERR_INVALID_FROM)). + ?SERRT_POLICY_VIOLATION(Lang, Text)). +-define(INVALID_FROM, ?SERR_INVALID_FROM). %%%---------------------------------------------------------------------- @@ -175,6 +172,11 @@ init([{SockMod, Socket}, Opts]) -> {value, {_, S}} -> S; _ -> none end, + XMLSocket = + case lists:keysearch(xml_socket, 1, Opts) of + {value, {_, XS}} -> XS; + _ -> false + end, Zlib = lists:member(zlib, Opts), StartTLS = lists:member(starttls, Opts), StartTLSRequired = lists:member(starttls_required, Opts), @@ -205,6 +207,7 @@ init([{SockMod, Socket}, Opts]) -> {ok, wait_for_stream, #state{socket = Socket1, sockmod = SockMod, socket_monitor = SocketMonitor, + xml_socket = XMLSocket, zlib = Zlib, tls = TLS, tls_required = StartTLSRequired, @@ -231,9 +234,9 @@ get_subscribed(FsmRef) -> wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> DefaultLang = case ?MYLANG of undefined -> - " xml:lang='en'"; + "en"; DL -> - " xml:lang='" ++ DL ++ "'" + DL end, case xml:get_attr_s("xmlns:stream", Attrs) of ?NS_STREAM -> @@ -244,12 +247,7 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> change_shaper(StateData, jlib:make_jid("", Server, "")), case xml:get_attr_s("version", Attrs) of "1.0" -> - Header = io_lib:format(?STREAM_HEADER, - [StateData#state.streamid, - Server, - " version='1.0'", - DefaultLang]), - send_text(StateData, Header), + send_header(StateData, Server, "1.0", DefaultLang), case StateData#state.authenticated of false -> SASLState = @@ -351,22 +349,18 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> end end; _ -> - Header = io_lib:format( - ?STREAM_HEADER, - [StateData#state.streamid, Server, "", - DefaultLang]), + send_header(StateData, Server, "", DefaultLang), if (not StateData#state.tls_enabled) and StateData#state.tls_required -> - send_text(StateData, - Header ++ - ?POLICY_VIOLATION_ERR( - Lang, - "Use of STARTTLS required") ++ - ?STREAM_TRAILER), + send_element( + StateData, + ?POLICY_VIOLATION_ERR( + Lang, + "Use of STARTTLS required")), + send_trailer(StateData), {stop, normal, StateData}; true -> - send_text(StateData, Header), fsm_next_state(wait_for_auth, StateData#state{ server = Server, @@ -374,20 +368,15 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> end end; _ -> - Header = io_lib:format( - ?STREAM_HEADER, - [StateData#state.streamid, ?MYNAME, "", - DefaultLang]), - send_text(StateData, - Header ++ ?HOST_UNKNOWN_ERR ++ ?STREAM_TRAILER), + send_header(StateData, ?MYNAME, "", DefaultLang), + send_element(StateData, ?HOST_UNKNOWN_ERR), + send_trailer(StateData), {stop, normal, StateData} end; _ -> - Header = io_lib:format( - ?STREAM_HEADER, - [StateData#state.streamid, ?MYNAME, "", DefaultLang]), - send_text(StateData, - Header ++ ?INVALID_NS_ERR ++ ?STREAM_TRAILER), + send_header(StateData, ?MYNAME, "", DefaultLang), + send_element(StateData, ?INVALID_NS_ERR), + send_trailer(StateData), {stop, normal, StateData} end; @@ -395,18 +384,19 @@ wait_for_stream(timeout, StateData) -> {stop, normal, StateData}; wait_for_stream({xmlstreamelement, _}, StateData) -> - send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_element(StateData, ?INVALID_XML_ERR), + send_trailer(StateData), {stop, normal, StateData}; wait_for_stream({xmlstreamend, _}, StateData) -> - send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_element(StateData, ?INVALID_XML_ERR), + send_trailer(StateData), {stop, normal, StateData}; wait_for_stream({xmlstreamerror, _}, StateData) -> - Header = io_lib:format(?STREAM_HEADER, - ["none", ?MYNAME, " version='1.0'", ""]), - send_text(StateData, - Header ++ ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_header(StateData, ?MYNAME, "1.0", ""), + send_element(StateData, ?INVALID_XML_ERR), + send_trailer(StateData), {stop, normal, StateData}; wait_for_stream(closed, StateData) -> @@ -538,11 +528,12 @@ wait_for_auth(timeout, StateData) -> {stop, normal, StateData}; wait_for_auth({xmlstreamend, _Name}, StateData) -> - send_text(StateData, ?STREAM_TRAILER), + send_trailer(StateData), {stop, normal, StateData}; wait_for_auth({xmlstreamerror, _}, StateData) -> - send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_element(StateData, ?INVALID_XML_ERR), + send_trailer(StateData), {stop, normal, StateData}; wait_for_auth(closed, StateData) -> @@ -665,10 +656,10 @@ wait_for_feature_request({xmlstreamelement, El}, StateData) -> if (SockMod == gen_tcp) and TLSRequired -> Lang = StateData#state.lang, - send_text(StateData, ?POLICY_VIOLATION_ERR( - Lang, - "Use of STARTTLS required") ++ - ?STREAM_TRAILER), + send_element(StateData, ?POLICY_VIOLATION_ERR( + Lang, + "Use of STARTTLS required")), + send_trailer(StateData), {stop, normal, StateData}; true -> process_unauthenticated_stanza(StateData, El), @@ -680,11 +671,12 @@ wait_for_feature_request(timeout, StateData) -> {stop, normal, StateData}; wait_for_feature_request({xmlstreamend, _Name}, StateData) -> - send_text(StateData, ?STREAM_TRAILER), + send_trailer(StateData), {stop, normal, StateData}; wait_for_feature_request({xmlstreamerror, _}, StateData) -> - send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_element(StateData, ?INVALID_XML_ERR), + send_trailer(StateData), {stop, normal, StateData}; wait_for_feature_request(closed, StateData) -> @@ -748,11 +740,12 @@ wait_for_sasl_response(timeout, StateData) -> {stop, normal, StateData}; wait_for_sasl_response({xmlstreamend, _Name}, StateData) -> - send_text(StateData, ?STREAM_TRAILER), + send_trailer(StateData), {stop, normal, StateData}; wait_for_sasl_response({xmlstreamerror, _}, StateData) -> - send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_element(StateData, ?INVALID_XML_ERR), + send_trailer(StateData), {stop, normal, StateData}; wait_for_sasl_response(closed, StateData) -> @@ -797,11 +790,12 @@ wait_for_bind(timeout, StateData) -> {stop, normal, StateData}; wait_for_bind({xmlstreamend, _Name}, StateData) -> - send_text(StateData, ?STREAM_TRAILER), + send_trailer(StateData), {stop, normal, StateData}; wait_for_bind({xmlstreamerror, _}, StateData) -> - send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_element(StateData, ?INVALID_XML_ERR), + send_trailer(StateData), {stop, normal, StateData}; wait_for_bind(closed, StateData) -> @@ -868,11 +862,12 @@ wait_for_session(timeout, StateData) -> {stop, normal, StateData}; wait_for_session({xmlstreamend, _Name}, StateData) -> - send_text(StateData, ?STREAM_TRAILER), + send_trailer(StateData), {stop, normal, StateData}; wait_for_session({xmlstreamerror, _}, StateData) -> - send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_element(StateData, ?INVALID_XML_ERR), + send_trailer(StateData), {stop, normal, StateData}; wait_for_session(closed, StateData) -> @@ -884,7 +879,8 @@ session_established({xmlstreamelement, El}, StateData) -> % Check 'from' attribute in stanza RFC 3920 Section 9.1.2 case check_from(El, FromJID) of 'invalid-from' -> - send_text(StateData, ?INVALID_FROM ++ ?STREAM_TRAILER), + send_element(StateData, ?INVALID_FROM), + send_trailer(StateData), {stop, normal, StateData}; _NewEl -> session_established2(El, StateData) @@ -900,16 +896,17 @@ session_established(timeout, StateData) -> fsm_next_state(session_established, StateData); session_established({xmlstreamend, _Name}, StateData) -> - send_text(StateData, ?STREAM_TRAILER), + send_trailer(StateData), {stop, normal, StateData}; session_established({xmlstreamerror, "XML stanza is too big" = E}, StateData) -> - Text = ?POLICY_VIOLATION_ERR(StateData#state.lang, E) ++ ?STREAM_TRAILER, - send_text(StateData, Text), + send_element(StateData, ?POLICY_VIOLATION_ERR(StateData#state.lang, E)), + send_trailer(StateData), {stop, normal, StateData}; session_established({xmlstreamerror, _}, StateData) -> - send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER), + send_element(StateData, ?INVALID_XML_ERR), + send_trailer(StateData), {stop, normal, StateData}; session_established(closed, StateData) -> @@ -1070,10 +1067,9 @@ handle_info({send_text, Text}, StateName, StateData) -> fsm_next_state(StateName, StateData); handle_info(replaced, _StateName, StateData) -> Lang = StateData#state.lang, - send_text(StateData, - xml:element_to_string( - ?SERRT_CONFLICT(Lang, "Replaced by new connection")) - ++ ?STREAM_TRAILER), + send_element(StateData, + ?SERRT_CONFLICT(Lang, "Replaced by new connection")), + send_trailer(StateData), {stop, normal, StateData#state{authenticated = replaced}}; %% Process Packets that are to be send to the user handle_info({route, From, To, Packet}, StateName, StateData) -> @@ -1273,18 +1269,15 @@ handle_info({route, From, To, Packet}, StateName, StateData) -> Pass == exit -> %% When Pass==exit, NewState contains a string instead of a #state{} Lang = StateData#state.lang, - catch send_text(StateData, - xml:element_to_string( - ?SERRT_CONFLICT(Lang, NewState)) - ++ ?STREAM_TRAILER), + send_element(StateData, ?SERRT_CONFLICT(Lang, NewState)), + send_trailer(StateData), {stop, normal, StateData}; Pass -> Attrs2 = jlib:replace_from_to_attrs(jlib:jid_to_string(From), jlib:jid_to_string(To), NewAttrs), FixedPacket = {xmlelement, Name, Attrs2, Els}, - Text = xml:element_to_string(FixedPacket), - send_text(StateData, Text), + send_element(StateData, FixedPacket), ejabberd_hooks:run(user_receive_packet, StateData#state.server, [StateData#state.jid, From, To, FixedPacket]), @@ -1379,9 +1372,60 @@ send_text(StateData, Text) -> ?DEBUG("Send XML on stream = ~p", [lists:flatten(Text)]), (StateData#state.sockmod):send(StateData#state.socket, Text). +send_element(StateData, El) when StateData#state.xml_socket -> + (StateData#state.sockmod):send_xml(StateData#state.socket, + {xmlstreamelement, El}); send_element(StateData, El) -> send_text(StateData, xml:element_to_string(El)). +send_header(StateData, Server, Version, Lang) + when StateData#state.xml_socket -> + VersionAttr = + case Version of + "" -> []; + _ -> [{"version", Version}] + end, + LangAttr = + case Lang of + "" -> []; + _ -> [{"xml:lang", Lang}] + end, + Header = + {xmlstreamstart, + "stream:stream", + VersionAttr ++ + LangAttr ++ + [{"xmlns", "jabber:client"}, + {"xmlns:stream", "http://etherx.jabber.org/streams"}, + {"id", StateData#state.streamid}, + {"from", Server}]}, + (StateData#state.sockmod):send_xml( + StateData#state.socket, Header); +send_header(StateData, Server, Version, Lang) -> + VersionStr = + case Version of + "" -> ""; + _ -> [" version='", Version, "'"] + end, + LangStr = + case Lang of + "" -> ""; + _ -> [" xml:lang='", Lang, "'"] + end, + Header = io_lib:format(?STREAM_HEADER, + [StateData#state.streamid, + Server, + VersionStr, + LangStr]), + send_text(StateData, Header). + +send_trailer(StateData) when StateData#state.xml_socket -> + (StateData#state.sockmod):send_xml( + StateData#state.socket, + {xmlstreamend, "stream:stream"}); +send_trailer(StateData) -> + send_text(StateData, ?STREAM_TRAILER). + new_id() -> randoms:get_string(). |