diff options
Diffstat (limited to 'src')
52 files changed, 3934 insertions, 1909 deletions
diff --git a/src/cyrsasl.erl b/src/cyrsasl.erl index f404a7afb..db2160ca9 100644 --- a/src/cyrsasl.erl +++ b/src/cyrsasl.erl @@ -93,9 +93,15 @@ start() -> ). register_mechanism(Mechanism, Module, PasswordType) -> - ets:insert(sasl_mechanism, - #sasl_mechanism{mechanism = Mechanism, module = Module, - password_type = PasswordType}). + case is_disabled(Mechanism) of + false -> + ets:insert(sasl_mechanism, + #sasl_mechanism{mechanism = Mechanism, module = Module, + password_type = PasswordType}); + true -> + ?DEBUG("SASL mechanism ~p is disabled", [Mechanism]), + true + end. %%% TODO: use callbacks %%-include("ejabberd.hrl"). @@ -215,3 +221,19 @@ filter_anonymous(Host, Mechs) -> true -> Mechs; false -> Mechs -- [<<"ANONYMOUS">>] end. + +-spec(is_disabled/1 :: +( + Mechanism :: mechanism()) + -> boolean() +). + +is_disabled(Mechanism) -> + Disabled = ejabberd_config:get_option( + disable_sasl_mechanisms, + fun(V) when is_list(V) -> + lists:map(fun(M) -> str:to_upper(M) end, V); + (V) -> + [str:to_upper(V)] + end, []), + lists:member(Mechanism, Disabled). diff --git a/src/ejabberd_app.erl b/src/ejabberd_app.erl index 8106b7b0d..379f728d6 100644 --- a/src/ejabberd_app.erl +++ b/src/ejabberd_app.erl @@ -57,6 +57,7 @@ start(normal, _Args) -> connect_nodes(), Sup = ejabberd_sup:start_link(), ejabberd_rdbms:start(), + ejabberd_riak_sup:start(), ejabberd_auth:start(), cyrsasl:start(), % Profiling @@ -107,6 +108,18 @@ loop() -> end. db_init() -> + MyNode = node(), + DbNodes = mnesia:system_info(db_nodes), + case lists:member(MyNode, DbNodes) of + true -> + ok; + false -> + ?CRITICAL_MSG("Node name mismatch: I'm [~s], " + "the database is owned by ~p", [MyNode, DbNodes]), + ?CRITICAL_MSG("Either set ERLANG_NODE in ejabberdctl.cfg " + "or change node name in Mnesia", []), + erlang:error(node_name_mismatch) + end, case mnesia:system_info(extra_db_nodes) of [] -> mnesia:create_schema([node()]); diff --git a/src/ejabberd_auth.erl b/src/ejabberd_auth.erl index 477926f6e..f716bbb35 100644 --- a/src/ejabberd_auth.erl +++ b/src/ejabberd_auth.erl @@ -445,5 +445,7 @@ import(Server) -> import(Server, mnesia, Passwd) -> ejabberd_auth_internal:import(Server, mnesia, Passwd); +import(Server, riak, Passwd) -> + ejabberd_auth_riak:import(Server, riak, Passwd); import(_, _, _) -> pass. diff --git a/src/ejabberd_auth_ldap.erl b/src/ejabberd_auth_ldap.erl index 77937d010..7eba6ef32 100644 --- a/src/ejabberd_auth_ldap.erl +++ b/src/ejabberd_auth_ldap.erl @@ -387,7 +387,7 @@ parse_options(Host) -> [{<<"%u">>, <<"*">>}]), {DNFilter, DNFilterAttrs} = eldap_utils:get_opt({ldap_dn_filter, Host}, [], - fun({DNF, DNFA}) -> + fun([{DNF, DNFA}]) -> NewDNFA = case DNFA of undefined -> []; diff --git a/src/ejabberd_auth_riak.erl b/src/ejabberd_auth_riak.erl new file mode 100644 index 000000000..e5d901cfc --- /dev/null +++ b/src/ejabberd_auth_riak.erl @@ -0,0 +1,296 @@ +%%%---------------------------------------------------------------------- +%%% File : ejabberd_auth_riak.erl +%%% Author : Evgeniy Khramtsov <ekhramtsov@process-one.net> +%%% Purpose : Authentification via Riak +%%% Created : 12 Nov 2012 by Evgeniy Khramtsov <ekhramtsov@process-one.net> +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2012 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License +%%% along with this program; if not, write to the Free Software +%%% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +%%% 02111-1307 USA +%%% +%%%---------------------------------------------------------------------- + +-module(ejabberd_auth_riak). + +-author('alexey@process-one.net'). + +-behaviour(ejabberd_auth). + +%% External exports +-export([start/1, set_password/3, check_password/3, + check_password/5, try_register/3, + dirty_get_registered_users/0, get_vh_registered_users/1, + get_vh_registered_users/2, + get_vh_registered_users_number/1, + get_vh_registered_users_number/2, get_password/2, + get_password_s/2, is_user_exists/2, remove_user/2, + remove_user/3, store_type/0, export/1, import/3, + plain_password_required/0]). +-export([passwd_schema/0]). + +-include("ejabberd.hrl"). + +-record(passwd, {us = {<<"">>, <<"">>} :: {binary(), binary()} | '$1', + password = <<"">> :: binary() | scram() | '_'}). + +-define(SALT_LENGTH, 16). + +start(_Host) -> + ok. + +plain_password_required() -> + case is_scrammed() of + false -> false; + true -> true + end. + +store_type() -> + case is_scrammed() of + false -> plain; %% allows: PLAIN DIGEST-MD5 SCRAM + true -> scram %% allows: PLAIN SCRAM + end. + +passwd_schema() -> + {record_info(fields, passwd), #passwd{}}. + +check_password(User, Server, Password) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + case ejabberd_riak:get(passwd, passwd_schema(), {LUser, LServer}) of + {ok, #passwd{password = Password}} when is_binary(Password) -> + Password /= <<"">>; + {ok, #passwd{password = Scram}} when is_record(Scram, scram) -> + is_password_scram_valid(Password, Scram); + _ -> + false + end. + +check_password(User, Server, Password, Digest, + DigestGen) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + case ejabberd_riak:get(passwd, passwd_schema(), {LUser, LServer}) of + {ok, #passwd{password = Passwd}} when is_binary(Passwd) -> + DigRes = if Digest /= <<"">> -> + Digest == DigestGen(Passwd); + true -> false + end, + if DigRes -> true; + true -> (Passwd == Password) and (Password /= <<"">>) + end; + {ok, #passwd{password = Scram}} + when is_record(Scram, scram) -> + Passwd = jlib:decode_base64(Scram#scram.storedkey), + DigRes = if Digest /= <<"">> -> + Digest == DigestGen(Passwd); + true -> false + end, + if DigRes -> true; + true -> (Passwd == Password) and (Password /= <<"">>) + end; + _ -> false + end. + +set_password(User, Server, Password) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + US = {LUser, LServer}, + if (LUser == error) or (LServer == error) -> + {error, invalid_jid}; + true -> + Password2 = case is_scrammed() and is_binary(Password) + of + true -> password_to_scram(Password); + false -> Password + end, + ok = ejabberd_riak:put(#passwd{us = US, password = Password2}, + passwd_schema(), + [{'2i', [{<<"host">>, LServer}]}]) + end. + +try_register(User, Server, PasswordList) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + Password = iolist_to_binary(PasswordList), + US = {LUser, LServer}, + if (LUser == error) or (LServer == error) -> + {error, invalid_jid}; + true -> + case ejabberd_riak:get(passwd, passwd_schema(), US) of + {error, notfound} -> + Password2 = case is_scrammed() and + is_binary(Password) + of + true -> password_to_scram(Password); + false -> Password + end, + {atomic, ejabberd_riak:put( + #passwd{us = US, + password = Password2}, + passwd_schema(), + [{'2i', [{<<"host">>, LServer}]}])}; + {ok, _} -> + exists; + Err -> + {atomic, Err} + end + end. + +dirty_get_registered_users() -> + lists:flatmap( + fun(Server) -> + get_vh_registered_users(Server) + end, ejabberd_config:get_vh_by_auth_method(riak)). + +get_vh_registered_users(Server) -> + LServer = jlib:nameprep(Server), + case ejabberd_riak:get_keys_by_index(passwd, <<"host">>, LServer) of + {ok, Users} -> + Users; + _ -> + [] + end. + +get_vh_registered_users(Server, _) -> + get_vh_registered_users(Server). + +get_vh_registered_users_number(Server) -> + LServer = jlib:nameprep(Server), + case ejabberd_riak:count_by_index(passwd, <<"host">>, LServer) of + {ok, N} -> + N; + _ -> + 0 + end. + +get_vh_registered_users_number(Server, _) -> + get_vh_registered_users_number(Server). + +get_password(User, Server) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + case ejabberd_riak:get(passwd, passwd_schema(), {LUser, LServer}) of + {ok, #passwd{password = Password}} + when is_binary(Password) -> + Password; + {ok, #passwd{password = Scram}} + when is_record(Scram, scram) -> + {jlib:decode_base64(Scram#scram.storedkey), + jlib:decode_base64(Scram#scram.serverkey), + jlib:decode_base64(Scram#scram.salt), + Scram#scram.iterationcount}; + _ -> false + end. + +get_password_s(User, Server) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + case ejabberd_riak:get(passwd, passwd_schema(), {LUser, LServer}) of + {ok, #passwd{password = Password}} + when is_binary(Password) -> + Password; + {ok, #passwd{password = Scram}} + when is_record(Scram, scram) -> + <<"">>; + _ -> <<"">> + end. + +is_user_exists(User, Server) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + case ejabberd_riak:get(passwd, passwd_schema(), {LUser, LServer}) of + {error, notfound} -> false; + {ok, _} -> true; + Err -> Err + end. + +remove_user(User, Server) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + ejabberd_riak:delete(passwd, {LUser, LServer}), + ok. + +remove_user(User, Server, Password) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + case ejabberd_riak:get(passwd, passwd_schema(), {LUser, LServer}) of + {ok, #passwd{password = Password}} + when is_binary(Password) -> + ejabberd_riak:delete(passwd, {LUser, LServer}), + ok; + {ok, #passwd{password = Scram}} + when is_record(Scram, scram) -> + case is_password_scram_valid(Password, Scram) of + true -> + ejabberd_riak:delete(passwd, {LUser, LServer}), + ok; + false -> not_allowed + end; + _ -> not_exists + end. + +%%% +%%% SCRAM +%%% + +is_scrammed() -> + scram == + ejabberd_config:get_local_option({auth_password_format, ?MYNAME}, + fun(V) -> V end). + +password_to_scram(Password) -> + password_to_scram(Password, + ?SCRAM_DEFAULT_ITERATION_COUNT). + +password_to_scram(Password, IterationCount) -> + Salt = crypto:rand_bytes(?SALT_LENGTH), + SaltedPassword = scram:salted_password(Password, Salt, + IterationCount), + StoredKey = + scram:stored_key(scram:client_key(SaltedPassword)), + ServerKey = scram:server_key(SaltedPassword), + #scram{storedkey = jlib:encode_base64(StoredKey), + serverkey = jlib:encode_base64(ServerKey), + salt = jlib:encode_base64(Salt), + iterationcount = IterationCount}. + +is_password_scram_valid(Password, Scram) -> + IterationCount = Scram#scram.iterationcount, + Salt = jlib:decode_base64(Scram#scram.salt), + SaltedPassword = scram:salted_password(Password, Salt, + IterationCount), + StoredKey = + scram:stored_key(scram:client_key(SaltedPassword)), + jlib:decode_base64(Scram#scram.storedkey) == StoredKey. + +export(_Server) -> + [{passwd, + fun(Host, #passwd{us = {LUser, LServer}, password = Password}) + when LServer == Host -> + Username = ejabberd_odbc:escape(LUser), + Pass = ejabberd_odbc:escape(Password), + [[<<"delete from users where username='">>, Username, <<"';">>], + [<<"insert into users(username, password) " + "values ('">>, Username, <<"', '">>, Pass, <<"');">>]]; + (_Host, _R) -> + [] + end}]. + +import(LServer, riak, #passwd{} = Passwd) -> + ejabberd_riak:put(Passwd, passwd_schema(), [{'2i', [{<<"host">>, LServer}]}]); +import(_, _, _) -> + pass. diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index 5d0cc9c08..c0b042ec6 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -45,6 +45,7 @@ set_aux_field/3, del_aux_field/2, get_subscription/2, + send_filtered/5, broadcast/4, get_subscribed/1, transform_listen_option/2]). @@ -94,20 +95,20 @@ tls_options = [], authenticated = false, jid, - user = "", server = <<"">>, resource = <<"">>, + user = <<"">>, server = <<"">>, resource = <<"">>, sid, pres_t = ?SETS:new(), pres_f = ?SETS:new(), pres_a = ?SETS:new(), - pres_i = ?SETS:new(), pres_last, pres_pri, pres_timestamp, - pres_invis = false, privacy_list = #userlist{}, conn = unknown, auth_module = unknown, ip, aux_fields = [], + csi_state = active, + csi_queue = [], mgmt_state, mgmt_xmlns, mgmt_queue, @@ -247,6 +248,9 @@ get_subscription(LFrom, StateData) -> true -> none end. +send_filtered(FsmRef, Feature, From, To, Packet) -> + FsmRef ! {send_filtered, Feature, From, To, Packet}. + broadcast(FsmRef, Type, From, Packet) -> FsmRef ! {broadcast, Type, From, Packet}. @@ -307,41 +311,37 @@ init([{SockMod, Socket}, Opts]) -> end, MaxAckQueue = case proplists:get_value(max_ack_queue, Opts) of Limit when is_integer(Limit), Limit > 0 -> Limit; + infinity -> infinity; _ -> 500 end, ResumeTimeout = case proplists:get_value(resume_timeout, Opts) of Timeout when is_integer(Timeout), Timeout >= 0 -> Timeout; _ -> 300 end, - ResendOnTimeout = proplists:get_bool(resend_on_timeout, Opts), + ResendOnTimeout = case proplists:get_value(resend_on_timeout, Opts) of + Resend when is_boolean(Resend) -> Resend; + if_offline -> if_offline; + _ -> false + end, IP = peerip(SockMod, Socket), - %% Check if IP is blacklisted: - case is_ip_blacklisted(IP) of - true -> - ?INFO_MSG("Connection attempt from blacklisted " - "IP: ~s (~w)", - [jlib:ip_to_list(IP), IP]), - {stop, normal}; - false -> - Socket1 = if TLSEnabled andalso - SockMod /= ejabberd_frontend_socket -> - SockMod:starttls(Socket, TLSOpts); - true -> Socket - end, - SocketMonitor = SockMod:monitor(Socket1), - StateData = #state{socket = Socket1, sockmod = SockMod, - socket_monitor = SocketMonitor, - xml_socket = XMLSocket, zlib = Zlib, tls = TLS, - tls_required = StartTLSRequired, - tls_enabled = TLSEnabled, tls_options = TLSOpts, - sid = {now(), self()}, streamid = new_id(), - access = Access, shaper = Shaper, ip = IP, - mgmt_state = StreamMgmtState, - mgmt_max_queue = MaxAckQueue, - mgmt_timeout = ResumeTimeout, - mgmt_resend = ResendOnTimeout}, - {ok, wait_for_stream, StateData, ?C2S_OPEN_TIMEOUT} - end. + Socket1 = if TLSEnabled andalso + SockMod /= ejabberd_frontend_socket -> + SockMod:starttls(Socket, TLSOpts); + true -> Socket + end, + SocketMonitor = SockMod:monitor(Socket1), + StateData = #state{socket = Socket1, sockmod = SockMod, + socket_monitor = SocketMonitor, + xml_socket = XMLSocket, zlib = Zlib, tls = TLS, + tls_required = StartTLSRequired, + tls_enabled = TLSEnabled, tls_options = TLSOpts, + sid = {now(), self()}, streamid = new_id(), + access = Access, shaper = Shaper, ip = IP, + mgmt_state = StreamMgmtState, + mgmt_max_queue = MaxAckQueue, + mgmt_timeout = ResumeTimeout, + mgmt_resend = ResendOnTimeout}, + {ok, wait_for_stream, StateData, ?C2S_OPEN_TIMEOUT}. %% Return list of all available resources of contacts, get_subscribed(FsmRef) -> @@ -365,27 +365,31 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> jlib:nameprep(xml:get_attr_s(<<"to">>, Attrs)); S -> S end, + Lang = case xml:get_attr_s(<<"xml:lang">>, Attrs) of + Lang1 when byte_size(Lang1) =< 35 -> + %% As stated in BCP47, 4.4.1: + %% Protocols or specifications that + %% specify limited buffer sizes for + %% language tags MUST allow for + %% language tags of at least 35 characters. + Lang1; + _ -> + %% Do not store long language tag to + %% avoid possible DoS/flood attacks + <<"">> + end, + IsBlacklistedIP = is_ip_blacklisted(StateData#state.ip, Lang), case lists:member(Server, ?MYHOSTS) of - true -> - Lang = case xml:get_attr_s(<<"xml:lang">>, Attrs) of - Lang1 when size(Lang1) =< 35 -> - %% As stated in BCP47, 4.4.1: - %% Protocols or specifications that - %% specify limited buffer sizes for - %% language tags MUST allow for - %% language tags of at least 35 characters. - Lang1; - _ -> - %% Do not store long language tag to - %% avoid possible DoS/flood attacks - <<"">> - end, + true when IsBlacklistedIP == false -> change_shaper(StateData, jlib:make_jid(<<"">>, Server, <<"">>)), case xml:get_attr_s(<<"version">>, Attrs) of <<"1.0">> -> send_header(StateData, Server, <<"1.0">>, DefaultLang), case StateData#state.authenticated of false -> + TLS = StateData#state.tls, + TLSEnabled = StateData#state.tls_enabled, + TLSRequired = StateData#state.tls_required, SASLState = cyrsasl:server_new( <<"jabber">>, Server, <<"">>, [], @@ -401,12 +405,21 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> ejabberd_auth:check_password_with_authmodule( U, Server, P, D, DG) end), - Mechs = lists:map(fun (S) -> - #xmlel{name = <<"mechanism">>, - attrs = [], - children = [{xmlcdata, S}]} - end, - cyrsasl:listmech(Server)), + Mechs = + case TLSEnabled or not TLSRequired of + true -> + Ms = lists:map(fun (S) -> + #xmlel{name = <<"mechanism">>, + attrs = [], + children = [{xmlcdata, S}]} + end, + cyrsasl:listmech(Server)), + [#xmlel{name = <<"mechanisms">>, + attrs = [{<<"xmlns">>, ?NS_SASL}], + children = Ms}]; + false -> + [] + end, SockMod = (StateData#state.sockmod):get_sockmod( StateData#state.socket), @@ -424,9 +437,6 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> _ -> [] end, - TLS = StateData#state.tls, - TLSEnabled = StateData#state.tls_enabled, - TLSRequired = StateData#state.tls_required, TLSFeature = case (TLS == true) andalso (TLSEnabled == false) andalso @@ -451,10 +461,7 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> #xmlel{name = <<"stream:features">>, attrs = [], children = - TLSFeature ++ CompressFeature ++ - [#xmlel{name = <<"mechanisms">>, - attrs = [{<<"xmlns">>, ?NS_SASL}], - children = Mechs}] + TLSFeature ++ CompressFeature ++ Mechs ++ ejabberd_hooks:run_fold(c2s_stream_features, Server, [], [Server])}), @@ -491,6 +498,8 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> ++ RosterVersioningFeature ++ StreamManagementFeature ++ + ejabberd_hooks:run_fold(c2s_post_auth_features, + Server, [], [Server]) ++ ejabberd_hooks:run_fold(c2s_stream_features, Server, [], [Server]), send_element(StateData, @@ -523,6 +532,15 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) -> lang = Lang}) end end; + true -> + IP = StateData#state.ip, + {true, LogReason, ReasonT} = IsBlacklistedIP, + ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s", + [jlib:ip_to_list(IP), LogReason]), + send_header(StateData, Server, <<"">>, DefaultLang), + send_element(StateData, ?POLICY_VIOLATION_ERR(Lang, ReasonT)), + send_trailer(StateData), + {stop, normal, StateData}; _ -> send_header(StateData, ?MYNAME, <<"">>, DefaultLang), send_element(StateData, ?HOST_UNKNOWN_ERR), @@ -622,9 +640,13 @@ wait_for_auth({xmlstreamelement, El}, StateData) -> P, D, DGen) of {true, AuthModule} -> - ?INFO_MSG("(~w) Accepted legacy authentication for ~s by ~p", - [StateData#state.socket, - jlib:jid_to_string(JID), AuthModule]), + ?INFO_MSG("(~w) Accepted legacy authentication for ~s by ~p from ~s", + [StateData#state.socket, + jlib:jid_to_string(JID), AuthModule, + jlib:ip_to_list(StateData#state.ip)]), + ejabberd_hooks:run(c2s_auth_result, StateData#state.server, + [true, U, StateData#state.server, + StateData#state.ip]), Conn = get_conn_type(StateData), Info = [{ip, StateData#state.ip}, {conn, Conn}, {auth_module, AuthModule}], @@ -659,12 +681,13 @@ wait_for_auth({xmlstreamelement, El}, StateData) -> privacy_list = PrivList}, fsm_next_state(session_established, NewStateData); _ -> - IP = peerip(StateData#state.sockmod, - StateData#state.socket), - ?INFO_MSG("(~w) Failed legacy authentication for " - "~s from IP ~s", - [StateData#state.socket, - jlib:jid_to_string(JID), jlib:ip_to_list(IP)]), + ?INFO_MSG("(~w) Failed legacy authentication for ~s from ~s", + [StateData#state.socket, + jlib:jid_to_string(JID), + jlib:ip_to_list(StateData#state.ip)]), + ejabberd_hooks:run(c2s_auth_result, StateData#state.server, + [false, U, StateData#state.server, + StateData#state.ip]), Err = jlib:make_error_reply(El, ?ERR_NOT_AUTHORIZED), send_element(StateData, Err), fsm_next_state(wait_for_auth, StateData) @@ -679,9 +702,13 @@ wait_for_auth({xmlstreamelement, El}, StateData) -> fsm_next_state(wait_for_auth, StateData); true -> ?INFO_MSG("(~w) Forbidden legacy authentication " - "for ~s", + "for ~s from ~s", [StateData#state.socket, - jlib:jid_to_string(JID)]), + jlib:jid_to_string(JID), + jlib:ip_to_list(StateData#state.ip)]), + ejabberd_hooks:run(c2s_auth_result, StateData#state.server, + [false, U, StateData#state.server, + StateData#state.ip]), Err = jlib:make_error_reply(El, ?ERR_NOT_ALLOWED), send_element(StateData, Err), fsm_next_state(wait_for_auth, StateData) @@ -718,7 +745,7 @@ wait_for_feature_request({xmlstreamelement, El}, (StateData#state.sockmod):get_sockmod(StateData#state.socket), case {xml:get_attr_s(<<"xmlns">>, Attrs), Name} of {?NS_SASL, <<"auth">>} - when not ((SockMod == gen_tcp) and TLSRequired) -> + when TLSEnabled or not TLSRequired -> Mech = xml:get_attr_s(<<"mechanism">>, Attrs), ClientIn = jlib:decode_base64(xml:get_cdata(Els)), case cyrsasl:server_start(StateData#state.sasl_state, @@ -731,8 +758,12 @@ wait_for_feature_request({xmlstreamelement, El}, %AuthModule = xml:get_attr_s(auth_module, Props), AuthModule = proplists:get_value(auth_module, Props, undefined), ?INFO_MSG("(~w) Accepted authentication for ~s " - "by ~p", - [StateData#state.socket, U, AuthModule]), + "by ~p from ~s", + [StateData#state.socket, U, AuthModule, + jlib:ip_to_list(StateData#state.ip)]), + ejabberd_hooks:run(c2s_auth_result, StateData#state.server, + [true, U, StateData#state.server, + StateData#state.ip]), send_element(StateData, #xmlel{name = <<"success">>, attrs = [{<<"xmlns">>, ?NS_SASL}], @@ -753,18 +784,20 @@ wait_for_feature_request({xmlstreamelement, El}, fsm_next_state(wait_for_sasl_response, StateData#state{sasl_state = NewSASLState}); {error, Error, Username} -> - IP = peerip(StateData#state.sockmod, StateData#state.socket), - ?INFO_MSG("(~w) Failed authentication for ~s@~s from IP ~s", - [StateData#state.socket, - Username, StateData#state.server, jlib:ip_to_list(IP)]), + ?INFO_MSG("(~w) Failed authentication for ~s@~s from ~s", + [StateData#state.socket, + Username, StateData#state.server, + jlib:ip_to_list(StateData#state.ip)]), + ejabberd_hooks:run(c2s_auth_result, StateData#state.server, + [false, Username, StateData#state.server, + StateData#state.ip]), send_element(StateData, #xmlel{name = <<"failure">>, attrs = [{<<"xmlns">>, ?NS_SASL}], children = [#xmlel{name = Error, attrs = [], children = []}]}), - {next_state, wait_for_feature_request, StateData, - ?C2S_OPEN_TIMEOUT}; + fsm_next_state(wait_for_feature_request, StateData); {error, Error} -> send_element(StateData, #xmlel{name = <<"failure">>, @@ -833,7 +866,7 @@ wait_for_feature_request({xmlstreamelement, El}, end end; _ -> - if (SockMod == gen_tcp) and TLSRequired -> + if TLSRequired and not TLSEnabled -> Lang = StateData#state.lang, send_element(StateData, ?POLICY_VIOLATION_ERR(Lang, @@ -878,8 +911,12 @@ wait_for_sasl_response({xmlstreamelement, El}, % AuthModule = xml:get_attr_s(auth_module, Props), AuthModule = proplists:get_value(auth_module, Props, <<>>), ?INFO_MSG("(~w) Accepted authentication for ~s " - "by ~p", - [StateData#state.socket, U, AuthModule]), + "by ~p from ~s", + [StateData#state.socket, U, AuthModule, + jlib:ip_to_list(StateData#state.ip)]), + ejabberd_hooks:run(c2s_auth_result, StateData#state.server, + [true, U, StateData#state.server, + StateData#state.ip]), send_element(StateData, #xmlel{name = <<"success">>, attrs = [{<<"xmlns">>, ?NS_SASL}], @@ -897,8 +934,12 @@ wait_for_sasl_response({xmlstreamelement, El}, % AuthModule = xml:get_attr_s(auth_module, Props), AuthModule = proplists:get_value(auth_module, Props, undefined), ?INFO_MSG("(~w) Accepted authentication for ~s " - "by ~p", - [StateData#state.socket, U, AuthModule]), + "by ~p from ~s", + [StateData#state.socket, U, AuthModule, + jlib:ip_to_list(StateData#state.ip)]), + ejabberd_hooks:run(c2s_auth_result, StateData#state.server, + [true, U, StateData#state.server, + StateData#state.ip]), send_element(StateData, #xmlel{name = <<"success">>, attrs = [{<<"xmlns">>, ?NS_SASL}], @@ -921,10 +962,13 @@ wait_for_sasl_response({xmlstreamelement, El}, fsm_next_state(wait_for_sasl_response, StateData#state{sasl_state = NewSASLState}); {error, Error, Username} -> - IP = peerip(StateData#state.sockmod, StateData#state.socket), - ?INFO_MSG("(~w) Failed authentication for ~s@~s from IP ~s", - [StateData#state.socket, - Username, StateData#state.server, jlib:ip_to_list(IP)]), + ?INFO_MSG("(~w) Failed authentication for ~s@~s from ~s", + [StateData#state.socket, + Username, StateData#state.server, + jlib:ip_to_list(StateData#state.ip)]), + ejabberd_hooks:run(c2s_auth_result, StateData#state.server, + [false, Username, StateData#state.server, + StateData#state.ip]), send_element(StateData, #xmlel{name = <<"failure">>, attrs = [{<<"xmlns">>, ?NS_SASL}], @@ -1137,6 +1181,17 @@ wait_for_session(closed, StateData) -> session_established({xmlstreamelement, #xmlel{name = Name} = El}, StateData) when ?IS_STREAM_MGMT_TAG(Name) -> fsm_next_state(session_established, dispatch_stream_mgmt(El, StateData)); +session_established({xmlstreamelement, + #xmlel{name = <<"active">>, + attrs = [{<<"xmlns">>, ?NS_CLIENT_STATE}]}}, + StateData) -> + NewStateData = csi_queue_flush(StateData), + fsm_next_state(session_established, NewStateData#state{csi_state = active}); +session_established({xmlstreamelement, + #xmlel{name = <<"inactive">>, + attrs = [{<<"xmlns">>, ?NS_CLIENT_STATE}]}}, + StateData) -> + fsm_next_state(session_established, StateData#state{csi_state = inactive}); session_established({xmlstreamelement, El}, StateData) -> FromJID = StateData#state.jid, @@ -1168,12 +1223,8 @@ session_established({xmlstreamerror, _}, StateData) -> send_element(StateData, ?INVALID_XML_ERR), send_trailer(StateData), {stop, normal, StateData}; -session_established(closed, StateData) - when StateData#state.mgmt_timeout > 0, - StateData#state.mgmt_state == active orelse - StateData#state.mgmt_state == pending -> - log_pending_state(StateData), - fsm_next_state(wait_for_resume, StateData#state{mgmt_state = pending}); +session_established(closed, #state{mgmt_state = active} = StateData) -> + fsm_next_state(wait_for_resume, StateData); session_established(closed, StateData) -> {stop, normal, StateData}. @@ -1312,8 +1363,8 @@ handle_sync_event(get_subscribed, _From, StateName, StateData) -> Subscribed = (?SETS):to_list(StateData#state.pres_f), {reply, Subscribed, StateName, StateData}; -handle_sync_event(resume_session, _From, _StateName, - StateData) -> +handle_sync_event({resume_session, Time}, _From, _StateName, + StateData) when element(1, StateData#state.sid) == Time -> %% The old session should be closed before the new one is opened, so we do %% this here instead of leaving it to the terminate callback ejabberd_sm:close_session(StateData#state.sid, @@ -1321,6 +1372,9 @@ handle_sync_event(resume_session, _From, _StateName, StateData#state.server, StateData#state.resource), {stop, normal, {ok, StateData}, StateData#state{mgmt_state = resumed}}; +handle_sync_event({resume_session, _Time}, _From, StateName, + StateData) -> + {reply, {error, <<"Previous session not found">>}, StateName, StateData}; handle_sync_event(_Event, _From, StateName, StateData) -> Reply = ok, fsm_reply(Reply, StateName, StateData). @@ -1373,7 +1427,6 @@ handle_info({route, _From, _To, {broadcast, Data}}, fsm_next_state(StateName, StateData); NewPL -> PrivPushIQ = #iq{type = set, - xmlns = ?NS_PRIVACY, id = <<"push", (randoms:get_string())/binary>>, sub_el = @@ -1654,12 +1707,9 @@ handle_info({route, From, To, handle_info({'DOWN', Monitor, _Type, _Object, _Info}, _StateName, StateData) when Monitor == StateData#state.socket_monitor -> - if StateData#state.mgmt_timeout > 0, - StateData#state.mgmt_state == active orelse + if StateData#state.mgmt_state == active; StateData#state.mgmt_state == pending -> - log_pending_state(StateData), - fsm_next_state(wait_for_resume, - StateData#state{mgmt_state = pending}); + fsm_next_state(wait_for_resume, StateData); true -> {stop, normal, StateData} end; @@ -1691,12 +1741,32 @@ handle_info({force_update_presence, LUser}, StateName, StateData2; _ -> StateData end, - {next_state, StateName, NewStateData}; + fsm_next_state(StateName, NewStateData); +handle_info({send_filtered, Feature, From, To, Packet}, StateName, StateData) -> + Drop = ejabberd_hooks:run_fold(c2s_filter_packet, StateData#state.server, + true, [StateData#state.server, StateData, + Feature, To, Packet]), + NewStateData = if Drop -> + ?DEBUG("Dropping packet from ~p to ~p", + [jlib:jid_to_string(From), + jlib:jid_to_string(To)]), + StateData; + true -> + FinalPacket = jlib:replace_from_to(From, To, Packet), + case StateData#state.jid of + To -> + send_packet(StateData, FinalPacket); + _ -> + ejabberd_router:route(From, To, FinalPacket), + StateData + end + end, + fsm_next_state(StateName, NewStateData); handle_info({broadcast, Type, From, Packet}, StateName, StateData) -> Recipients = ejabberd_hooks:run_fold( c2s_broadcast_recipients, StateData#state.server, [], - [StateData, Type, From, Packet]), + [StateData#state.server, StateData, Type, From, Packet]), lists:foreach( fun(USR) -> ejabberd_router:route( @@ -1713,11 +1783,10 @@ handle_info(Info, StateName, StateData) -> %% Purpose: Prepare the state to be printed on error log %% Returns: State to print %%---------------------------------------------------------------------- -print_state(State = #state{pres_t = T, pres_f = F, pres_a = A, pres_i = I}) -> +print_state(State = #state{pres_t = T, pres_f = F, pres_a = A}) -> State#state{pres_t = {pres_t, ?SETS:size(T)}, pres_f = {pres_f, ?SETS:size(F)}, - pres_a = {pres_a, ?SETS:size(A)}, - pres_i = {pres_i, ?SETS:size(I)} + pres_a = {pres_a, ?SETS:size(A)} }. %%---------------------------------------------------------------------- @@ -1753,8 +1822,6 @@ terminate(_Reason, StateName, StateData) -> <<"Replaced by new connection">>), presence_broadcast(StateData, From, StateData#state.pres_a, Packet), - presence_broadcast(StateData, From, - StateData#state.pres_i, Packet), handle_unacked_stanzas(StateData); _ -> ?INFO_MSG("(~w) Close session for ~s", @@ -1762,10 +1829,7 @@ terminate(_Reason, StateName, StateData) -> jlib:jid_to_string(StateData#state.jid)]), EmptySet = (?SETS):new(), case StateData of - #state{pres_last = undefined, - pres_a = EmptySet, - pres_i = EmptySet, - pres_invis = false} -> + #state{pres_last = undefined, pres_a = EmptySet} -> ejabberd_sm:close_session(StateData#state.sid, StateData#state.user, StateData#state.server, @@ -1781,9 +1845,7 @@ terminate(_Reason, StateName, StateData) -> StateData#state.resource, <<"">>), presence_broadcast(StateData, From, - StateData#state.pres_a, Packet), - presence_broadcast(StateData, From, - StateData#state.pres_i, Packet) + StateData#state.pres_a, Packet) end, handle_unacked_stanzas(StateData) end, @@ -1811,6 +1873,15 @@ send_text(StateData, Text) when StateData#state.xml_socket -> ?DEBUG("Send Text on stream = ~p", [Text]), (StateData#state.sockmod):send_xml(StateData#state.socket, {xmlstreamraw, Text}); +send_text(StateData, Text) when StateData#state.mgmt_state == active -> + ?DEBUG("Send XML on stream = ~p", [Text]), + case catch (StateData#state.sockmod):send(StateData#state.socket, Text) of + {'EXIT', _} -> + (StateData#state.sockmod):close(StateData#state.socket), + error; + _ -> + ok + end; send_text(StateData, Text) -> ?DEBUG("Send XML on stream = ~p", [Text]), (StateData#state.sockmod):send(StateData#state.socket, Text). @@ -1823,27 +1894,30 @@ send_element(StateData, El) when StateData#state.xml_socket -> send_element(StateData, El) -> send_text(StateData, xml:element_to_binary(El)). +send_stanza(StateData, Stanza) when StateData#state.csi_state == inactive -> + csi_filter_stanza(StateData, Stanza); send_stanza(StateData, Stanza) when StateData#state.mgmt_state == pending -> mgmt_queue_add(StateData, Stanza); send_stanza(StateData, Stanza) when StateData#state.mgmt_state == active -> - send_stanza_and_ack_req(StateData, Stanza), - mgmt_queue_add(StateData, Stanza); + NewStateData = case send_stanza_and_ack_req(StateData, Stanza) of + ok -> + StateData; + error -> + StateData#state{mgmt_state = pending} + end, + mgmt_queue_add(NewStateData, Stanza); send_stanza(StateData, Stanza) -> send_element(StateData, Stanza), StateData. -send_packet(StateData, Packet) when StateData#state.mgmt_state == active; - StateData#state.mgmt_state == pending -> +send_packet(StateData, Packet) -> case is_stanza(Packet) of true -> send_stanza(StateData, Packet); false -> send_element(StateData, Packet), StateData - end; -send_packet(StateData, Stanza) -> - send_element(StateData, Stanza), - StateData. + end. send_header(StateData, Server, Version, Lang) when StateData#state.xml_socket -> @@ -1952,21 +2026,12 @@ process_presence_probe(From, To, StateData) -> undefined -> ok; _ -> - Cond1 = (not StateData#state.pres_invis) - andalso (?SETS:is_element(LFrom, StateData#state.pres_f) - orelse - ((LFrom /= LBFrom) andalso - ?SETS:is_element(LBFrom, StateData#state.pres_f))) - andalso (not - (?SETS:is_element(LFrom, StateData#state.pres_i) - orelse - ((LFrom /= LBFrom) andalso - ?SETS:is_element(LBFrom, StateData#state.pres_i)))), - Cond2 = StateData#state.pres_invis - andalso ?SETS:is_element(LFrom, StateData#state.pres_f) - andalso ?SETS:is_element(LFrom, StateData#state.pres_a), + Cond = ?SETS:is_element(LFrom, StateData#state.pres_f) + orelse + ((LFrom /= LBFrom) andalso + ?SETS:is_element(LBFrom, StateData#state.pres_f)), if - Cond1 -> + Cond -> Timestamp = StateData#state.pres_timestamp, Packet = xml:append_subtags( StateData#state.pres_last, @@ -1988,11 +2053,6 @@ process_presence_probe(From, To, StateData) -> ok end end; - Cond2 -> - ejabberd_router:route(To, From, - #xmlel{name = <<"presence">>, - attrs = [], - children = []}); true -> ok end @@ -2168,7 +2228,7 @@ presence_broadcast_first(From, StateData, Packet) -> [], StateData#state.pres_t), PacketProbe = #xmlel{name = <<"presence">>, attrs = [{<<"type">>,<<"probe">>}], children = []}, - JIDs2Probe = format_and_check_privacy(From, StateData, Packet, JIDsProbe, out), + JIDs2Probe = format_and_check_privacy(From, StateData, PacketProbe, JIDsProbe, out), Server = StateData#state.server, send_multiple(StateData, From, JIDs2Probe, PacketProbe), {As, JIDs} = @@ -2424,13 +2484,28 @@ fsm_next_state_gc(StateName, PackedStateData) -> %% fsm_next_state: Generate the next_state FSM tuple with different %% timeout, depending on the future state +fsm_next_state(session_established, #state{mgmt_max_queue = exceeded} = + StateData) -> + ?WARNING_MSG("ACK queue too long, terminating session for ~s", + [jlib:jid_to_string(StateData#state.jid)]), + Err = ?SERRT_POLICY_VIOLATION(StateData#state.lang, + <<"Too many unacked stanzas">>), + send_element(StateData, Err), + send_trailer(StateData), + {stop, normal, StateData#state{mgmt_resend = false}}; +fsm_next_state(session_established, #state{mgmt_state = pending} = StateData) -> + fsm_next_state(wait_for_resume, StateData); fsm_next_state(session_established, StateData) -> {next_state, session_established, StateData, ?C2S_HIBERNATE_TIMEOUT}; +fsm_next_state(wait_for_resume, #state{mgmt_timeout = 0} = StateData) -> + {stop, normal, StateData}; fsm_next_state(wait_for_resume, #state{mgmt_pending_since = undefined} = StateData) -> + ?INFO_MSG("Waiting for resumption of stream for ~s", + [jlib:jid_to_string(StateData#state.jid)]), {next_state, wait_for_resume, - StateData#state{mgmt_pending_since = os:timestamp()}, + StateData#state{mgmt_state = pending, mgmt_pending_since = os:timestamp()}, StateData#state.mgmt_timeout}; fsm_next_state(wait_for_resume, StateData) -> Diff = timer:now_diff(os:timestamp(), StateData#state.mgmt_pending_since), @@ -2444,11 +2519,6 @@ fsm_next_state(StateName, StateData) -> fsm_reply(Reply, session_established, StateData) -> {reply, Reply, session_established, StateData, ?C2S_HIBERNATE_TIMEOUT}; -fsm_reply(Reply, wait_for_resume, #state{mgmt_pending_since = undefined} = - StateData) -> - {reply, Reply, wait_for_resume, - StateData#state{mgmt_pending_since = os:timestamp()}, - StateData#state.mgmt_timeout}; fsm_reply(Reply, wait_for_resume, StateData) -> Diff = timer:now_diff(os:timestamp(), StateData#state.mgmt_pending_since), Timeout = max(StateData#state.mgmt_timeout - Diff div 1000, 1), @@ -2457,9 +2527,9 @@ fsm_reply(Reply, StateName, StateData) -> {reply, Reply, StateName, StateData, ?C2S_OPEN_TIMEOUT}. %% Used by c2s blacklist plugins -is_ip_blacklisted(undefined) -> false; -is_ip_blacklisted({IP, _Port}) -> - ejabberd_hooks:run_fold(check_bl_c2s, false, [IP]). +is_ip_blacklisted(undefined, _Lang) -> false; +is_ip_blacklisted({IP, _Port}, Lang) -> + ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]). %% Check from attributes %% returns invalid-from|NewElement @@ -2541,8 +2611,7 @@ route_blocking(What, StateData) -> #xmlel{name = <<"unblock">>, attrs = [{<<"xmlns">>, ?NS_BLOCKING}], children = []} end, - PrivPushIQ = #iq{type = set, xmlns = ?NS_BLOCKING, - id = <<"push">>, sub_el = [SubEl]}, + PrivPushIQ = #iq{type = set, id = <<"push">>, sub_el = [SubEl]}, PrivPushEl = jlib:replace_from_to(jlib:jid_remove_resource(StateData#state.jid), StateData#state.jid, jlib:iq_to_xml(PrivPushIQ)), @@ -2725,9 +2794,11 @@ handle_resume(StateData, Attrs) -> #xmlel{name = <<"r">>, attrs = [{<<"xmlns">>, AttrXmlns}], children = []}), + FlushedState = csi_queue_flush(NewState), + NewStateData = FlushedState#state{csi_state = active}, ?INFO_MSG("Resumed session for ~s", - [jlib:jid_to_string(NewState#state.jid)]), - {ok, NewState}; + [jlib:jid_to_string(NewStateData#state.jid)]), + {ok, NewStateData}; {error, El, Msg} -> send_element(StateData, El), ?INFO_MSG("Cannot resume session for ~s@~s: ~s", @@ -2773,37 +2844,29 @@ mgmt_queue_add(StateData, El) -> Num -> Num + 1 end, - NewState = limit_queue_length(StateData), - NewQueue = queue:in({NewNum, El}, NewState#state.mgmt_queue), - NewState#state{mgmt_queue = NewQueue, mgmt_stanzas_out = NewNum}. + NewQueue = queue:in({NewNum, El}, StateData#state.mgmt_queue), + NewState = StateData#state{mgmt_queue = NewQueue, + mgmt_stanzas_out = NewNum}, + check_queue_length(NewState). mgmt_queue_drop(StateData, NumHandled) -> NewQueue = jlib:queue_drop_while(fun({N, _Stanza}) -> N =< NumHandled end, StateData#state.mgmt_queue), StateData#state{mgmt_queue = NewQueue}. -limit_queue_length(#state{mgmt_max_queue = Limit} = StateData) +check_queue_length(#state{mgmt_max_queue = Limit} = StateData) when Limit == infinity; - Limit == unlimited -> + Limit == exceeded -> StateData; -limit_queue_length(#state{jid = JID, - mgmt_queue = Queue, +check_queue_length(#state{mgmt_queue = Queue, mgmt_max_queue = Limit} = StateData) -> - case queue:len(Queue) >= Limit of + case queue:len(Queue) > Limit of true -> - ?WARNING_MSG("Dropping stanza from too long ACK queue for ~s", - [jlib:jid_to_string(JID)]), - limit_queue_length(StateData#state{mgmt_queue = queue:drop(Queue)}); + StateData#state{mgmt_max_queue = exceeded}; false -> StateData end. -log_pending_state(StateData) when StateData#state.mgmt_state /= pending -> - ?INFO_MSG("Waiting for resumption of stream for ~s", - [jlib:jid_to_string(StateData#state.jid)]); -log_pending_state(_StateData) -> - ok. - handle_unacked_stanzas(StateData, F) when StateData#state.mgmt_state == active; StateData#state.mgmt_state == pending -> @@ -2829,7 +2892,15 @@ handle_unacked_stanzas(_StateData, _F) -> handle_unacked_stanzas(StateData) when StateData#state.mgmt_state == active; StateData#state.mgmt_state == pending -> - ReRoute = case StateData#state.mgmt_resend of + ResendOnTimeout = + case StateData#state.mgmt_resend of + Resend when is_boolean(Resend) -> + Resend; + if_offline -> + ejabberd_sm:get_user_resources(StateData#state.user, + StateData#state.server) == [] + end, + ReRoute = case ResendOnTimeout of true -> fun ejabberd_router:route/3; false -> @@ -2888,14 +2959,14 @@ is_encapsulated_forward(_El) -> inherit_session_state(#state{user = U, server = S} = StateData, ResumeID) -> case jlib:base64_to_term(ResumeID) of - {term, {U, S, R, Time}} -> + {term, {R, Time}} -> case ejabberd_sm:get_session_pid(U, S, R) of none -> {error, <<"Previous session PID not found">>}; OldPID -> OldSID = {Time, OldPID}, - case catch resume_session(OldPID) of - {ok, #state{sid = OldSID} = OldStateData} -> + case catch resume_session(OldSID) of + {ok, OldStateData} -> NewSID = {Time, self()}, % Old time, new PID Priority = case OldStateData#state.pres_last of undefined -> @@ -2908,43 +2979,107 @@ inherit_session_state(#state{user = U, server = S} = StateData, ResumeID) -> {auth_module, StateData#state.auth_module}], ejabberd_sm:open_session(NewSID, U, S, R, Priority, Info), - {ok, StateData#state{sid = NewSID, + {ok, StateData#state{conn = Conn, + sid = NewSID, jid = OldStateData#state.jid, resource = OldStateData#state.resource, pres_t = OldStateData#state.pres_t, pres_f = OldStateData#state.pres_f, pres_a = OldStateData#state.pres_a, - pres_i = OldStateData#state.pres_i, pres_last = OldStateData#state.pres_last, pres_pri = OldStateData#state.pres_pri, pres_timestamp = OldStateData#state.pres_timestamp, - pres_invis = OldStateData#state.pres_invis, privacy_list = OldStateData#state.privacy_list, aux_fields = OldStateData#state.aux_fields, + csi_state = OldStateData#state.csi_state, + csi_queue = OldStateData#state.csi_queue, mgmt_xmlns = OldStateData#state.mgmt_xmlns, mgmt_queue = OldStateData#state.mgmt_queue, mgmt_timeout = OldStateData#state.mgmt_timeout, mgmt_stanzas_in = OldStateData#state.mgmt_stanzas_in, mgmt_stanzas_out = OldStateData#state.mgmt_stanzas_out, mgmt_state = active}}; + {error, Msg} -> + {error, Msg}; _ -> {error, <<"Cannot grab session state">>} end end; - error -> + _ -> {error, <<"Invalid 'previd' value">>} end. -resume_session(FsmRef) -> - (?GEN_FSM):sync_send_all_state_event(FsmRef, resume_session, 3000). +resume_session({Time, PID}) -> + (?GEN_FSM):sync_send_all_state_event(PID, {resume_session, Time}, 3000). make_resume_id(StateData) -> {Time, _} = StateData#state.sid, - ID = {StateData#state.user, - StateData#state.server, - StateData#state.resource, - Time}, - jlib:term_to_base64(ID). + jlib:term_to_base64({StateData#state.resource, Time}). + +%%%---------------------------------------------------------------------- +%%% XEP-0352 +%%%---------------------------------------------------------------------- + +csi_filter_stanza(#state{csi_state = CsiState, jid = JID} = StateData, + Stanza) -> + Action = ejabberd_hooks:run_fold(csi_filter_stanza, + StateData#state.server, + send, [Stanza]), + ?DEBUG("Going to ~p stanza for inactive client ~p", + [Action, jlib:jid_to_string(JID)]), + case Action of + queue -> csi_queue_add(StateData, Stanza); + drop -> StateData; + send -> + From = xml:get_tag_attr_s(<<"from">>, Stanza), + StateData1 = csi_queue_send(StateData, From), + StateData2 = send_stanza(StateData1#state{csi_state = active}, + Stanza), + StateData2#state{csi_state = CsiState} + end. + +csi_queue_add(#state{csi_queue = Queue, server = Host} = StateData, + #xmlel{children = Els} = Stanza) -> + From = xml:get_tag_attr_s(<<"from">>, Stanza), + Time = calendar:now_to_universal_time(os:timestamp()), + DelayTag = [jlib:timestamp_to_xml(Time, utc, + jlib:make_jid(<<"">>, Host, <<"">>), + <<"Client Inactive">>)], + NewStanza = Stanza#xmlel{children = Els ++ DelayTag}, + case length(StateData#state.csi_queue) >= csi_max_queue(StateData) of + true -> csi_queue_add(csi_queue_flush(StateData), NewStanza); + false -> + NewQueue = lists:keystore(From, 1, Queue, {From, NewStanza}), + StateData#state{csi_queue = NewQueue} + end. + +csi_queue_send(#state{csi_queue = Queue, csi_state = CsiState} = StateData, + From) -> + case lists:keytake(From, 1, Queue) of + {value, {From, Stanza}, NewQueue} -> + NewStateData = send_stanza(StateData#state{csi_state = active}, + Stanza), + NewStateData#state{csi_queue = NewQueue, csi_state = CsiState}; + false -> StateData + end. + +csi_queue_flush(#state{csi_queue = Queue, csi_state = CsiState, jid = JID} = + StateData) -> + ?DEBUG("Flushing CSI queue for ~s", [jlib:jid_to_string(JID)]), + NewStateData = + lists:foldl(fun({_From, Stanza}, AccState) -> + send_stanza(AccState, Stanza) + end, StateData#state{csi_state = active}, Queue), + NewStateData#state{csi_queue = [], csi_state = CsiState}. + +%% Make sure we won't push too many messages to the XEP-0198 queue when the +%% client becomes 'active' again. Otherwise, the client might not manage to +%% acknowledge the message flood in time. Also, don't let the queue grow to +%% more than 100 stanzas. +csi_max_queue(#state{mgmt_max_queue = infinity}) -> 100; +csi_max_queue(#state{mgmt_max_queue = Max}) when Max > 200 -> 100; +csi_max_queue(#state{mgmt_max_queue = Max}) when Max < 2 -> 1; +csi_max_queue(#state{mgmt_max_queue = Max}) -> Max div 2. %%%---------------------------------------------------------------------- %%% JID Set memory footprint reduction code diff --git a/src/ejabberd_config.erl b/src/ejabberd_config.erl index 06db8b569..dd1db765d 100644 --- a/src/ejabberd_config.erl +++ b/src/ejabberd_config.erl @@ -68,8 +68,15 @@ start() -> %% This start time is used by mod_last: {MegaSecs, Secs, _} = now(), UnixTime = MegaSecs*1000000 + Secs, + SharedKey = case erlang:get_cookie() of + nocookie -> + p1_sha:sha(randoms:get_string()); + Cookie -> + p1_sha:sha(jlib:atom_to_binary(Cookie)) + end, State1 = set_option({node_start, global}, UnixTime, State), - set_opts(State1). + State2 = set_option({shared_key, global}, SharedKey, State1), + set_opts(State2). %% @doc Get the filename of the ejabberd configuration file. %% The filename can be specified with: erl -config "/path/to/ejabberd.yml". @@ -179,7 +186,9 @@ consult(File) -> {ok, [Document|_]} -> {ok, Document}; {error, Err} -> - {error, p1_yaml:format_error(Err)} + Msg1 = "Cannot load " ++ File ++ ": ", + Msg2 = p1_yaml:format_error(Err), + {error, Msg1 ++ Msg2} end; _ -> case file:consult(File) of @@ -980,7 +989,7 @@ report_and_stop(Tab, Err) -> halt(string:substr(ErrTxt, 1, 199)). emit_deprecation_warning(Module, NewModule, DBType) -> - ?WARNING_MSG("Module ~s is deprecated, use {~s, [{db_type, ~s}, ...]}" + ?WARNING_MSG("Module ~s is deprecated, use ~s with 'db_type: ~s'" " instead", [Module, NewModule, DBType]). emit_deprecation_warning(Module, NewModule) -> diff --git a/src/ejabberd_hooks.erl b/src/ejabberd_hooks.erl index e1f99eb88..87c26c5ed 100644 --- a/src/ejabberd_hooks.erl +++ b/src/ejabberd_hooks.erl @@ -151,7 +151,7 @@ run(Hook, Host, Args) -> %% The arguments passed to the function are: [Val | Args]. %% The result of a call is used as Val for the next call. %% If a call returns 'stop', no more calls are performed and 'stopped' is returned. -%% If a call returns {stopped, NewVal}, no more calls are performed and NewVal is returned. +%% If a call returns {stop, NewVal}, no more calls are performed and NewVal is returned. run_fold(Hook, Val, Args) -> run_fold(Hook, global, Val, Args). diff --git a/src/ejabberd_http.erl b/src/ejabberd_http.erl index c5b5758ae..162d5ac73 100644 --- a/src/ejabberd_http.erl +++ b/src/ejabberd_http.erl @@ -65,6 +65,7 @@ request_tp, request_headers = [], end_of_request = false, + options = [], default_host, trail = <<>> }). @@ -133,6 +134,10 @@ init({SockMod, Socket}, Opts) -> true -> [{[<<"http-poll">>], ejabberd_http_poll}]; false -> [] end, + XMLRPC = case proplists:get_bool(xmlrpc, Opts) of + true -> [{[], ejabberd_xmlrpc}]; + false -> [] + end, DefinedHandlers = gen_mod:get_opt( request_handlers, Opts, fun(Hs) -> @@ -141,7 +146,7 @@ init({SockMod, Socket}, Opts) -> Mod} || {Path, Mod} <- Hs] end, []), RequestHandlers = DefinedHandlers ++ Captcha ++ Register ++ - Admin ++ Bind ++ Poll, + Admin ++ Bind ++ Poll ++ XMLRPC, ?DEBUG("S: ~p~n", [RequestHandlers]), DefaultHost = gen_mod:get_opt(default_host, Opts, fun(A) -> A end, undefined), @@ -150,6 +155,7 @@ init({SockMod, Socket}, Opts) -> State = #state{sockmod = SockMod1, socket = Socket1, default_host = DefaultHost, + options = Opts, request_handlers = RequestHandlers}, receive_headers(State). @@ -359,7 +365,7 @@ process(Handlers, Request) -> false -> process(HandlersLeft, Request) end. -process_request(#state{request_method = Method, +process_request(#state{request_method = Method, options = Options, request_path = {abs_path, Path}, request_auth = Auth, request_lang = Lang, request_handlers = RequestHandlers, request_host = Host, request_port = Port, @@ -389,6 +395,7 @@ process_request(#state{request_method = Method, IP = analyze_ip_xff(IPHere, XFF, Host), Request = #request{method = Method, path = LPath, + opts = Options, q = LQuery, auth = Auth, lang = Lang, @@ -413,7 +420,7 @@ process_request(#state{request_method = Method, make_text_output(State, Status, Headers, Output) end end; -process_request(#state{request_method = Method, +process_request(#state{request_method = Method, options = Options, request_path = {abs_path, Path}, request_auth = Auth, request_content_length = Len, request_lang = Lang, sockmod = SockMod, socket = Socket, request_host = Host, @@ -450,6 +457,7 @@ process_request(#state{request_method = Method, Request = #request{method = Method, path = LPath, q = LQuery, + opts = Options, auth = Auth, data = Data, lang = Lang, diff --git a/src/ejabberd_listener.erl b/src/ejabberd_listener.erl index 02a2f3fbd..515cf7348 100644 --- a/src/ejabberd_listener.erl +++ b/src/ejabberd_listener.erl @@ -201,11 +201,7 @@ listen_tcp(PortIP, Module, SockOpts, Port, IPS) -> catch _:_ -> [] end, - DeliverAs = case Module of - ejabberd_xmlrpc -> list; - _ -> binary - end, - Res = gen_tcp:listen(Port, [DeliverAs, + Res = gen_tcp:listen(Port, [binary, {packet, 0}, {active, false}, {reuseaddr, true}, @@ -595,7 +591,7 @@ transform_option({{Port, IP, Transport}, Mod, Opts}) -> try Mod:transform_listen_option(Opt, Acc) catch error:undef -> - Acc + [Opt|Acc] end end, [], Opts1), TransportOpt = if Transport == tcp -> []; diff --git a/src/ejabberd_logger.erl b/src/ejabberd_logger.erl index e1883f1db..65899c8f6 100644 --- a/src/ejabberd_logger.erl +++ b/src/ejabberd_logger.erl @@ -61,29 +61,66 @@ get_log_path() -> -ifdef(LAGER). +get_pos_integer_env(Name, Default) -> + case application:get_env(ejabberd, Name) of + {ok, I} when is_integer(I), I>0 -> + I; + undefined -> + Default; + {ok, Junk} -> + error_logger:error_msg("wrong value for ~s: ~p; " + "using ~p as a fallback~n", + [Name, Junk, Default]), + Default + end. +get_pos_string_env(Name, Default) -> + case application:get_env(ejabberd, Name) of + {ok, L} when is_list(L) -> + L; + undefined -> + Default; + {ok, Junk} -> + error_logger:error_msg("wrong value for ~s: ~p; " + "using ~p as a fallback~n", + [Name, Junk, Default]), + Default + end. + start() -> + application:load(sasl), + application:set_env(sasl, sasl_error_logger, false), application:load(lager), ConsoleLog = get_log_path(), Dir = filename:dirname(ConsoleLog), ErrorLog = filename:join([Dir, "error.log"]), CrashLog = filename:join([Dir, "crash.log"]), + LogRotateDate = get_pos_string_env(log_rotate_date, ""), + LogRotateSize = get_pos_integer_env(log_rotate_size, 10*1024*1024), + LogRotateCount = get_pos_integer_env(log_rotate_count, 1), + LogRateLimit = get_pos_integer_env(log_rate_limit, 100), + application:set_env(lager, error_logger_hwm, LogRateLimit), application:set_env( lager, handlers, [{lager_console_backend, info}, - {lager_file_backend, [{file, ConsoleLog}, {level, info}, {count, 1}]}, - {lager_file_backend, [{file, ErrorLog}, {level, error}, {count, 1}]}]), + {lager_file_backend, [{file, ConsoleLog}, {level, info}, {date, LogRotateDate}, + {count, LogRotateCount}, {size, LogRotateSize}]}, + {lager_file_backend, [{file, ErrorLog}, {level, error}, {date, LogRotateDate}, + {count, LogRotateCount}, {size, LogRotateSize}]}]), application:set_env(lager, crash_log, CrashLog), + application:set_env(lager, crash_log_date, LogRotateDate), + application:set_env(lager, crash_log_size, LogRotateSize), + application:set_env(lager, crash_log_count, LogRotateCount), ejabberd:start_app(lager), ok. reopen_log() -> + lager_crash_log ! rotate, lists:foreach( fun({lager_file_backend, File}) -> whereis(lager_event) ! {rotate, File}; (_) -> ok - end, gen_event:which_handlers(lager_event)), - reopen_sasl_log(). + end, gen_event:which_handlers(lager_event)). get() -> case lager:get_loglevel(lager_console_backend) of @@ -145,8 +182,6 @@ get() -> set(LogLevel) -> p1_loglevel:set(LogLevel). --endif. - %%%=================================================================== %%% Internal functions %%%=================================================================== @@ -179,3 +214,5 @@ get_sasl_error_logger_type () -> {ok, Bad} -> exit ({bad_config, {sasl, {errlog_type, Bad}}}); _ -> all end. + +-endif. diff --git a/src/ejabberd_odbc.erl b/src/ejabberd_odbc.erl index 2b852bfaa..78be623d4 100644 --- a/src/ejabberd_odbc.erl +++ b/src/ejabberd_odbc.erl @@ -204,7 +204,7 @@ decode_term(Bin) -> %%%---------------------------------------------------------------------- init([Host, StartInterval]) -> case ejabberd_config:get_option( - {keepalive_interval, Host}, + {odbc_keepalive_interval, Host}, fun(I) when is_integer(I), I>0 -> I end) of undefined -> ok; @@ -450,7 +450,7 @@ sql_query_internal(Query) -> ?DEBUG("MySQL, Send query~n~p~n", [Query]), %%squery to be able to specify result_type = binary %%[Query] because p1_mysql_conn expect query to be a list (elements can be binaries, or iolist) - %% but doesn't accept just a binary + %% but doesn't accept just a binary R = mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref, [Query], self(), [{timeout, (?TRANSACTION_TIMEOUT) - 1000}, @@ -553,10 +553,16 @@ mysql_to_odbc({data, MySQLRes}) -> mysql_item_to_odbc(p1_mysql:get_result_field_info(MySQLRes), p1_mysql:get_result_rows(MySQLRes)); mysql_to_odbc({error, MySQLRes}) - when is_binary(MySQLRes) -> + when is_binary(MySQLRes) -> {error, MySQLRes}; +mysql_to_odbc({error, MySQLRes}) + when is_list(MySQLRes) -> + {error, list_to_binary(MySQLRes)}; mysql_to_odbc({error, MySQLRes}) -> - {error, p1_mysql:get_result_reason(MySQLRes)}. + {error, p1_mysql:get_result_reason(MySQLRes)}; +mysql_to_odbc(ok) -> + ok. + %% When tabular data is returned, convert it to the ODBC formalism mysql_item_to_odbc(Columns, Recs) -> @@ -588,7 +594,7 @@ db_opts(Host) -> [odbc, Server]; _ -> Port = ejabberd_config:get_option( - {port, Host}, + {odbc_port, Host}, fun(P) when is_integer(P), P > 0, P < 65536 -> P end, case Type of mysql -> ?MYSQL_PORT; diff --git a/src/ejabberd_riak.erl b/src/ejabberd_riak.erl new file mode 100644 index 000000000..d80a77d3e --- /dev/null +++ b/src/ejabberd_riak.erl @@ -0,0 +1,554 @@ +%%%------------------------------------------------------------------- +%%% @author Alexey Shchepin <alexey@process-one.net> +%%% @doc +%%% Interface for Riak database +%%% @end +%%% Created : 29 Dec 2011 by Alexey Shchepin <alexey@process-one.net> +%%% @copyright (C) 2002-2014 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License +%%% along with this program; if not, write to the Free Software +%%% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +%%% 02111-1307 USA +%%% +%%%------------------------------------------------------------------- +-module(ejabberd_riak). + +-behaviour(gen_server). + +%% API +-export([start_link/4, get_proc/1, make_bucket/1, put/2, put/3, + get/2, get/3, get_by_index/4, delete/1, delete/2, + count_by_index/3, get_by_index_range/5, + get_keys/1, get_keys_by_index/3, is_connected/0, + count/1, delete_by_index/3]). +%% For debugging +-export([get_tables/0]). +%% map/reduce exports +-export([map_key/3]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-include("ejabberd.hrl"). +-include("logger.hrl"). + +-record(state, {pid = self() :: pid()}). + +-type index() :: {binary(), any()}. + +-type index_info() :: [{i, any()} | {'2i', [index()]}]. + +%% The `record_schema()' is just a tuple: +%% {record_info(fields, some_record), #some_record{}} + +-type record_schema() :: {[atom()], tuple()}. + +%% The `index_info()' is used in put/delete functions: +%% `i' defines a primary index, `` '2i' '' defines secondary indexes. +%% There must be only one primary index. If `i' is not specified, +%% the first element of the record is assumed as a primary index, +%% i.e. `i' = element(2, Record). + +-export_types([index_info/0]). + +%%%=================================================================== +%%% API +%%%=================================================================== +%% @private +start_link(Num, Server, Port, _StartInterval) -> + gen_server:start_link({local, get_proc(Num)}, ?MODULE, [Server, Port], []). + +%% @private +is_connected() -> + catch riakc_pb_socket:is_connected(get_random_pid()). + +%% @private +get_proc(I) -> + jlib:binary_to_atom( + iolist_to_binary( + [atom_to_list(?MODULE), $_, integer_to_list(I)])). + +-spec make_bucket(atom()) -> binary(). +%% @doc Makes a bucket from a table name +%% @private +make_bucket(Table) -> + erlang:atom_to_binary(Table, utf8). + +-spec put(tuple(), record_schema()) -> ok | {error, any()}. +%% @equiv put(Record, []) +put(Record, RecFields) -> + ?MODULE:put(Record, RecFields, []). + +-spec put(tuple(), record_schema(), index_info()) -> ok | {error, any()}. +%% @doc Stores a record `Rec' with indexes described in ``IndexInfo'' +put(Rec, RecSchema, IndexInfo) -> + Key = encode_key(proplists:get_value(i, IndexInfo, element(2, Rec))), + SecIdxs = [encode_index_key(K, V) || + {K, V} <- proplists:get_value('2i', IndexInfo, [])], + Table = element(1, Rec), + Value = encode_record(Rec, RecSchema), + case put_raw(Table, Key, Value, SecIdxs) of + ok -> + ok; + {error, _} = Error -> + log_error(Error, put, [{record, Rec}, + {index_info, IndexInfo}]), + Error + end. + +put_raw(Table, Key, Value, Indexes) -> + Bucket = make_bucket(Table), + Obj = riakc_obj:new(Bucket, Key, Value, "application/x-erlang-term"), + Obj1 = if Indexes /= [] -> + MetaData = dict:store(<<"index">>, Indexes, dict:new()), + riakc_obj:update_metadata(Obj, MetaData); + true -> + Obj + end, + catch riakc_pb_socket:put(get_random_pid(), Obj1). + +get_object_raw(Table, Key) -> + Bucket = make_bucket(Table), + catch riakc_pb_socket:get(get_random_pid(), Bucket, Key). + +-spec get(atom(), record_schema()) -> {ok, [any()]} | {error, any()}. +%% @doc Returns all objects from table `Table' +get(Table, RecSchema) -> + Bucket = make_bucket(Table), + case catch riakc_pb_socket:mapred( + get_random_pid(), + Bucket, + [{map, {modfun, riak_kv_mapreduce, map_object_value}, + none, true}]) of + {ok, [{_, Objs}]} -> + {ok, lists:flatmap( + fun(Obj) -> + case catch decode_record(Obj, RecSchema) of + {'EXIT', _} -> + Error = {error, make_invalid_object(Obj)}, + log_error(Error, get, + [{table, Table}]), + []; + Term -> + [Term] + end + end, Objs)}; + {ok, []} -> + {ok, []}; + {error, notfound} -> + {ok, []}; + {error, _} = Error -> + Error + end. + +-spec get(atom(), record_schema(), any()) -> {ok, any()} | {error, any()}. +%% @doc Reads record by `Key' from table `Table' +get(Table, RecSchema, Key) -> + case get_raw(Table, encode_key(Key)) of + {ok, Val} -> + case catch decode_record(Val, RecSchema) of + {'EXIT', _} -> + Error = {error, make_invalid_object(Val)}, + log_error(Error, get, [{table, Table}, {key, Key}]), + {error, notfound}; + Term -> + {ok, Term} + end; + {error, _} = Error -> + log_error(Error, get, [{table, Table}, + {key, Key}]), + Error + end. + +-spec get_by_index(atom(), record_schema(), binary(), any()) -> + {ok, [any()]} | {error, any()}. +%% @doc Reads records by `Index' and value `Key' from `Table' +get_by_index(Table, RecSchema, Index, Key) -> + {NewIndex, NewKey} = encode_index_key(Index, Key), + case get_by_index_raw(Table, NewIndex, NewKey) of + {ok, Vals} -> + {ok, lists:flatmap( + fun(Val) -> + case catch decode_record(Val, RecSchema) of + {'EXIT', _} -> + Error = {error, make_invalid_object(Val)}, + log_error(Error, get_by_index, + [{table, Table}, + {index, Index}, + {key, Key}]), + []; + Term -> + [Term] + end + end, Vals)}; + {error, notfound} -> + {ok, []}; + {error, _} = Error -> + log_error(Error, get_by_index, + [{table, Table}, + {index, Index}, + {key, Key}]), + Error + end. + +-spec get_by_index_range(atom(), record_schema(), binary(), any(), any()) -> + {ok, [any()]} | {error, any()}. +%% @doc Reads records by `Index' in the range `FromKey'..`ToKey' from `Table' +get_by_index_range(Table, RecSchema, Index, FromKey, ToKey) -> + {NewIndex, NewFromKey} = encode_index_key(Index, FromKey), + {NewIndex, NewToKey} = encode_index_key(Index, ToKey), + case get_by_index_range_raw(Table, NewIndex, NewFromKey, NewToKey) of + {ok, Vals} -> + {ok, lists:flatmap( + fun(Val) -> + case catch decode_record(Val, RecSchema) of + {'EXIT', _} -> + Error = {error, make_invalid_object(Val)}, + log_error(Error, get_by_index_range, + [{table, Table}, + {index, Index}, + {start_key, FromKey}, + {end_key, ToKey}]), + []; + Term -> + [Term] + end + end, Vals)}; + {error, notfound} -> + {ok, []}; + {error, _} = Error -> + log_error(Error, get_by_index_range, + [{table, Table}, {index, Index}, + {start_key, FromKey}, {end_key, ToKey}]), + Error + end. + +get_raw(Table, Key) -> + case get_object_raw(Table, Key) of + {ok, Obj} -> + {ok, riakc_obj:get_value(Obj)}; + {error, _} = Error -> + Error + end. + +-spec get_keys(atom()) -> {ok, [any()]} | {error, any()}. +%% @doc Returns a list of index values +get_keys(Table) -> + Bucket = make_bucket(Table), + case catch riakc_pb_socket:mapred( + get_random_pid(), + Bucket, + [{map, {modfun, ?MODULE, map_key}, none, true}]) of + {ok, [{_, Keys}]} -> + {ok, Keys}; + {ok, []} -> + {ok, []}; + {error, _} = Error -> + log_error(Error, get_keys, [{table, Table}]), + Error + end. + +-spec get_keys_by_index(atom(), binary(), + any()) -> {ok, [any()]} | {error, any()}. +%% @doc Returns a list of primary keys of objects indexed by `Key'. +get_keys_by_index(Table, Index, Key) -> + {NewIndex, NewKey} = encode_index_key(Index, Key), + Bucket = make_bucket(Table), + case catch riakc_pb_socket:mapred( + get_random_pid(), + {index, Bucket, NewIndex, NewKey}, + [{map, {modfun, ?MODULE, map_key}, none, true}]) of + {ok, [{_, Keys}]} -> + {ok, Keys}; + {ok, []} -> + {ok, []}; + {error, _} = Error -> + log_error(Error, get_keys_by_index, [{table, Table}, + {index, Index}, + {key, Key}]), + Error + end. + +%% @hidden +get_tables() -> + catch riakc_pb_socket:list_buckets(get_random_pid()). + +get_by_index_raw(Table, Index, Key) -> + Bucket = make_bucket(Table), + case riakc_pb_socket:mapred( + get_random_pid(), + {index, Bucket, Index, Key}, + [{map, {modfun, riak_kv_mapreduce, map_object_value}, + none, true}]) of + {ok, [{_, Objs}]} -> + {ok, Objs}; + {ok, []} -> + {ok, []}; + {error, _} = Error -> + Error + end. + +get_by_index_range_raw(Table, Index, FromKey, ToKey) -> + Bucket = make_bucket(Table), + case catch riakc_pb_socket:mapred( + get_random_pid(), + {index, Bucket, Index, FromKey, ToKey}, + [{map, {modfun, riak_kv_mapreduce, map_object_value}, + none, true}]) of + {ok, [{_, Objs}]} -> + {ok, Objs}; + {ok, []} -> + {ok, []}; + {error, _} = Error -> + Error + end. + +-spec count(atom()) -> {ok, non_neg_integer()} | {error, any()}. +%% @doc Returns the number of objects in the `Table' +count(Table) -> + Bucket = make_bucket(Table), + case catch riakc_pb_socket:mapred( + get_random_pid(), + Bucket, + [{reduce, {modfun, riak_kv_mapreduce, reduce_count_inputs}, + none, true}]) of + {ok, [{_, [Cnt]}]} -> + {ok, Cnt}; + {error, _} = Error -> + log_error(Error, count, [{table, Table}]), + Error + end. + +-spec count_by_index(atom(), binary(), any()) -> + {ok, non_neg_integer()} | {error, any()}. +%% @doc Returns the number of objects in the `Table' by index +count_by_index(Tab, Index, Key) -> + {NewIndex, NewKey} = encode_index_key(Index, Key), + case count_by_index_raw(Tab, NewIndex, NewKey) of + {ok, Cnt} -> + {ok, Cnt}; + {error, notfound} -> + {ok, 0}; + {error, _} = Error -> + log_error(Error, count_by_index, + [{table, Tab}, + {index, Index}, + {key, Key}]), + Error + end. + +count_by_index_raw(Table, Index, Key) -> + Bucket = make_bucket(Table), + case catch riakc_pb_socket:mapred( + get_random_pid(), + {index, Bucket, Index, Key}, + [{reduce, {modfun, riak_kv_mapreduce, reduce_count_inputs}, + none, true}]) of + {ok, [{_, [Cnt]}]} -> + {ok, Cnt}; + {error, _} = Error -> + Error + end. + +-spec delete(tuple() | atom()) -> ok | {error, any()}. +%% @doc Same as delete(T, []) when T is record. +%% Or deletes all elements from table if T is atom. +delete(Rec) when is_tuple(Rec) -> + delete(Rec, []); +delete(Table) when is_atom(Table) -> + try + {ok, Keys} = ?MODULE:get_keys(Table), + lists:foreach( + fun(K) -> + ok = delete(Table, K) + end, Keys) + catch _:{badmatch, Err} -> + Err + end. + +-spec delete(tuple() | atom(), index_info() | any()) -> ok | {error, any()}. +%% @doc Delete an object +delete(Rec, Opts) when is_tuple(Rec) -> + Table = element(1, Rec), + Key = proplists:get_value(i, Opts, element(2, Rec)), + delete(Table, Key); +delete(Table, Key) when is_atom(Table) -> + case delete_raw(Table, encode_key(Key)) of + ok -> + ok; + Err -> + log_error(Err, delete, [{table, Table}, {key, Key}]), + Err + end. + +delete_raw(Table, Key) -> + Bucket = make_bucket(Table), + catch riakc_pb_socket:delete(get_random_pid(), Bucket, Key). + +-spec delete_by_index(atom(), binary(), any()) -> ok | {error, any()}. +%% @doc Deletes objects by index +delete_by_index(Table, Index, Key) -> + try + {ok, Keys} = get_keys_by_index(Table, Index, Key), + lists:foreach( + fun(K) -> + ok = delete(Table, K) + end, Keys) + catch _:{badmatch, Err} -> + Err + end. + +%%%=================================================================== +%%% map/reduce functions +%%%=================================================================== +%% @private +map_key(Obj, _, _) -> + [case riak_object:key(Obj) of + <<"b_", B/binary>> -> + B; + <<"i_", B/binary>> -> + list_to_integer(binary_to_list(B)); + B -> + erlang:binary_to_term(B) + end]. + +%%%=================================================================== +%%% gen_server API +%%%=================================================================== +%% @private +init([Server, Port]) -> + case riakc_pb_socket:start( + Server, Port, + [auto_reconnect]) of + {ok, Pid} -> + erlang:monitor(process, Pid), + {ok, #state{pid = Pid}}; + Err -> + {stop, Err} + end. + +%% @private +handle_call(get_pid, _From, #state{pid = Pid} = State) -> + {reply, {ok, Pid}, State}; +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +%% @private +handle_cast(_Msg, State) -> + {noreply, State}. + +%% @private +handle_info({'DOWN', _MonitorRef, _Type, _Object, _Info}, State) -> + {stop, normal, State}; +handle_info(_Info, State) -> + ?ERROR_MSG("unexpected info: ~p", [_Info]), + {noreply, State}. + +%% @private +terminate(_Reason, _State) -> + ok. + +%% @private +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +encode_index_key(Idx, Key) when is_integer(Key) -> + {<<Idx/binary, "_int">>, Key}; +encode_index_key(Idx, Key) -> + {<<Idx/binary, "_bin">>, encode_key(Key)}. + +encode_key(Bin) when is_binary(Bin) -> + <<"b_", Bin/binary>>; +encode_key(Int) when is_integer(Int) -> + <<"i_", (list_to_binary(integer_to_list(Int)))/binary>>; +encode_key(Term) -> + erlang:term_to_binary(Term). + +log_error({error, notfound}, _, _) -> + ok; +log_error({error, Why} = Err, Function, Opts) -> + Txt = lists:map( + fun({table, Table}) -> + io_lib:fwrite("** Table: ~p~n", [Table]); + ({key, Key}) -> + io_lib:fwrite("** Key: ~p~n", [Key]); + ({index, Index}) -> + io_lib:fwrite("** Index = ~p~n", [Index]); + ({start_key, Key}) -> + io_lib:fwrite("** Start Key: ~p~n", [Key]); + ({end_key, Key}) -> + io_lib:fwrite("** End Key: ~p~n", [Key]); + ({record, Rec}) -> + io_lib:fwrite("** Record = ~p~n", [Rec]); + ({index_info, IdxInfo}) -> + io_lib:fwrite("** Index info = ~p~n", [IdxInfo]); + (_) -> + "" + end, Opts), + ErrTxt = if is_binary(Why) -> + io_lib:fwrite("** Error: ~s", [Why]); + true -> + io_lib:fwrite("** Error: ~p", [Err]) + end, + ?ERROR_MSG("database error:~n** Function: ~p~n~s~s", + [Function, Txt, ErrTxt]); +log_error(_, _, _) -> + ok. + +make_invalid_object(Val) -> + list_to_binary(io_lib:fwrite("Invalid object: ~p", [Val])). + +get_random_pid() -> + PoolPid = ejabberd_riak_sup:get_random_pid(), + case catch gen_server:call(PoolPid, get_pid) of + {ok, Pid} -> + Pid; + {'EXIT', {timeout, _}} -> + throw({error, timeout}); + {'EXIT', Err} -> + throw({error, Err}) + end. + +encode_record(Rec, {Fields, DefRec}) -> + term_to_binary(encode_record(Rec, Fields, DefRec, 2)). + +encode_record(Rec, [FieldName|Fields], DefRec, Pos) -> + Value = element(Pos, Rec), + DefValue = element(Pos, DefRec), + if Value == DefValue -> + encode_record(Rec, Fields, DefRec, Pos+1); + true -> + [{FieldName, Value}|encode_record(Rec, Fields, DefRec, Pos+1)] + end; +encode_record(_, [], _, _) -> + []. + +decode_record(Bin, {Fields, DefRec}) -> + decode_record(binary_to_term(Bin), Fields, DefRec, 2). + +decode_record(KeyVals, [FieldName|Fields], Rec, Pos) -> + case lists:keyfind(FieldName, 1, KeyVals) of + {_, Value} -> + NewRec = setelement(Pos, Rec, Value), + decode_record(KeyVals, Fields, NewRec, Pos+1); + false -> + decode_record(KeyVals, Fields, Rec, Pos+1) + end; +decode_record(_, [], Rec, _) -> + Rec. diff --git a/src/ejabberd_riak_sup.erl b/src/ejabberd_riak_sup.erl new file mode 100644 index 000000000..513ad785f --- /dev/null +++ b/src/ejabberd_riak_sup.erl @@ -0,0 +1,161 @@ +%%%---------------------------------------------------------------------- +%%% File : ejabberd_riak_sup.erl +%%% Author : Alexey Shchepin <alexey@process-one.net> +%%% Purpose : Riak connections supervisor +%%% Created : 29 Dec 2011 by Alexey Shchepin <alexey@process-one.net> +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2011 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License +%%% along with this program; if not, write to the Free Software +%%% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +%%% 02111-1307 USA +%%% +%%%---------------------------------------------------------------------- + +-module(ejabberd_riak_sup). +-author('alexey@process-one.net'). + +%% API +-export([start/0, + start_link/0, + init/1, + get_pids/0, + transform_options/1, + get_random_pid/0, + get_random_pid/1 + ]). + +-include("ejabberd.hrl"). +-include("logger.hrl"). + +-define(DEFAULT_POOL_SIZE, 10). +-define(DEFAULT_RIAK_START_INTERVAL, 30). % 30 seconds +-define(DEFAULT_RIAK_HOST, "127.0.0.1"). +-define(DEFAULT_RIAK_PORT, 8087). + +% time to wait for the supervisor to start its child before returning +% a timeout error to the request +-define(CONNECT_TIMEOUT, 500). % milliseconds + +start() -> + case lists:any( + fun(Host) -> + is_riak_configured(Host) + end, ?MYHOSTS) of + true -> + ejabberd:start_app(riakc), + do_start(); + false -> + ok + end. + +is_riak_configured(Host) -> + ServerConfigured = ejabberd_config:get_option( + {riak_server, Host}, + fun(_) -> true end, false), + PortConfigured = ejabberd_config:get_option( + {riak_port, Host}, + fun(_) -> true end, false), + AuthConfigured = lists:member( + ejabberd_auth_riak, + ejabberd_auth:auth_modules(Host)), + Modules = ejabberd_config:get_option( + {modules, Host}, + fun(L) when is_list(L) -> L end, []), + ModuleWithRiakDBConfigured = lists:any( + fun({_Module, Opts}) -> + gen_mod:db_type(Opts) == riak + end, Modules), + ServerConfigured or PortConfigured + or AuthConfigured or ModuleWithRiakDBConfigured. + +do_start() -> + SupervisorName = ?MODULE, + ChildSpec = + {SupervisorName, + {?MODULE, start_link, []}, + transient, + infinity, + supervisor, + [?MODULE]}, + case supervisor:start_child(ejabberd_sup, ChildSpec) of + {ok, _PID} -> + ok; + _Error -> + ?ERROR_MSG("Start of supervisor ~p failed:~n~p~nRetrying...~n", + [SupervisorName, _Error]), + timer:sleep(5000), + start() + end. + +start_link() -> + supervisor:start_link({local, ?MODULE}, ?MODULE, []). + +init([]) -> + PoolSize = get_pool_size(), + StartInterval = get_start_interval(), + Server = get_riak_server(), + Port = get_riak_port(), + {ok, {{one_for_one, PoolSize*10, 1}, + lists:map( + fun(I) -> + {ejabberd_riak:get_proc(I), + {ejabberd_riak, start_link, + [I, Server, Port, StartInterval*1000]}, + transient, 2000, worker, [?MODULE]} + end, lists:seq(1, PoolSize))}}. + +get_start_interval() -> + ejabberd_config:get_option( + riak_start_interval, + fun(N) when is_integer(N), N >= 1 -> N end, + ?DEFAULT_RIAK_START_INTERVAL). + +get_pool_size() -> + ejabberd_config:get_option( + riak_pool_size, + fun(N) when is_integer(N), N >= 1 -> N end, + ?DEFAULT_POOL_SIZE). + +get_riak_server() -> + ejabberd_config:get_option( + riak_server, + fun(S) -> + binary_to_list(iolist_to_binary(S)) + end, ?DEFAULT_RIAK_HOST). + +get_riak_port() -> + ejabberd_config:get_option( + riak_port, + fun(P) when is_integer(P), P > 0, P < 65536 -> P end, + ?DEFAULT_RIAK_PORT). + +get_pids() -> + [ejabberd_riak:get_proc(I) || I <- lists:seq(1, get_pool_size())]. + +get_random_pid() -> + get_random_pid(now()). + +get_random_pid(Term) -> + I = erlang:phash2(Term, get_pool_size()) + 1, + ejabberd_riak:get_proc(I). + +transform_options(Opts) -> + lists:foldl(fun transform_options/2, [], Opts). + +transform_options({riak_server, {S, P}}, Opts) -> + [{riak_server, S}, {riak_port, P}|Opts]; +transform_options(Opt, Opts) -> + [Opt|Opts]. diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index 4fde814fe..6c594185a 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -374,8 +374,8 @@ wait_for_feature_request({xmlstreamelement, El}, #xmlel{name = <<"success">>, attrs = [{<<"xmlns">>, ?NS_SASL}], children = []}), - ?DEBUG("(~w) Accepted s2s authentication for ~s", - [StateData#state.socket, AuthDomain]), + ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)", + [AuthDomain, StateData#state.tls_enabled]), change_shaper(StateData, <<"">>, jlib:make_jid(<<"">>, AuthDomain, <<"">>)), {next_state, wait_for_stream, @@ -515,6 +515,8 @@ stream_established({valid, From, To}, StateData) -> [{<<"from">>, To}, {<<"to">>, From}, {<<"type">>, <<"valid">>}], children = []}), + ?INFO_MSG("Accepted s2s dialback authentication for ~s (TLS=~p)", + [From, StateData#state.tls_enabled]), LFrom = jlib:nameprep(From), LTo = jlib:nameprep(To), NSD = StateData#state{connections = diff --git a/src/ejabberd_sm.erl b/src/ejabberd_sm.erl index 3e37f4692..12f2a7708 100644 --- a/src/ejabberd_sm.erl +++ b/src/ejabberd_sm.erl @@ -849,7 +849,7 @@ kick_user(User, Server) -> lists:foreach( fun(Resource) -> PID = get_session_pid(User, Server, Resource), - PID ! disconnect + PID ! kick end, Resources), length(Resources). diff --git a/src/ejabberd_xmlrpc.erl b/src/ejabberd_xmlrpc.erl index ff89d2858..b59001819 100644 --- a/src/ejabberd_xmlrpc.erl +++ b/src/ejabberd_xmlrpc.erl @@ -17,11 +17,12 @@ -author('badlop@process-one.net'). --export([start/2, handler/2, socket_type/0, transform_listen_option/2]). +-export([start/2, handler/2, process/2, socket_type/0, + transform_listen_option/2]). -include("ejabberd.hrl"). -include("logger.hrl"). - +-include("ejabberd_http.hrl"). -include("mod_roster.hrl"). -include("jlib.hrl"). @@ -170,12 +171,14 @@ %% ----------------------------- start({gen_tcp = _SockMod, Socket}, Opts) -> - %MaxSessions = gen_mod:get_opt(maxsessions, Opts, - % fun(I) when is_integer(I), I>0 -> I end, - % 10), - Timeout = gen_mod:get_opt(timeout, Opts, - fun(I) when is_integer(I), I>0 -> I end, - 5000), + ejabberd_http:start({gen_tcp, Socket}, [{xmlrpc, true}|Opts]). + +socket_type() -> raw. + +%% ----------------------------- +%% HTTP interface +%% ----------------------------- +process(_, #request{method = 'POST', data = Data, opts = Opts}) -> AccessCommandsOpts = gen_mod:get_opt(access_commands, Opts, fun(L) when is_list(L) -> L end, []), @@ -201,19 +204,36 @@ start({gen_tcp = _SockMod, Socket}, Opts) -> [?MODULE, Wrong]), [] end, AccessCommandsOpts), - GetAuth = case [ACom - || {Ac, _, _} = ACom <- AccessCommands, Ac /= all] - of - [] -> false; - _ -> true + GetAuth = case [ACom || {Ac, _, _} = ACom <- AccessCommands, Ac /= all] of + [] -> false; + _ -> true end, - Handler = {?MODULE, handler}, - State = #state{access_commands = AccessCommands, - get_auth = GetAuth}, - Pid = proc_lib:spawn(xmlrpc_http, handler, [Socket, Timeout, Handler, State]), - {ok, Pid}. - -socket_type() -> raw. + State = #state{access_commands = AccessCommands, get_auth = GetAuth}, + case xml_stream:parse_element(Data) of + {error, _} -> + {400, [], + #xmlel{name = <<"h1">>, attrs = [], + children = [{xmlcdata, <<"Malformed XML">>}]}}; + El -> + case p1_xmlrpc:decode(El) of + {error, _} = Err -> + ?ERROR_MSG("XML-RPC request ~s failed with reason: ~p", + [Data, Err]), + {400, [], + #xmlel{name = <<"h1">>, attrs = [], + children = [{xmlcdata, <<"Malformed Request">>}]}}; + {ok, RPC} -> + ?DEBUG("got XML-RPC request: ~p", [RPC]), + {false, Result} = handler(State, RPC), + XML = xml:element_to_binary(p1_xmlrpc:encode(Result)), + {200, [{<<"Content-Type">>, <<"text/xml">>}], + <<"<?xml version=\"1.0\"?>", XML/binary>>} + end + end; +process(_, _) -> + {400, [], + #xmlel{name = <<"h1">>, attrs = [], + children = [{xmlcdata, <<"400 Bad Request">>}]}}. %% ----------------------------- %% Access verification @@ -428,8 +448,8 @@ format_arg({array, Elements}, {list, ElementsDef}) format_arg(Arg, integer) when is_integer(Arg) -> Arg; format_arg(Arg, binary) when is_list(Arg) -> list_to_binary(Arg); format_arg(Arg, binary) when is_binary(Arg) -> Arg; -format_arg(Arg, string) when is_list(Arg) -> list_to_binary(Arg); -format_arg(Arg, string) when is_binary(Arg) -> Arg; +format_arg(Arg, string) when is_list(Arg) -> Arg; +format_arg(Arg, string) when is_binary(Arg) -> binary_to_list(Arg); format_arg(Arg, Format) -> ?ERROR_MSG("don't know how to format Arg ~p for format ~p", [Arg, Format]), throw({error_formatting_argument, Arg, Format}). @@ -450,6 +470,10 @@ format_result(String, {Name, string}) when is_list(String) -> {struct, [{Name, lists:flatten(String)}]}; format_result(Binary, {Name, string}) when is_binary(Binary) -> {struct, [{Name, binary_to_list(Binary)}]}; +format_result(String, {Name, binary}) when is_list(String) -> + {struct, [{Name, lists:flatten(String)}]}; +format_result(Binary, {Name, binary}) when is_binary(Binary) -> + {struct, [{Name, binary_to_list(Binary)}]}; format_result(Code, {Name, rescode}) -> {struct, [{Name, make_status(Code)}]}; format_result({Code, Text}, {Name, restuple}) -> diff --git a/src/ejd2odbc.erl b/src/ejd2odbc.erl index 892b8df62..a5c10560d 100644 --- a/src/ejd2odbc.erl +++ b/src/ejd2odbc.erl @@ -48,7 +48,6 @@ modules() -> [ejabberd_auth, mod_announce, - mod_caps, mod_irc, mod_last, mod_muc, diff --git a/src/eldap_filter_yecc.yrl b/src/eldap_filter_yecc.yrl index a70ea3e74..fe2a075c0 100644 --- a/src/eldap_filter_yecc.yrl +++ b/src/eldap_filter_yecc.yrl @@ -38,9 +38,9 @@ any -> '$empty': []. initial -> value: initial('$1'). final -> value: final('$1'). -extensible -> xattr ':dn' ':' matchingrule ':=' value: extensible('$6', ['$1', '$4']). +extensible -> xattr ':dn' ':' matchingrule ':=' value: extensible('$6', ['$1', '$4', {dnAttributes, true}]). extensible -> xattr ':' matchingrule ':=' value: extensible('$5', ['$1', '$3']). -extensible -> xattr ':dn' ':=' value: extensible('$4', ['$1']). +extensible -> xattr ':dn' ':=' value: extensible('$4', ['$1', {dnAttributes, true}]). extensible -> xattr ':=' value: extensible('$3', ['$1']). extensible -> ':dn' ':' matchingrule ':=' value: extensible('$5', ['$3']). extensible -> ':' matchingrule ':=' value: extensible('$4', ['$2']). diff --git a/src/eldap_utils.erl b/src/eldap_utils.erl index a87023437..e6e874a63 100644 --- a/src/eldap_utils.erl +++ b/src/eldap_utils.erl @@ -228,13 +228,28 @@ get_config(Host, Opts) -> Base = get_opt({ldap_base, Host}, Opts, fun iolist_to_binary/1, <<"">>), - DerefAliases = get_opt({deref_aliases, Host}, Opts, - fun(never) -> never; - (searching) -> searching; - (finding) -> finding; - (always) -> always - end, never), - #eldap_config{servers = Servers, + OldDerefAliases = get_opt({deref_aliases, Host}, Opts, + fun(never) -> never; + (searching) -> searching; + (finding) -> finding; + (always) -> always + end, unspecified), + DerefAliases = + if OldDerefAliases == unspecified -> + get_opt({ldap_deref_aliases, Host}, Opts, + fun(never) -> never; + (searching) -> searching; + (finding) -> finding; + (always) -> always + end, never); + true -> + ?WARNING_MSG("Option 'deref_aliases' is deprecated. " + "The option is still supported " + "but it is better to fix your config: " + "use 'ldap_deref_aliases' instead.", []), + OldDerefAliases + end, + #eldap_config{servers = Servers, backups = Backups, tls_options = [{encrypt, Encrypt}, {tls_verify, TLSVerify}, diff --git a/src/gen_mod.erl b/src/gen_mod.erl index 261e6c6dd..00a716746 100644 --- a/src/gen_mod.erl +++ b/src/gen_mod.erl @@ -28,7 +28,7 @@ -author('alexey@process-one.net'). --export([start/0, start_module/3, stop_module/2, +-export([start/0, start_module/2, start_module/3, stop_module/2, stop_module_keep_config/2, get_opt/3, get_opt/4, get_opt_host/3, db_type/1, db_type/2, get_module_opt/5, get_module_opt_host/3, loaded_modules/1, @@ -60,6 +60,19 @@ start() -> {keypos, #ejabberd_module.module_host}]), ok. +-spec start_module(binary(), atom()) -> any(). + +start_module(Host, Module) -> + Modules = ejabberd_config:get_option( + {modules, Host}, + fun(L) when is_list(L) -> L end, []), + case lists:keyfind(Module, 1, Modules) of + {_, Opts} -> + start_module(Host, Module, Opts); + false -> + {error, not_found_in_config} + end. + -spec start_module(binary(), atom(), opts()) -> any(). start_module(Host, Module, Opts) -> @@ -196,22 +209,26 @@ get_opt_host(Host, Opts, Default) -> Val = get_opt(host, Opts, fun iolist_to_binary/1, Default), ejabberd_regexp:greplace(Val, <<"@HOST@">>, Host). --spec db_type(opts()) -> odbc | mnesia. +-spec db_type(opts()) -> odbc | mnesia | riak. db_type(Opts) -> get_opt(db_type, Opts, fun(odbc) -> odbc; (internal) -> mnesia; - (mnesia) -> mnesia end, + (mnesia) -> mnesia; + (riak) -> riak + end, mnesia). --spec db_type(binary(), atom()) -> odbc | mnesia. +-spec db_type(binary(), atom()) -> odbc | mnesia | riak. db_type(Host, Module) -> get_module_opt(Host, Module, db_type, fun(odbc) -> odbc; (internal) -> mnesia; - (mnesia) -> mnesia end, + (mnesia) -> mnesia; + (riak) -> riak + end, mnesia). -spec loaded_modules(binary()) -> [atom()]. diff --git a/src/jlib.erl b/src/jlib.erl index 7735d7dbc..be1da3fd0 100644 --- a/src/jlib.erl +++ b/src/jlib.erl @@ -798,7 +798,12 @@ base64_to_term(Base64) -> -spec decode_base64(binary()) -> binary(). decode_base64(S) -> - decode_base64_bin(S, <<>>). + case catch binary:last(S) of + C when C == $\n; C == $\s -> + decode_base64(binary:part(S, 0, byte_size(S) - 1)); + _ -> + decode_base64_bin(S, <<>>) + end. take_without_spaces(Bin, Count) -> take_without_spaces(Bin, Count, <<>>). diff --git a/src/mod_announce.erl b/src/mod_announce.erl index fba6d3b81..40204da80 100644 --- a/src/mod_announce.erl +++ b/src/mod_announce.erl @@ -792,6 +792,18 @@ announce_motd(Host, Packet) -> end, Sessions) end, mnesia:transaction(F); + riak -> + try + lists:foreach( + fun({U, S, _R}) -> + ok = ejabberd_riak:put(#motd_users{us = {U, S}}, + motd_users_schema(), + [{'2i', [{<<"server">>, S}]}]) + end, Sessions), + {atomic, ok} + catch _:{badmatch, Err} -> + {atomic, Err} + end; odbc -> F = fun() -> lists:foreach( @@ -837,6 +849,10 @@ announce_motd_update(LServer, Packet) -> mnesia:write(#motd{server = LServer, packet = Packet}) end, mnesia:transaction(F); + riak -> + {atomic, ejabberd_riak:put(#motd{server = LServer, + packet = Packet}, + motd_schema())}; odbc -> XML = ejabberd_odbc:escape(xml:element_to_binary(Packet)), F = fun() -> @@ -887,6 +903,16 @@ announce_motd_delete(LServer) -> end, Users) end, mnesia:transaction(F); + riak -> + try + ok = ejabberd_riak:delete(motd, LServer), + ok = ejabberd_riak:delete_by_index(motd_users, + <<"server">>, + LServer), + {atomic, ok} + catch _:{badmatch, Err} -> + {atomic, Err} + end; odbc -> F = fun() -> ejabberd_odbc:sql_query_t([<<"delete from motd;">>]) @@ -915,6 +941,23 @@ send_motd(#jid{luser = LUser, lserver = LServer} = JID, mnesia) -> _ -> ok end; +send_motd(#jid{luser = LUser, lserver = LServer} = JID, riak) -> + case catch ejabberd_riak:get(motd, motd_schema(), LServer) of + {ok, #motd{packet = Packet}} -> + US = {LUser, LServer}, + case ejabberd_riak:get(motd_users, motd_users_schema(), US) of + {ok, #motd_users{}} -> + ok; + _ -> + Local = jlib:make_jid(<<>>, LServer, <<>>), + ejabberd_router:route(Local, JID, Packet), + {atomic, ejabberd_riak:put( + #motd_users{us = US}, motd_users_schema(), + [{'2i', [{<<"server">>, LServer}]}])} + end; + _ -> + ok + end; send_motd(#jid{luser = LUser, lserver = LServer} = JID, odbc) when LUser /= <<>> -> case catch ejabberd_odbc:sql_query( LServer, [<<"select xml from motd where username='';">>]) of @@ -965,6 +1008,13 @@ get_stored_motd_packet(LServer, mnesia) -> _ -> error end; +get_stored_motd_packet(LServer, riak) -> + case ejabberd_riak:get(motd, motd_schema(), LServer) of + {ok, #motd{packet = Packet}} -> + {ok, Packet}; + _ -> + error + end; get_stored_motd_packet(LServer, odbc) -> case catch ejabberd_odbc:sql_query( LServer, [<<"select xml from motd where username='';">>]) of @@ -1052,6 +1102,12 @@ update_motd_users_table() -> mnesia:transform_table(motd_users, ignore, Fields) end. +motd_schema() -> + {record_info(fields, motd), #motd{}}. + +motd_users_schema() -> + {record_info(fields, motd_users), #motd_users{}}. + export(_Server) -> [{motd, fun(Host, #motd{server = LServer, packet = El}) @@ -1089,5 +1145,10 @@ import(_LServer, mnesia, #motd{} = Motd) -> mnesia:dirty_write(Motd); import(_LServer, mnesia, #motd_users{} = Users) -> mnesia:dirty_write(Users); +import(_LServer, riak, #motd{} = Motd) -> + ejabberd_riak:put(Motd, motd_schema()); +import(_LServer, riak, #motd_users{us = {_, S}} = Users) -> + ejabberd_riak:put(Users, motd_users_schema(), + [{'2i', [{<<"server">>, S}]}]); import(_, _, _) -> pass. diff --git a/src/mod_blocking.erl b/src/mod_blocking.erl index 797b7573b..0f57ce723 100644 --- a/src/mod_blocking.erl +++ b/src/mod_blocking.erl @@ -181,6 +181,39 @@ process_blocklist_block(LUser, LServer, Filter, {ok, NewDefault, NewList} end, mnesia:transaction(F); +process_blocklist_block(LUser, LServer, Filter, + riak) -> + {atomic, + begin + case ejabberd_riak:get(privacy, mod_privacy:privacy_schema(), + {LUser, LServer}) of + {ok, #privacy{default = Default, lists = Lists} = P} -> + case lists:keysearch(Default, 1, Lists) of + {value, {_, List}} -> + NewDefault = Default, + NewLists1 = lists:keydelete(Default, 1, Lists); + false -> + NewDefault = <<"Blocked contacts">>, + NewLists1 = Lists, + List = [] + end; + {error, _} -> + P = #privacy{us = {LUser, LServer}}, + NewDefault = <<"Blocked contacts">>, + NewLists1 = [], + List = [] + end, + NewList = Filter(List), + NewLists = [{NewDefault, NewList} | NewLists1], + case ejabberd_riak:put(P#privacy{default = NewDefault, + lists = NewLists}, + mod_privacy:privacy_schema()) of + ok -> + {ok, NewDefault, NewList}; + Err -> + Err + end + end}; process_blocklist_block(LUser, LServer, Filter, odbc) -> F = fun () -> Default = case @@ -257,6 +290,31 @@ process_blocklist_unblock_all(LUser, LServer, Filter, end, mnesia:transaction(F); process_blocklist_unblock_all(LUser, LServer, Filter, + riak) -> + {atomic, + case ejabberd_riak:get(privacy, {LUser, LServer}) of + {ok, #privacy{default = Default, lists = Lists} = P} -> + case lists:keysearch(Default, 1, Lists) of + {value, {_, List}} -> + NewList = Filter(List), + NewLists1 = lists:keydelete(Default, 1, Lists), + NewLists = [{Default, NewList} | NewLists1], + case ejabberd_riak:put(P#privacy{lists = NewLists}, + mod_privacy:privacy_schema()) of + ok -> + {ok, Default, NewList}; + Err -> + Err + end; + false -> + %% No default list, nothing to unblock + ok + end; + {error, _} -> + %% No lists, nothing to unblock + ok + end}; +process_blocklist_unblock_all(LUser, LServer, Filter, odbc) -> F = fun () -> case mod_privacy:sql_get_default_privacy_list_t(LUser) @@ -332,6 +390,32 @@ process_blocklist_unblock(LUser, LServer, Filter, end, mnesia:transaction(F); process_blocklist_unblock(LUser, LServer, Filter, + riak) -> + {atomic, + case ejabberd_riak:get(privacy, mod_privacy:privacy_schema(), + {LUser, LServer}) of + {error, _} -> + %% No lists, nothing to unblock + ok; + {ok, #privacy{default = Default, lists = Lists} = P} -> + case lists:keysearch(Default, 1, Lists) of + {value, {_, List}} -> + NewList = Filter(List), + NewLists1 = lists:keydelete(Default, 1, Lists), + NewLists = [{Default, NewList} | NewLists1], + case ejabberd_riak:put(P#privacy{lists = NewLists}, + mod_privacy:privacy_schema()) of + ok -> + {ok, Default, NewList}; + Err -> + Err + end; + false -> + %% No default list, nothing to unblock + ok + end + end}; +process_blocklist_unblock(LUser, LServer, Filter, odbc) -> F = fun () -> case mod_privacy:sql_get_default_privacy_list_t(LUser) @@ -409,6 +493,19 @@ process_blocklist_get(LUser, LServer, mnesia) -> _ -> [] end end; +process_blocklist_get(LUser, LServer, riak) -> + case ejabberd_riak:get(privacy, mod_privacy:privacy_schema(), + {LUser, LServer}) of + {ok, #privacy{default = Default, lists = Lists}} -> + case lists:keysearch(Default, 1, Lists) of + {value, {_, List}} -> List; + _ -> [] + end; + {error, notfound} -> + []; + {error, _} -> + error + end; process_blocklist_get(LUser, LServer, odbc) -> case catch mod_privacy:sql_get_default_privacy_list(LUser, LServer) diff --git a/src/mod_caps.erl b/src/mod_caps.erl index 5f529bd28..1002df444 100644 --- a/src/mod_caps.erl +++ b/src/mod_caps.erl @@ -17,9 +17,10 @@ %%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU %%% General Public License for more details. %%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% You should have received a copy of the GNU General Public License +%%% along with this program; if not, write to the Free Software +%%% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +%%% 02111-1307 USA %%% %%% 2009, improvements from ProcessOne to support correct PEP handling %%% through s2s, use less memory, and speedup global caps handling @@ -35,7 +36,8 @@ -export([read_caps/1, caps_stream_features/2, disco_features/5, disco_identity/5, disco_info/5, - get_features/1]). + get_features/2, export/1, import_info/0, import/5, + import_start/2, import_stop/2]). %% gen_mod callbacks -export([start/2, start_link/2, stop/1]). @@ -45,10 +47,9 @@ handle_cast/2, terminate/2, code_change/3]). %% hook handlers --export([user_send_packet/3, - user_receive_packet/4, - c2s_presence_in/2, - c2s_broadcast_recipients/5]). +-export([user_send_packet/3, user_receive_packet/4, + c2s_presence_in/2, c2s_filter_packet/6, + c2s_broadcast_recipients/6]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -79,9 +80,6 @@ -record(state, {host = <<"">> :: binary()}). -%%==================================================================== -%% API -%%==================================================================== start_link(Host, Opts) -> Proc = gen_mod:get_module_proc(Host, ?PROCNAME), gen_server:start_link({local, Proc}, ?MODULE, @@ -99,20 +97,14 @@ stop(Host) -> supervisor:terminate_child(ejabberd_sup, Proc), supervisor:delete_child(ejabberd_sup, Proc). -%% get_features returns a list of features implied by the given caps -%% record (as extracted by read_caps) or 'unknown' if features are -%% not completely collected at the moment. -get_features(nothing) -> []; -get_features(#caps{node = Node, version = Version, exts = Exts}) -> +get_features(_Host, nothing) -> []; +get_features(Host, #caps{node = Node, version = Version, + exts = Exts}) -> SubNodes = [Version | Exts], -%% read_caps takes a list of XML elements (the child elements of a -%% <presence/> stanza) and returns an opaque value representing the -%% Entity Capabilities contained therein, or the atom nothing if no -%% capabilities are advertised. lists:foldl(fun (SubNode, Acc) -> NodePair = {Node, SubNode}, case cache_tab:lookup(caps_features, NodePair, - caps_read_fun(NodePair)) + caps_read_fun(Host, NodePair)) of {ok, Features} when is_list(Features) -> Features ++ Acc; @@ -121,6 +113,8 @@ get_features(#caps{node = Node, version = Version, exts = Exts}) -> end, [], SubNodes). +-spec read_caps([xmlel()]) -> nothing | caps(). + read_caps(Els) -> read_caps(Els, nothing). read_caps([#xmlel{name = <<"c">>, attrs = Attrs} @@ -149,13 +143,11 @@ read_caps([_ | Tail], Result) -> read_caps(Tail, Result); read_caps([], Result) -> Result. -%%==================================================================== -%% Hooks -%%==================================================================== -user_send_packet( - #jid{luser = User, lserver = Server} = From, - #jid{luser = User, lserver = Server, lresource = <<"">>}, - #xmlel{name = <<"presence">>, attrs = Attrs, children = Els}) -> +user_send_packet(#jid{luser = User, lserver = Server} = From, + #jid{luser = User, lserver = Server, + lresource = <<"">>}, + #xmlel{name = <<"presence">>, attrs = Attrs, + children = Els} = Pkt) -> Type = xml:get_attr_s(<<"type">>, Attrs), if Type == <<"">>; Type == <<"available">> -> case read_caps(Els) of @@ -164,12 +156,15 @@ user_send_packet( feature_request(Server, From, Caps, [Version | Exts]) end; true -> ok - end; -user_send_packet(_From, _To, _Packet) -> ok. + end, + Pkt; +user_send_packet( _From, _To, Pkt) -> + Pkt. -user_receive_packet(#jid{lserver = Server}, From, _To, +user_receive_packet(#jid{lserver = Server}, + From, _To, #xmlel{name = <<"presence">>, attrs = Attrs, - children = Els}) -> + children = Els} = Pkt) -> Type = xml:get_attr_s(<<"type">>, Attrs), IsRemote = not lists:member(From#jid.lserver, ?MYHOSTS), if IsRemote and @@ -180,9 +175,12 @@ user_receive_packet(#jid{lserver = Server}, From, _To, feature_request(Server, From, Caps, [Version | Exts]) end; true -> ok - end; -user_receive_packet(_JID, _From, _To, _Packet) -> - ok. + end, + Pkt; +user_receive_packet( _JID, _From, _To, Pkt) -> + Pkt. + +-spec caps_stream_features([xmlel()], binary()) -> [xmlel()]. caps_stream_features(Acc, MyHost) -> case make_my_disco_hash(MyHost) of @@ -260,7 +258,8 @@ c2s_presence_in(C2SState, end, if CapsUpdated -> ejabberd_hooks:run(caps_update, To#jid.lserver, - [From, To, get_features(Caps)]); + [From, To, + get_features(To#jid.lserver, Caps)]); true -> ok end, ejabberd_c2s:set_aux_field(caps_resources, NewRs, @@ -268,63 +267,90 @@ c2s_presence_in(C2SState, true -> C2SState end. -c2s_broadcast_recipients(InAcc, C2SState, {pep_message, Feature}, - _From, _Packet) -> +c2s_filter_packet(InAcc, Host, C2SState, {pep_message, Feature}, To, _Packet) -> case ejabberd_c2s:get_aux_field(caps_resources, C2SState) of - {ok, Rs} -> - gb_trees_fold( - fun(USR, Caps, Acc) -> - case lists:member(Feature, get_features(Caps)) of - true -> - [USR|Acc]; - false -> - Acc - end - end, InAcc, Rs); - _ -> - InAcc + {ok, Rs} -> + LTo = jlib:jid_tolower(To), + case gb_trees:lookup(LTo, Rs) of + {value, Caps} -> + Drop = not lists:member(Feature, get_features(Host, Caps)), + {stop, Drop}; + none -> + {stop, true} + end; + _ -> InAcc end; -c2s_broadcast_recipients(Acc, _, _, _, _) -> - Acc. +c2s_filter_packet(Acc, _, _, _, _, _) -> Acc. -%%==================================================================== -%% gen_server callbacks -%%==================================================================== -init([Host, Opts]) -> +c2s_broadcast_recipients(InAcc, Host, C2SState, + {pep_message, Feature}, _From, _Packet) -> + case ejabberd_c2s:get_aux_field(caps_resources, + C2SState) + of + {ok, Rs} -> + gb_trees_fold(fun (USR, Caps, Acc) -> + case lists:member(Feature, + get_features(Host, Caps)) + of + true -> [USR | Acc]; + false -> Acc + end + end, + InAcc, Rs); + _ -> InAcc + end; +c2s_broadcast_recipients(Acc, _, _, _, _, _) -> Acc. + +init_db(mnesia, _Host) -> case catch mnesia:table_info(caps_features, storage_type) of - {'EXIT', _} -> - ok; - disc_only_copies -> - ok; - _ -> - mnesia:delete_table(caps_features) + {'EXIT', _} -> + ok; + disc_only_copies -> + ok; + _ -> + mnesia:delete_table(caps_features) end, mnesia:create_table(caps_features, - [{disc_only_copies, [node()]}, - {local_content, true}, - {attributes, record_info(fields, caps_features)}]), - mnesia:add_table_copy(caps_features, node(), disc_only_copies), - MaxSize = gen_mod:get_opt(cache_size, Opts, fun(CS) when is_integer(CS) -> CS end, 1000), - LifeTime = gen_mod:get_opt(cache_life_time, Opts, fun(CL) when is_integer(CL) -> CL end, timer:hours(24) div 1000), - cache_tab:new(caps_features, [{max_size, MaxSize}, {life_time, LifeTime}]), - ejabberd_hooks:add(c2s_presence_in, Host, - ?MODULE, c2s_presence_in, 75), + [{disc_only_copies, [node()]}, + {local_content, true}, + {attributes, + record_info(fields, caps_features)}]), + update_table(), + mnesia:add_table_copy(caps_features, node(), + disc_only_copies); +init_db(_, _) -> + ok. + +init([Host, Opts]) -> + init_db(gen_mod:db_type(Opts), Host), + MaxSize = gen_mod:get_opt(cache_size, Opts, + fun(I) when is_integer(I), I>0 -> I end, + 1000), + LifeTime = gen_mod:get_opt(cache_life_time, Opts, + fun(I) when is_integer(I), I>0 -> I end, + timer:hours(24) div 1000), + cache_tab:new(caps_features, + [{max_size, MaxSize}, {life_time, LifeTime}]), + ejabberd_hooks:add(c2s_presence_in, Host, ?MODULE, + c2s_presence_in, 75), + ejabberd_hooks:add(c2s_filter_packet, Host, ?MODULE, + c2s_filter_packet, 75), ejabberd_hooks:add(c2s_broadcast_recipients, Host, ?MODULE, c2s_broadcast_recipients, 75), - ejabberd_hooks:add(user_send_packet, Host, - ?MODULE, user_send_packet, 75), - ejabberd_hooks:add(user_receive_packet, Host, - ?MODULE, user_receive_packet, 75), - ejabberd_hooks:add(c2s_stream_features, Host, - ?MODULE, caps_stream_features, 75), - ejabberd_hooks:add(s2s_stream_features, Host, - ?MODULE, caps_stream_features, 75), - ejabberd_hooks:add(disco_local_features, Host, - ?MODULE, disco_features, 75), - ejabberd_hooks:add(disco_local_identity, Host, - ?MODULE, disco_identity, 75), - ejabberd_hooks:add(disco_info, Host, - ?MODULE, disco_info, 75), + ejabberd_hooks:add(user_send_packet, Host, ?MODULE, + user_send_packet, 75), + ejabberd_hooks:add(user_receive_packet, Host, ?MODULE, + user_receive_packet, 75), + ejabberd_hooks:add(c2s_stream_features, Host, ?MODULE, + caps_stream_features, 75), + ejabberd_hooks:add(s2s_stream_features, Host, ?MODULE, + caps_stream_features, 75), + ejabberd_hooks:add(disco_local_features, Host, ?MODULE, + disco_features, 75), + ejabberd_hooks:add(disco_local_identity, Host, ?MODULE, + disco_identity, 75), + ejabberd_hooks:add(disco_info, Host, ?MODULE, + disco_info, 75), {ok, #state{host = Host}}. handle_call(stop, _From, State) -> @@ -340,6 +366,8 @@ terminate(_Reason, State) -> Host = State#state.host, ejabberd_hooks:delete(c2s_presence_in, Host, ?MODULE, c2s_presence_in, 75), + ejabberd_hooks:delete(c2s_filter_packet, Host, ?MODULE, + c2s_filter_packet, 75), ejabberd_hooks:delete(c2s_broadcast_recipients, Host, ?MODULE, c2s_broadcast_recipients, 75), ejabberd_hooks:delete(user_send_packet, Host, ?MODULE, @@ -360,15 +388,12 @@ terminate(_Reason, State) -> code_change(_OldVsn, State, _Extra) -> {ok, State}. -%%==================================================================== -%% Aux functions -%%==================================================================== feature_request(Host, From, Caps, [SubNode | Tail] = SubNodes) -> Node = Caps#caps.node, NodePair = {Node, SubNode}, case cache_tab:lookup(caps_features, NodePair, - caps_read_fun(NodePair)) + caps_read_fun(Host, NodePair)) of {ok, Fs} when is_list(Fs) -> feature_request(Host, From, Caps, Tail); @@ -388,7 +413,7 @@ feature_request(Host, From, Caps, SubNode/binary>>}], children = []}]}, cache_tab:insert(caps_features, NodePair, now_ts(), - caps_write_fun(NodePair, now_ts())), + caps_write_fun(Host, NodePair, now_ts())), F = fun (IQReply) -> feature_response(IQReply, Host, From, Caps, SubNodes) @@ -416,7 +441,7 @@ feature_response(#iq{type = result, Els), cache_tab:insert(caps_features, NodePair, Features, - caps_write_fun(NodePair, Features)); + caps_write_fun(Host, NodePair, Features)); false -> ok end, feature_request(Host, From, Caps, SubNodes); @@ -424,18 +449,66 @@ feature_response(_IQResult, Host, From, Caps, [_SubNode | SubNodes]) -> feature_request(Host, From, Caps, SubNodes). -caps_read_fun(Node) -> +caps_read_fun(Host, Node) -> + LServer = jlib:nameprep(Host), + DBType = gen_mod:db_type(LServer, ?MODULE), + caps_read_fun(LServer, Node, DBType). + +caps_read_fun(_LServer, Node, mnesia) -> fun () -> case mnesia:dirty_read({caps_features, Node}) of [#caps_features{features = Features}] -> {ok, Features}; _ -> error end + end; +caps_read_fun(_LServer, Node, riak) -> + fun() -> + case ejabberd_riak:get(caps_features, caps_features_schema(), Node) of + {ok, #caps_features{features = Features}} -> {ok, Features}; + _ -> error + end + end; +caps_read_fun(LServer, {Node, SubNode}, odbc) -> + fun() -> + SNode = ejabberd_odbc:escape(Node), + SSubNode = ejabberd_odbc:escape(SubNode), + case ejabberd_odbc:sql_query( + LServer, [<<"select feature from caps_features where ">>, + <<"node='">>, SNode, <<"' and subnode='">>, + SSubNode, <<"';">>]) of + {selected, [<<"feature">>], [[H]|_] = Fs} -> + case catch jlib:binary_to_integer(H) of + Int when is_integer(Int), Int>=0 -> + {ok, Int}; + _ -> + {ok, lists:flatten(Fs)} + end; + _ -> + error + end end. -caps_write_fun(Node, Features) -> +caps_write_fun(Host, Node, Features) -> + LServer = jlib:nameprep(Host), + DBType = gen_mod:db_type(LServer, ?MODULE), + caps_write_fun(LServer, Node, Features, DBType). + +caps_write_fun(_LServer, Node, Features, mnesia) -> fun () -> mnesia:dirty_write(#caps_features{node_pair = Node, features = Features}) + end; +caps_write_fun(_LServer, Node, Features, riak) -> + fun () -> + ejabberd_riak:put(#caps_features{node_pair = Node, + features = Features}, + caps_features_schema()) + end; +caps_write_fun(LServer, NodePair, Features, odbc) -> + fun () -> + ejabberd_odbc:sql_transaction( + LServer, + sql_write_features_t(NodePair, Features)) end. make_my_disco_hash(Host) -> @@ -585,3 +658,98 @@ is_valid_node(Node) -> _ -> false end. + +update_table() -> + Fields = record_info(fields, caps_features), + case mnesia:table_info(caps_features, attributes) of + Fields -> + ejabberd_config:convert_table_to_binary( + caps_features, Fields, set, + fun(#caps_features{node_pair = {N, _}}) -> N end, + fun(#caps_features{node_pair = {N, P}, + features = Fs} = R) -> + NewFs = if is_integer(Fs) -> + Fs; + true -> + [iolist_to_binary(F) || F <- Fs] + end, + R#caps_features{node_pair = {iolist_to_binary(N), + iolist_to_binary(P)}, + features = NewFs} + end); + _ -> + ?INFO_MSG("Recreating caps_features table", []), + mnesia:transform_table(caps_features, ignore, Fields) + end. + +sql_write_features_t({Node, SubNode}, Features) -> + SNode = ejabberd_odbc:escape(Node), + SSubNode = ejabberd_odbc:escape(SubNode), + NewFeatures = if is_integer(Features) -> + [jlib:integer_to_binary(Features)]; + true -> + Features + end, + [[<<"delete from caps_features where node='">>, + SNode, <<"' and subnode='">>, SSubNode, <<"';">>]| + [[<<"insert into caps_features(node, subnode, feature) ">>, + <<"values ('">>, SNode, <<"', '">>, SSubNode, <<"', '">>, + ejabberd_odbc:escape(F), <<"');">>] || F <- NewFeatures]]. + +caps_features_schema() -> + {record_info(fields, caps_features), #caps_features{}}. + +export(_Server) -> + [{caps_features, + fun(_Host, #caps_features{node_pair = NodePair, + features = Features}) -> + sql_write_features_t(NodePair, Features); + (_Host, _R) -> + [] + end}]. + +import_info() -> + [{<<"caps_features">>, 4}]. + +import_start(LServer, DBType) -> + ets:new(caps_features_tmp, [private, named_table, bag]), + init_db(DBType, LServer), + ok. + +import(_LServer, {odbc, _}, _DBType, <<"caps_features">>, + [Node, SubNode, Feature, _TimeStamp]) -> + Feature1 = case catch jlib:binary_to_integer(Feature) of + I when is_integer(I), I>0 -> I; + _ -> Feature + end, + ets:insert(caps_features_tmp, {{Node, SubNode}, Feature1}), + ok. + +import_stop(LServer, DBType) -> + import_next(LServer, DBType, ets:first(caps_features_tmp)), + ets:delete(caps_features_tmp), + ok. + +import_next(_LServer, _DBType, '$end_of_table') -> + ok; +import_next(LServer, DBType, NodePair) -> + Features = [F || {_, F} <- ets:lookup(caps_features_tmp, NodePair)], + case Features of + [I] when is_integer(I), DBType == mnesia -> + mnesia:dirty_write( + #caps_features{node_pair = NodePair, features = I}); + [I] when is_integer(I), DBType == riak -> + ejabberd_riak:put( + #caps_features{node_pair = NodePair, features = I}, + caps_features_schema()); + _ when DBType == mnesia -> + mnesia:dirty_write( + #caps_features{node_pair = NodePair, features = Features}); + _ when DBType == riak -> + ejabberd_riak:put( + #caps_features{node_pair = NodePair, features = Features}, + caps_features_schema()); + _ when DBType == odbc -> + ok + end, + import_next(LServer, DBType, ets:next(caps_features_tmp, NodePair)). diff --git a/src/mod_carboncopy.erl b/src/mod_carboncopy.erl index 6f3101fcd..1313e341a 100644 --- a/src/mod_carboncopy.erl +++ b/src/mod_carboncopy.erl @@ -41,10 +41,6 @@ remove_connection/4, is_carbon_copy/1]). --define(NS_CC_2, <<"urn:xmpp:carbons:2">>). --define(NS_CC_1, <<"urn:xmpp:carbons:1">>). --define(NS_FORWARD, <<"urn:xmpp:forward:0">>). - -include("ejabberd.hrl"). -include("logger.hrl"). -include("jlib.hrl"). @@ -57,20 +53,24 @@ version :: binary() | matchspec_atom()}). is_carbon_copy(Packet) -> - case xml:get_subtag(Packet, <<"sent">>) of - #xmlel{name= <<"sent">>, attrs = AAttrs} -> - case xml:get_attr_s(<<"xmlns">>, AAttrs) of - ?NS_CC_2 -> true; - ?NS_CC_1 -> true; - _ -> false - end; + is_carbon_copy(Packet, <<"sent">>) orelse + is_carbon_copy(Packet, <<"received">>). + +is_carbon_copy(Packet, Direction) -> + case xml:get_subtag(Packet, Direction) of + #xmlel{name = Direction, attrs = Attrs} -> + case xml:get_attr_s(<<"xmlns">>, Attrs) of + ?NS_CARBONS_2 -> true; + ?NS_CARBONS_1 -> true; _ -> false - end. + end; + _ -> false + end. start(Host, Opts) -> IQDisc = gen_mod:get_opt(iqdisc, Opts,fun gen_iq_handler:check_type/1, one_queue), - mod_disco:register_feature(Host, ?NS_CC_1), - mod_disco:register_feature(Host, ?NS_CC_2), + mod_disco:register_feature(Host, ?NS_CARBONS_1), + mod_disco:register_feature(Host, ?NS_CARBONS_2), Fields = record_info(fields, ?TABLE), try mnesia:table_info(?TABLE, attributes) of Fields -> ok; @@ -86,26 +86,26 @@ start(Host, Opts) -> %% why priority 89: to define clearly that we must run BEFORE mod_logdb hook (90) ejabberd_hooks:add(user_send_packet,Host, ?MODULE, user_send_packet, 89), ejabberd_hooks:add(user_receive_packet,Host, ?MODULE, user_receive_packet, 89), - gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_CC_2, ?MODULE, iq_handler2, IQDisc), - gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_CC_1, ?MODULE, iq_handler1, IQDisc). + gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_CARBONS_2, ?MODULE, iq_handler2, IQDisc), + gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_CARBONS_1, ?MODULE, iq_handler1, IQDisc). stop(Host) -> - gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_CC_1), - gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_CC_2), - mod_disco:unregister_feature(Host, ?NS_CC_2), - mod_disco:unregister_feature(Host, ?NS_CC_1), + gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_CARBONS_1), + gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_CARBONS_2), + mod_disco:unregister_feature(Host, ?NS_CARBONS_2), + mod_disco:unregister_feature(Host, ?NS_CARBONS_1), %% why priority 89: to define clearly that we must run BEFORE mod_logdb hook (90) ejabberd_hooks:delete(user_send_packet,Host, ?MODULE, user_send_packet, 89), ejabberd_hooks:delete(user_receive_packet,Host, ?MODULE, user_receive_packet, 89), ejabberd_hooks:delete(unset_presence_hook,Host, ?MODULE, remove_connection, 10). iq_handler2(From, To, IQ) -> - iq_handler(From, To, IQ, ?NS_CC_2). + iq_handler(From, To, IQ, ?NS_CARBONS_2). iq_handler1(From, To, IQ) -> - iq_handler(From, To, IQ, ?NS_CC_1). + iq_handler(From, To, IQ, ?NS_CARBONS_1). iq_handler(From, _To, #iq{type=set, sub_el = #xmlel{name = Operation, children = []}} = IQ, CC)-> - ?INFO_MSG("carbons IQ received: ~p", [IQ]), + ?DEBUG("carbons IQ received: ~p", [IQ]), {U, S, R} = jlib:jid_tolower(From), Result = case Operation of <<"enable">>-> @@ -117,10 +117,10 @@ iq_handler(From, _To, #iq{type=set, sub_el = #xmlel{name = Operation, children end, case Result of ok -> - ?INFO_MSG("carbons IQ result: ok", []), + ?DEBUG("carbons IQ result: ok", []), IQ#iq{type=result, sub_el=[]}; {error,_Error} -> - ?INFO_MSG("Error enabling / disabling carbons: ~p", [Result]), + ?WARNING_MSG("Error enabling / disabling carbons: ~p", [Result]), IQ#iq{type=error,sub_el = [?ERR_BAD_REQUEST]} end; @@ -139,39 +139,20 @@ user_receive_packet(JID, _From, To, Packet) -> % - do not support "private" message mode, and do not modify the original packet in any way % - we also replicate "read" notifications check_and_forward(JID, To, #xmlel{name = <<"message">>, attrs = Attrs} = Packet, Direction)-> - case xml:get_attr_s(<<"type">>, Attrs) of - <<"chat">> -> - case xml:get_subtag(Packet, <<"private">>) of - false -> - case xml:get_subtag(Packet, <<"no-copy">>) of - false -> - case xml:get_subtag(Packet,<<"received">>) of - false -> - %% We must check if a packet contains "<sent><forwarded></sent></forwarded>" - %% tags in order to avoid receiving message back to original sender. - SubTag = xml:get_subtag(Packet,<<"sent">>), - if SubTag == false -> - send_copies(JID, To, Packet, Direction); - true -> - case xml:get_subtag(SubTag,<<"forwarded">>) of - false-> - send_copies(JID, To, Packet, Direction); - _ -> - stop - end - end; - _ -> - %% stop the hook chain, we don't want mod_logdb to register this message (duplicate) - stop - end; - _ -> - ok - end; - _ -> - ok - end; - _ -> - ok + case xml:get_attr_s(<<"type">>, Attrs) == <<"chat">> andalso + xml:get_subtag(Packet, <<"private">>) == false andalso + xml:get_subtag(Packet, <<"no-copy">>) == false of + true -> + case is_carbon_copy(Packet) of + false -> + send_copies(JID, To, Packet, Direction); + true -> + %% stop the hook chain, we don't want mod_logdb to register + %% this message (duplicate) + stop + end; + _ -> + ok end; check_and_forward(_JID, _To, _Packet, _)-> ok. @@ -186,6 +167,10 @@ remove_connection(User, Server, Resource, _Status)-> send_copies(JID, To, Packet, Direction)-> {U, S, R} = jlib:jid_tolower(JID), PrioRes = ejabberd_sm:get_user_present_resources(U, S), + {MaxPrio, MaxRes} = case catch lists:max(PrioRes) of + {Prio, Res} -> {Prio, Res}; + _ -> {0, undefined} + end, IsBareTo = case {Direction, To} of {received, #jid{lresource = <<>>}} -> true; @@ -199,15 +184,19 @@ send_copies(JID, To, Packet, Direction)-> end, %% list of JIDs that should receive a carbon copy of this message (excluding the %% receiver(s) of the original message - TargetJIDs = if IsBareTo -> - MaxPrio = case catch lists:max(PrioRes) of - {Prio, _Res} -> Prio; - _ -> 0 - end, + TargetJIDs = case {IsBareTo, R} of + {true, MaxRes} -> OrigTo = fun(Res) -> lists:member({MaxPrio, Res}, PrioRes) end, [ {jlib:make_jid({U, S, CCRes}), CC_Version} || {CCRes, CC_Version} <- list(U, S), not OrigTo(CCRes) ]; - true -> + {true, _} -> + %% The message was sent to our bare JID, and we currently have + %% multiple resources with the same highest priority, so the session + %% manager routes the message to each of them. We create carbon + %% copies only from one of those resources (the one where R equals + %% MaxRes) in order to avoid duplicates. + []; + {false, _} -> [ {jlib:make_jid({U, S, CCRes}), CC_Version} || {CCRes, CC_Version} <- list(U, S), CCRes /= R ] %TargetJIDs = lists:delete(JID, [ jlib:make_jid({U, S, CCRes}) || CCRes <- list(U, S) ]), @@ -223,7 +212,7 @@ send_copies(JID, To, Packet, Direction)-> end, TargetJIDs), ok. -build_forward_packet(JID, Packet, Sender, Dest, Direction, ?NS_CC_2) -> +build_forward_packet(JID, Packet, Sender, Dest, Direction, ?NS_CARBONS_2) -> #xmlel{name = <<"message">>, attrs = [{<<"xmlns">>, <<"jabber:client">>}, {<<"type">>, <<"chat">>}, @@ -231,7 +220,7 @@ build_forward_packet(JID, Packet, Sender, Dest, Direction, ?NS_CC_2) -> {<<"to">>, jlib:jid_to_string(Dest)}], children = [ #xmlel{name = list_to_binary(atom_to_list(Direction)), - attrs = [{<<"xmlns">>, ?NS_CC_2}], + attrs = [{<<"xmlns">>, ?NS_CARBONS_2}], children = [ #xmlel{name = <<"forwarded">>, attrs = [{<<"xmlns">>, ?NS_FORWARD}], @@ -239,7 +228,7 @@ build_forward_packet(JID, Packet, Sender, Dest, Direction, ?NS_CC_2) -> complete_packet(JID, Packet, Direction)]} ]} ]}; -build_forward_packet(JID, Packet, Sender, Dest, Direction, ?NS_CC_1) -> +build_forward_packet(JID, Packet, Sender, Dest, Direction, ?NS_CARBONS_1) -> #xmlel{name = <<"message">>, attrs = [{<<"xmlns">>, <<"jabber:client">>}, {<<"type">>, <<"chat">>}, @@ -247,7 +236,7 @@ build_forward_packet(JID, Packet, Sender, Dest, Direction, ?NS_CC_1) -> {<<"to">>, jlib:jid_to_string(Dest)}], children = [ #xmlel{name = list_to_binary(atom_to_list(Direction)), - attrs = [{<<"xmlns">>, ?NS_CC_1}]}, + attrs = [{<<"xmlns">>, ?NS_CARBONS_1}]}, #xmlel{name = <<"forwarded">>, attrs = [{<<"xmlns">>, ?NS_FORWARD}], children = [complete_packet(JID, Packet, Direction)]} diff --git a/src/mod_client_state.erl b/src/mod_client_state.erl new file mode 100644 index 000000000..b43683bb7 --- /dev/null +++ b/src/mod_client_state.erl @@ -0,0 +1,105 @@ +%%%---------------------------------------------------------------------- +%%% File : mod_client_state.erl +%%% Author : Holger Weiss +%%% Purpose : Filter stanzas sent to inactive clients (XEP-0352) +%%% Created : 11 Sep 2014 by Holger Weiss +%%% +%%% +%%% ejabberd, Copyright (C) 2014 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%---------------------------------------------------------------------- + +-module(mod_client_state). +-author('holger@zedat.fu-berlin.de'). + +-behavior(gen_mod). + +-export([start/2, stop/1, add_stream_feature/2, filter_presence/2, + filter_chat_states/2]). + +-include("ejabberd.hrl"). +-include("logger.hrl"). +-include("jlib.hrl"). + +start(Host, Opts) -> + QueuePresence = gen_mod:get_opt(queue_presence, Opts, + fun(true) -> true end, false), + DropChatStates = gen_mod:get_opt(drop_chat_states, Opts, + fun(true) -> true end, false), + if QueuePresence; DropChatStates -> + ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE, + add_stream_feature, 50), + if QueuePresence -> + ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, + filter_presence, 50); + true -> ok + end, + if DropChatStates -> + ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE, + filter_chat_states, 50); + true -> ok + end; + true -> ok + end, + ok. + +stop(Host) -> + ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, + filter_presence, 50), + ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE, + filter_chat_states, 50), + ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE, + add_stream_feature, 50), + ok. + +add_stream_feature(Features, _Host) -> + Feature = #xmlel{name = <<"csi">>, + attrs = [{<<"xmlns">>, ?NS_CLIENT_STATE}], + children = []}, + [Feature | Features]. + +filter_presence(_Action, #xmlel{name = <<"presence">>, attrs = Attrs}) -> + case xml:get_attr(<<"type">>, Attrs) of + {value, Type} when Type /= <<"unavailable">> -> + ?DEBUG("Got important presence stanza", []), + {stop, send}; + _ -> + ?DEBUG("Got availability presence stanza", []), + {stop, queue} + end; +filter_presence(Action, _Stanza) -> Action. + +filter_chat_states(_Action, #xmlel{name = <<"message">>} = Stanza) -> + %% All XEP-0085 chat states except for <gone/>: + ChatStates = [<<"active">>, <<"inactive">>, <<"composing">>, <<"paused">>], + Stripped = + lists:foldl(fun(ChatState, AccStanza) -> + xml:remove_subtags(AccStanza, ChatState, + {<<"xmlns">>, ?NS_CHATSTATES}) + end, Stanza, ChatStates), + case Stripped of + #xmlel{children = [#xmlel{name = <<"thread">>}]} -> + ?DEBUG("Got standalone chat state notification", []), + {stop, drop}; + #xmlel{children = []} -> + ?DEBUG("Got standalone chat state notification", []), + {stop, drop}; + _ -> + ?DEBUG("Got message with chat state notification", []), + {stop, send} + end; +filter_chat_states(Action, _Stanza) -> Action. diff --git a/src/mod_fail2ban.erl b/src/mod_fail2ban.erl new file mode 100644 index 000000000..b246e402c --- /dev/null +++ b/src/mod_fail2ban.erl @@ -0,0 +1,161 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net> +%%% @copyright (C) 2014, Evgeny Khramtsov +%%% @doc +%%% +%%% @end +%%% Created : 15 Aug 2014 by Evgeny Khramtsov <ekhramtsov@process-one.net> +%%%------------------------------------------------------------------- +-module(mod_fail2ban). + +-behaviour(gen_mod). +-behaviour(gen_server). + +%% API +-export([start_link/2, start/2, stop/1, c2s_auth_result/4, check_bl_c2s/3]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-include_lib("stdlib/include/ms_transform.hrl"). +-include("ejabberd.hrl"). +-include("logger.hrl"). + +-define(C2S_AUTH_BAN_LIFETIME, 3600). %% 1 hour +-define(C2S_MAX_AUTH_FAILURES, 20). +-define(CLEAN_INTERVAL, timer:minutes(10)). + +-record(state, {host = <<"">> :: binary()}). + +%%%=================================================================== +%%% API +%%%=================================================================== +start_link(Host, Opts) -> + Proc = gen_mod:get_module_proc(Host, ?MODULE), + gen_server:start_link({local, Proc}, ?MODULE, [Host, Opts], []). + +c2s_auth_result(false, _User, LServer, {Addr, _Port}) -> + BanLifetime = gen_mod:get_module_opt( + LServer, ?MODULE, c2s_auth_ban_lifetime, + fun(T) when is_integer(T), T > 0 -> T end, + ?C2S_AUTH_BAN_LIFETIME), + MaxFailures = gen_mod:get_module_opt( + LServer, ?MODULE, c2s_max_auth_failures, + fun(I) when is_integer(I), I > 0 -> I end, + ?C2S_MAX_AUTH_FAILURES), + UnbanTS = unban_timestamp(BanLifetime), + case ets:lookup(failed_auth, Addr) of + [{Addr, N, _, _}] -> + ets:insert(failed_auth, {Addr, N+1, UnbanTS, MaxFailures}); + [] -> + ets:insert(failed_auth, {Addr, 1, UnbanTS, MaxFailures}) + end; +c2s_auth_result(true, _User, _Server, _AddrPort) -> + ok. + +check_bl_c2s(_Acc, Addr, Lang) -> + case ets:lookup(failed_auth, Addr) of + [{Addr, N, TS, MaxFailures}] when N >= MaxFailures -> + case TS > now() of + true -> + IP = jlib:ip_to_list(Addr), + UnbanDate = format_date( + calendar:now_to_universal_time(TS)), + LogReason = io_lib:fwrite( + "Too many (~p) failed authentications " + "from this IP address (~s). The address " + "will be unblocked at ~s UTC", + [N, IP, UnbanDate]), + ReasonT = io_lib:fwrite( + translate:translate( + Lang, + <<"Too many (~p) failed authentications " + "from this IP address (~s). The address " + "will be unblocked at ~s UTC">>), + [N, IP, UnbanDate]), + {stop, {true, LogReason, ReasonT}}; + false -> + ets:delete(failed_auth, Addr), + false + end; + _ -> + false + end. + +%%==================================================================== +%% gen_mod callbacks +%%==================================================================== +start(Host, Opts) -> + catch ets:new(failed_auth, [named_table, public]), + Proc = gen_mod:get_module_proc(Host, ?MODULE), + ChildSpec = {Proc, {?MODULE, start_link, [Host, Opts]}, + transient, 1000, worker, [?MODULE]}, + supervisor:start_child(ejabberd_sup, ChildSpec). + +stop(Host) -> + Proc = gen_mod:get_module_proc(Host, ?MODULE), + supervisor:terminate_child(ejabberd_sup, Proc), + supervisor:delete_child(ejabberd_sup, Proc). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([Host, _Opts]) -> + ejabberd_hooks:add(c2s_auth_result, Host, ?MODULE, c2s_auth_result, 100), + ejabberd_hooks:add(check_bl_c2s, ?MODULE, check_bl_c2s, 100), + erlang:send_after(?CLEAN_INTERVAL, self(), clean), + {ok, #state{host = Host}}. + +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +handle_cast(_Msg, State) -> + ?ERROR_MSG("got unexpected cast = ~p", [_Msg]), + {noreply, State}. + +handle_info(clean, State) -> + ?DEBUG("cleaning ~p ETS table", [failed_auth]), + Now = now(), + ets:select_delete( + failed_auth, + ets:fun2ms(fun({_, _, UnbanTS, _}) -> UnbanTS =< Now end)), + erlang:send_after(?CLEAN_INTERVAL, self(), clean), + {noreply, State}; +handle_info(_Info, State) -> + ?ERROR_MSG("got unexpected info = ~p", [_Info]), + {noreply, State}. + +terminate(_Reason, #state{host = Host}) -> + ejabberd_hooks:delete(c2s_auth_result, Host, ?MODULE, c2s_auth_result, 100), + case is_loaded_at_other_hosts(Host) of + true -> + ok; + false -> + ejabberd_hooks:delete(check_bl_c2s, ?MODULE, check_bl_c2s, 100), + ets:delete(failed_auth) + end. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +unban_timestamp(BanLifetime) -> + {MegaSecs, MSecs, USecs} = now(), + UnbanSecs = MegaSecs * 1000000 + MSecs + BanLifetime, + {UnbanSecs div 1000000, UnbanSecs rem 1000000, USecs}. + +is_loaded_at_other_hosts(Host) -> + lists:any( + fun(VHost) when VHost == Host -> + false; + (VHost) -> + gen_mod:is_loaded(VHost, ?MODULE) + end, ?MYHOSTS). + +format_date({{Year, Month, Day}, {Hour, Minute, Second}}) -> + io_lib:format("~2..0w:~2..0w:~2..0w ~2..0w.~2..0w.~4..0w", + [Hour, Minute, Second, Day, Month, Year]). diff --git a/src/mod_http_fileserver.erl b/src/mod_http_fileserver.erl index 8807f24bf..1011dd07f 100644 --- a/src/mod_http_fileserver.erl +++ b/src/mod_http_fileserver.erl @@ -48,17 +48,12 @@ -include("ejabberd.hrl"). -include("logger.hrl"). +-include("ejabberd_http.hrl"). -include("jlib.hrl"). -include_lib("kernel/include/file.hrl"). -%%-include("ejabberd_http.hrl"). -%% TODO: When ejabberd-modules SVN gets the new ejabberd_http.hrl, delete this code: --record(request, - {method, path, q = [], us, auth, lang = <<"">>, - data = <<"">>, ip, host, port, tp, headers}). - -record(state, {host, docroot, accesslog, accesslogfd, directory_indices, custom_headers, default_content_type, diff --git a/src/mod_ip_blacklist.erl b/src/mod_ip_blacklist.erl index f0feb6551..1dd641ce5 100644 --- a/src/mod_ip_blacklist.erl +++ b/src/mod_ip_blacklist.erl @@ -37,7 +37,7 @@ -export([update_bl_c2s/0]). %% Hooks: --export([is_ip_in_c2s_blacklist/2]). +-export([is_ip_in_c2s_blacklist/3]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -107,14 +107,23 @@ update_bl_c2s() -> %% Return: false: IP not blacklisted %% true: IP is blacklisted %% IPV4 IP tuple: -is_ip_in_c2s_blacklist(_Val, IP) when is_tuple(IP) -> +is_ip_in_c2s_blacklist(_Val, IP, Lang) when is_tuple(IP) -> BinaryIP = jlib:ip_to_list(IP), case ets:lookup(bl_c2s, BinaryIP) of [] -> %% Not in blacklist false; - [_] -> {stop, true} + [_] -> + LogReason = io_lib:fwrite( + "This IP address is blacklisted in ~s", + [?BLC2S]), + ReasonT = io_lib:fwrite( + translate:translate( + Lang, + <<"This IP address is blacklisted in ~s">>), + [?BLC2S]), + {stop, {true, LogReason, ReasonT}} end; -is_ip_in_c2s_blacklist(_Val, _IP) -> false. +is_ip_in_c2s_blacklist(_Val, _IP, _Lang) -> false. %% TODO: %% - For now, we do not kick user already logged on a given IP after diff --git a/src/mod_irc.erl b/src/mod_irc.erl index 88e0f1dce..f6e7bb774 100644 --- a/src/mod_irc.erl +++ b/src/mod_irc.erl @@ -56,7 +56,8 @@ -type conn_param() :: {binary(), binary(), inet:port_number(), binary()} | {binary(), binary(), inet:port_number()} | - {binary(), binary()}. + {binary(), binary()} | + {binary()}. -record(irc_connection, {jid_server_host = {#jid{}, <<"">>, <<"">>} :: {jid(), binary(), binary()}, @@ -590,6 +591,17 @@ get_data(_LServer, Host, From, mnesia) -> [] -> empty; [#irc_custom{data = Data}] -> Data end; +get_data(LServer, Host, From, riak) -> + #jid{luser = LUser, lserver = LServer} = From, + US = {LUser, LServer}, + case ejabberd_riak:get(irc_custom, irc_custom_schema(), {US, Host}) of + {ok, #irc_custom{data = Data}} -> + Data; + {error, notfound} -> + empty; + _Err -> + error + end; get_data(LServer, Host, From, odbc) -> SJID = ejabberd_odbc:escape(jlib:jid_to_string(jlib:jid_tolower(jlib:jid_remove_resource(From)))), @@ -600,7 +612,7 @@ get_data(LServer, Host, From, odbc) -> <<"';">>]) of {selected, [<<"data">>], [[SData]]} -> - data_to_binary(ejabberd_odbc:decode_term(SData)); + data_to_binary(From, ejabberd_odbc:decode_term(SData)); {'EXIT', _} -> error; {selected, _, _} -> empty end. @@ -711,7 +723,7 @@ get_form(_ServerHost, _Host, _, _, _Lang) -> set_data(ServerHost, Host, From, Data) -> LServer = jlib:nameprep(ServerHost), - set_data(LServer, Host, From, data_to_binary(Data), + set_data(LServer, Host, From, data_to_binary(From, Data), gen_mod:db_type(LServer, ?MODULE)). set_data(_LServer, Host, From, Data, mnesia) -> @@ -722,6 +734,12 @@ set_data(_LServer, Host, From, Data, mnesia) -> data = Data}) end, mnesia:transaction(F); +set_data(LServer, Host, From, Data, riak) -> + {LUser, LServer, _} = jlib:jid_tolower(From), + US = {LUser, LServer}, + {atomic, ejabberd_riak:put(#irc_custom{us_host = {US, Host}, + data = Data}, + irc_custom_schema())}; set_data(LServer, Host, From, Data, odbc) -> SJID = ejabberd_odbc:escape(jlib:jid_to_string(jlib:jid_tolower(jlib:jid_remove_resource(From)))), @@ -1217,28 +1235,48 @@ get_username_and_connection_params(Data) -> end, {Username, ConnParams}. -data_to_binary(Data) -> +data_to_binary(JID, Data) -> lists:map( fun({username, U}) -> {username, iolist_to_binary(U)}; ({connections_params, Params}) -> - {connections_params, - lists:map( - fun({S, E}) -> - {iolist_to_binary(S), iolist_to_binary(E)}; - ({S, E, Port}) -> - {iolist_to_binary(S), iolist_to_binary(E), Port}; - ({S, E, Port, P}) -> - {iolist_to_binary(S), iolist_to_binary(E), - Port, iolist_to_binary(P)} - end, Params)}; + {connections_params, + lists:flatmap( + fun(Param) -> + try + [conn_param_to_binary(Param)] + catch _:_ -> + if JID /= error -> + ?ERROR_MSG("failed to convert " + "parameter ~p for user ~s", + [Param, + jlib:jid_to_string(JID)]); + true -> + ?ERROR_MSG("failed to convert " + "parameter ~p", + [Param]) + end, + [] + end + end, Params)}; (Opt) -> Opt end, Data). +conn_param_to_binary({S}) -> + {iolist_to_binary(S)}; +conn_param_to_binary({S, E}) -> + {iolist_to_binary(S), iolist_to_binary(E)}; +conn_param_to_binary({S, E, Port}) when is_integer(Port) -> + {iolist_to_binary(S), iolist_to_binary(E), Port}; +conn_param_to_binary({S, E, Port, P}) when is_integer(Port) -> + {iolist_to_binary(S), iolist_to_binary(E), Port, iolist_to_binary(P)}. + conn_params_to_list(Params) -> lists:map( - fun({S, E}) -> + fun({S}) -> + {binary_to_list(S)}; + ({S, E}) -> {binary_to_list(S), binary_to_list(E)}; ({S, E, Port}) -> {binary_to_list(S), binary_to_list(E), Port}; @@ -1247,6 +1285,9 @@ conn_params_to_list(Params) -> Port, binary_to_list(P)} end, Params). +irc_custom_schema() -> + {record_info(fields, irc_custom), #irc_custom{}}. + update_table() -> Fields = record_info(fields, irc_custom), case mnesia:table_info(irc_custom, attributes) of @@ -1256,10 +1297,11 @@ update_table() -> fun(#irc_custom{us_host = {_, H}}) -> H end, fun(#irc_custom{us_host = {{U, S}, H}, data = Data} = R) -> + JID = jlib:make_jid(U, S, <<"">>), R#irc_custom{us_host = {{iolist_to_binary(U), iolist_to_binary(S)}, iolist_to_binary(H)}, - data = data_to_binary(Data)} + data = data_to_binary(JID, Data)} end); _ -> ?INFO_MSG("Recreating irc_custom table", []), @@ -1299,5 +1341,7 @@ import(_LServer) -> import(_LServer, mnesia, #irc_custom{} = R) -> mnesia:dirty_write(R); +import(_LServer, riak, #irc_custom{} = R) -> + ejabberd_riak:put(R, irc_custom_schema()); import(_, _, _) -> pass. diff --git a/src/mod_last.erl b/src/mod_last.erl index 6b7a06bed..a20da3130 100644 --- a/src/mod_last.erl +++ b/src/mod_last.erl @@ -168,6 +168,17 @@ get_last(LUser, LServer, mnesia) -> status = Status}] -> {ok, TimeStamp, Status} end; +get_last(LUser, LServer, riak) -> + case ejabberd_riak:get(last_activity, last_activity_schema(), + {LUser, LServer}) of + {ok, #last_activity{timestamp = TimeStamp, + status = Status}} -> + {ok, TimeStamp, Status}; + {error, notfound} -> + not_found; + Err -> + Err + end; get_last(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), case catch odbc_queries:get_last(LServer, Username) of @@ -236,6 +247,13 @@ store_last_info(LUser, LServer, TimeStamp, Status, end, mnesia:transaction(F); store_last_info(LUser, LServer, TimeStamp, Status, + riak) -> + US = {LUser, LServer}, + {atomic, ejabberd_riak:put(#last_activity{us = US, + timestamp = TimeStamp, + status = Status}, + last_activity_schema())}; +store_last_info(LUser, LServer, TimeStamp, Status, odbc) -> Username = ejabberd_odbc:escape(LUser), Seconds = @@ -264,7 +282,9 @@ remove_user(LUser, LServer, mnesia) -> mnesia:transaction(F); remove_user(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), - odbc_queries:del_last(LServer, Username). + odbc_queries:del_last(LServer, Username); +remove_user(LUser, LServer, riak) -> + {atomic, ejabberd_riak:delete(last_activity, {LUser, LServer})}. update_table() -> Fields = record_info(fields, last_activity), @@ -283,6 +303,9 @@ update_table() -> mnesia:transform_table(last_activity, ignore, Fields) end. +last_activity_schema() -> + {record_info(fields, last_activity), #last_activity{}}. + export(_Server) -> [{last_activity, fun(Host, #last_activity{us = {LUser, LServer}, @@ -312,6 +335,8 @@ import(LServer) -> import(_LServer, mnesia, #last_activity{} = LA) -> mnesia:dirty_write(LA); +import(_LServer, riak, #last_activity{} = LA) -> + ejabberd_riak:put(LA, last_activity_schema()); import(_, _, _) -> pass. diff --git a/src/mod_muc.erl b/src/mod_muc.erl index 160b9009d..b844a01e0 100644 --- a/src/mod_muc.erl +++ b/src/mod_muc.erl @@ -147,6 +147,10 @@ store_room(_LServer, Host, Name, Opts, mnesia) -> opts = Opts}) end, mnesia:transaction(F); +store_room(_LServer, Host, Name, Opts, riak) -> + {atomic, ejabberd_riak:put(#muc_room{name_host = {Name, Host}, + opts = Opts}, + muc_room_schema())}; store_room(LServer, Host, Name, Opts, odbc) -> SName = ejabberd_odbc:escape(Name), SHost = ejabberd_odbc:escape(Host), @@ -170,6 +174,11 @@ restore_room(_LServer, Host, Name, mnesia) -> [#muc_room{opts = Opts}] -> Opts; _ -> error end; +restore_room(_LServer, Host, Name, riak) -> + case ejabberd_riak:get(muc_room, muc_room_schema(), {Name, Host}) of + {ok, #muc_room{opts = Opts}} -> Opts; + _ -> error + end; restore_room(LServer, Host, Name, odbc) -> SName = ejabberd_odbc:escape(Name), SHost = ejabberd_odbc:escape(Host), @@ -192,6 +201,8 @@ forget_room(_LServer, Host, Name, mnesia) -> F = fun () -> mnesia:delete({muc_room, {Name, Host}}) end, mnesia:transaction(F); +forget_room(_LServer, Host, Name, riak) -> + {atomic, ejabberd_riak:delete(muc_room, {Name, Host})}; forget_room(LServer, Host, Name, odbc) -> SName = ejabberd_odbc:escape(Name), SHost = ejabberd_odbc:escape(Host), @@ -231,6 +242,19 @@ can_use_nick(_LServer, Host, JID, Nick, mnesia) -> [] -> true; [#muc_registered{us_host = {U, _Host}}] -> U == LUS end; +can_use_nick(LServer, Host, JID, Nick, riak) -> + {LUser, LServer, _} = jlib:jid_tolower(JID), + LUS = {LUser, LServer}, + case ejabberd_riak:get_by_index(muc_registered, + muc_registered_schema(), + <<"nick_host">>, {Nick, Host}) of + {ok, []} -> + true; + {ok, [#muc_registered{us_host = {U, _Host}}]} -> + U == LUS; + {error, _} -> + true + end; can_use_nick(LServer, Host, JID, Nick, odbc) -> SJID = jlib:jid_to_string(jlib:jid_tolower(jlib:jid_remove_resource(JID))), @@ -617,6 +641,16 @@ get_rooms(_LServer, Host, mnesia) -> {'EXIT', Reason} -> ?ERROR_MSG("~p", [Reason]), []; Rs -> Rs end; +get_rooms(_LServer, Host, riak) -> + case ejabberd_riak:get(muc_room, muc_room_schema()) of + {ok, Rs} -> + lists:filter( + fun(#muc_room{name_host = {_, H}}) -> + Host == H + end, Rs); + _Err -> + [] + end; get_rooms(LServer, Host, odbc) -> SHost = ejabberd_odbc:escape(Host), case catch ejabberd_odbc:sql_query(LServer, @@ -839,6 +873,15 @@ get_nick(_LServer, Host, From, mnesia) -> [] -> error; [#muc_registered{nick = Nick}] -> Nick end; +get_nick(LServer, Host, From, riak) -> + {LUser, LServer, _} = jlib:jid_tolower(From), + US = {LUser, LServer}, + case ejabberd_riak:get(muc_registered, + muc_registered_schema(), + {US, Host}) of + {ok, #muc_registered{nick = Nick}} -> Nick; + {error, _} -> error + end; get_nick(LServer, Host, From, odbc) -> SJID = ejabberd_odbc:escape(jlib:jid_to_string(jlib:jid_tolower(jlib:jid_remove_resource(From)))), @@ -871,7 +914,8 @@ iq_get_register_info(ServerHost, Host, From, Lang) -> <<"You need a client that supports x:data " "to register the nickname">>)}]}, #xmlel{name = <<"x">>, - attrs = [{<<"xmlns">>, ?NS_XDATA}], + attrs = [{<<"xmlns">>, ?NS_XDATA}, + {<<"type">>, <<"form">>}], children = [#xmlel{name = <<"title">>, attrs = [], children = @@ -922,6 +966,35 @@ set_nick(_LServer, Host, From, Nick, mnesia) -> end end, mnesia:transaction(F); +set_nick(LServer, Host, From, Nick, riak) -> + {LUser, LServer, _} = jlib:jid_tolower(From), + LUS = {LUser, LServer}, + {atomic, + case Nick of + <<"">> -> + ejabberd_riak:delete(muc_registered, {LUS, Host}); + _ -> + Allow = case ejabberd_riak:get_by_index( + muc_registered, + muc_registered_schema(), + <<"nick_host">>, {Nick, Host}) of + {ok, []} -> + true; + {ok, [#muc_registered{us_host = {U, _Host}}]} -> + U == LUS; + {error, _} -> + false + end, + if Allow -> + ejabberd_riak:put(#muc_registered{us_host = {LUS, Host}, + nick = Nick}, + muc_registered_schema(), + [{'2i', [{<<"nick_host">>, + {Nick, Host}}]}]); + true -> + false + end + end}; set_nick(LServer, Host, From, Nick, odbc) -> JID = jlib:jid_to_string(jlib:jid_tolower(jlib:jid_remove_resource(From))), @@ -1107,6 +1180,12 @@ update_tables(Host) -> update_muc_room_table(Host), update_muc_registered_table(Host). +muc_room_schema() -> + {record_info(fields, muc_room), #muc_room{}}. + +muc_registered_schema() -> + {record_info(fields, muc_registered), #muc_registered{}}. + update_muc_room_table(_Host) -> Fields = record_info(fields, muc_room), case mnesia:table_info(muc_room, attributes) of @@ -1202,5 +1281,11 @@ import(_LServer, mnesia, #muc_room{} = R) -> mnesia:dirty_write(R); import(_LServer, mnesia, #muc_registered{} = R) -> mnesia:dirty_write(R); +import(_LServer, riak, #muc_room{} = R) -> + ejabberd_riak:put(R, muc_room_schema()); +import(_LServer, riak, + #muc_registered{us_host = {_, Host}, nick = Nick} = R) -> + ejabberd_riak:put(R, muc_registered_schema(), + [{'2i', [{<<"nick_host">>, {Nick, Host}}]}]); import(_, _, _) -> pass. diff --git a/src/mod_muc_log.erl b/src/mod_muc_log.erl index ac6bea4fa..bdaafd197 100644 --- a/src/mod_muc_log.erl +++ b/src/mod_muc_log.erl @@ -571,16 +571,7 @@ get_dateweek(Date, Lang) -> end). make_dir_rec(Dir) -> - DirS = binary_to_list(Dir), - case file:read_file_info(DirS) of - {ok, _} -> ok; - {error, enoent} -> - DirL = [list_to_binary(F) || F <- filename:split(DirS)], - DirR = lists:sublist(DirL, length(DirL) - 1), - make_dir_rec(fjoin(DirR)), - file:make_dir(DirS), - file:change_mode(DirS, 8#00755) % -rwxr-xr-x - end. + filelib:ensure_dir(<<Dir/binary, $/>>). %% {ok, F1}=file:open("valid-xhtml10.png", [read]). %% {ok, F1b}=file:read(F1, 1000000). diff --git a/src/mod_muc_room.erl b/src/mod_muc_room.erl index 3842fde40..0974950b7 100644 --- a/src/mod_muc_room.erl +++ b/src/mod_muc_room.erl @@ -127,6 +127,13 @@ init([Host, ServerHost, Access, Room, HistorySize, RoomShaper, Creator, _Nick, D just_created = true, room_shaper = Shaper}), State1 = set_opts(DefRoomOpts, State), + if (State1#state.config)#config.persistent -> + mod_muc:store_room(State1#state.server_host, + State1#state.host, + State1#state.room, + make_opts(State1)); + true -> ok + end, ?INFO_MSG("Created MUC room ~s@~s by ~s", [Room, Host, jlib:jid_to_string(Creator)]), add_to_log(room_existence, created, State1), @@ -167,7 +174,7 @@ normal_state({route, From, <<"">>, Now = now_to_usec(now()), MinMessageInterval = trunc(gen_mod:get_module_opt(StateData#state.server_host, - mod_muc, min_message_interval, fun(MMI) when is_integer(MMI) -> MMI end, 0) + mod_muc, min_message_interval, fun(MMI) when is_number(MMI) -> MMI end, 0) * 1000000), Size = element_size(Packet), {MessageShaper, MessageShaperInterval} = @@ -1510,15 +1517,17 @@ get_user_activity(JID, StateData) -> store_user_activity(JID, UserActivity, StateData) -> MinMessageInterval = - gen_mod:get_module_opt(StateData#state.server_host, - mod_muc, min_message_interval, - fun(I) when is_integer(I), I>=0 -> I end, - 0), + trunc(gen_mod:get_module_opt(StateData#state.server_host, + mod_muc, min_message_interval, + fun(I) when is_number(I), I>=0 -> I end, + 0) + * 1000), MinPresenceInterval = - gen_mod:get_module_opt(StateData#state.server_host, - mod_muc, min_presence_interval, - fun(I) when is_integer(I), I>=0 -> I end, - 0), + trunc(gen_mod:get_module_opt(StateData#state.server_host, + mod_muc, min_presence_interval, + fun(I) when is_number(I), I>=0 -> I end, + 0) + * 1000), Key = jlib:jid_tolower(JID), Now = now_to_usec(now()), Activity1 = clean_treap(StateData#state.activity, @@ -1549,8 +1558,8 @@ store_user_activity(JID, UserActivity, StateData) -> 100000), Delay = lists:max([MessageShaperInterval, PresenceShaperInterval, - MinMessageInterval * 1000, - MinPresenceInterval * 1000]) + MinMessageInterval, + MinPresenceInterval]) * 1000, Priority = {1, -(Now + Delay)}, StateData#state{activity = diff --git a/src/mod_offline.erl b/src/mod_offline.erl index f27d35830..91d31a75d 100644 --- a/src/mod_offline.erl +++ b/src/mod_offline.erl @@ -26,13 +26,15 @@ -module(mod_offline). -author('alexey@process-one.net'). +-define(GEN_SERVER, p1_server). +-behaviour(?GEN_SERVER). -behaviour(gen_mod). -export([count_offline_messages/2]). -export([start/2, - loop/2, + start_link/2, stop/1, store_packet/3, resend_offline_messages/2, @@ -50,6 +52,10 @@ webadmin_user/4, webadmin_user_parse_query/5]). +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). + -include("ejabberd.hrl"). -include("logger.hrl"). @@ -67,6 +73,10 @@ to = #jid{} :: jid() | '_', packet = #xmlel{} :: xmlel() | '_'}). +-record(state, + {host = <<"">> :: binary(), + access_max_offline_messages}). + -define(PROCNAME, ejabberd_offline). -define(OFFLINE_TABLE_LOCK_THRESHOLD, 1000). @@ -74,7 +84,29 @@ %% default value for the maximum number of user messages -define(MAX_USER_MESSAGES, infinity). +start_link(Host, Opts) -> + Proc = gen_mod:get_module_proc(Host, ?PROCNAME), + ?GEN_SERVER:start_link({local, Proc}, ?MODULE, + [Host, Opts], []). + start(Host, Opts) -> + Proc = gen_mod:get_module_proc(Host, ?PROCNAME), + ChildSpec = {Proc, {?MODULE, start_link, [Host, Opts]}, + temporary, 1000, worker, [?MODULE]}, + supervisor:start_child(ejabberd_sup, ChildSpec). + +stop(Host) -> + Proc = gen_mod:get_module_proc(Host, ?PROCNAME), + ?GEN_SERVER:call(Proc, stop), + supervisor:delete_child(ejabberd_sup, Proc), + ok. + + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +init([Host, Opts]) -> case gen_mod:db_type(Opts) of mnesia -> mnesia:create_table(offline_msg, @@ -102,31 +134,63 @@ start(Host, Opts) -> ejabberd_hooks:add(webadmin_user_parse_query, Host, ?MODULE, webadmin_user_parse_query, 50), AccessMaxOfflineMsgs = gen_mod:get_opt(access_max_user_messages, Opts, fun(A) -> A end, max_user_offline_messages), - register(gen_mod:get_module_proc(Host, ?PROCNAME), - spawn(?MODULE, loop, [Host, AccessMaxOfflineMsgs])). + {ok, + #state{host = Host, + access_max_offline_messages = AccessMaxOfflineMsgs}}. + + +handle_call(stop, _From, State) -> + {stop, normal, ok, State}. + + +handle_cast(_Msg, State) -> {noreply, State}. + + +handle_info(#offline_msg{us = UserServer} = Msg, State) -> + #state{host = Host, + access_max_offline_messages = AccessMaxOfflineMsgs} = State, + DBType = gen_mod:db_type(Host, ?MODULE), + Msgs = receive_all(UserServer, [Msg], DBType), + Len = length(Msgs), + MaxOfflineMsgs = get_max_user_messages(AccessMaxOfflineMsgs, + UserServer, Host), + store_offline_msg(Host, UserServer, Msgs, Len, MaxOfflineMsgs, DBType), + {noreply, State}; + +handle_info(_Info, State) -> + ?ERROR_MSG("got unexpected info: ~p", [_Info]), + {noreply, State}. + + +terminate(_Reason, State) -> + Host = State#state.host, + ejabberd_hooks:delete(offline_message_hook, Host, + ?MODULE, store_packet, 50), + ejabberd_hooks:delete(resend_offline_messages_hook, + Host, ?MODULE, pop_offline_messages, 50), + ejabberd_hooks:delete(remove_user, Host, ?MODULE, + remove_user, 50), + ejabberd_hooks:delete(anonymous_purge_hook, Host, + ?MODULE, remove_user, 50), + ejabberd_hooks:delete(disco_sm_features, Host, ?MODULE, get_sm_features, 50), + ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, get_sm_features, 50), + ejabberd_hooks:delete(webadmin_page_host, Host, + ?MODULE, webadmin_page, 50), + ejabberd_hooks:delete(webadmin_user, Host, + ?MODULE, webadmin_user, 50), + ejabberd_hooks:delete(webadmin_user_parse_query, Host, + ?MODULE, webadmin_user_parse_query, 50), + ok. + + +code_change(_OldVsn, State, _Extra) -> {ok, State}. -loop(Host, AccessMaxOfflineMsgs) -> - receive - #offline_msg{us = UserServer} = Msg -> - DBType = gen_mod:db_type(Host, ?MODULE), - Msgs = receive_all(UserServer, [Msg], DBType), - Len = length(Msgs), - MaxOfflineMsgs = get_max_user_messages(AccessMaxOfflineMsgs, - UserServer, Host), - store_offline_msg(Host, UserServer, Msgs, Len, MaxOfflineMsgs, DBType), - loop(Host, AccessMaxOfflineMsgs); - _ -> - loop(Host, AccessMaxOfflineMsgs) - end. store_offline_msg(_Host, US, Msgs, Len, MaxOfflineMsgs, mnesia) -> F = fun () -> Count = if MaxOfflineMsgs =/= infinity -> - Len + - p1_mnesia:count_records(offline_msg, - #offline_msg{us = US, - _ = '_'}); + Len + count_mnesia_records(US); true -> 0 end, if Count > MaxOfflineMsgs -> discard_warn_sender(Msgs); @@ -175,6 +239,23 @@ store_offline_msg(Host, {User, _Server}, Msgs, Len, MaxOfflineMsgs, odbc) -> end, Msgs), odbc_queries:add_spool(Host, Query) + end; +store_offline_msg(Host, {User, _}, Msgs, Len, MaxOfflineMsgs, + riak) -> + Count = if MaxOfflineMsgs =/= infinity -> + Len + count_offline_messages(User, Host); + true -> 0 + end, + if + Count > MaxOfflineMsgs -> + discard_warn_sender(Msgs); + true -> + lists:foreach( + fun(#offline_msg{us = US, + timestamp = TS} = M) -> + ejabberd_riak:put(M, offline_msg_schema(), + [{i, TS}, {'2i', [{<<"us">>, US}]}]) + end, Msgs) end. %% Function copied from ejabberd_sm.erl: @@ -193,32 +274,12 @@ receive_all(US, Msgs, DBType) -> after 0 -> case DBType of mnesia -> Msgs; - odbc -> lists:reverse(Msgs) + odbc -> lists:reverse(Msgs); + riak -> Msgs end end. -stop(Host) -> - ejabberd_hooks:delete(offline_message_hook, Host, - ?MODULE, store_packet, 50), - ejabberd_hooks:delete(resend_offline_messages_hook, - Host, ?MODULE, pop_offline_messages, 50), - ejabberd_hooks:delete(remove_user, Host, ?MODULE, - remove_user, 50), - ejabberd_hooks:delete(anonymous_purge_hook, Host, - ?MODULE, remove_user, 50), - ejabberd_hooks:delete(disco_sm_features, Host, ?MODULE, get_sm_features, 50), - ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, get_sm_features, 50), - ejabberd_hooks:delete(webadmin_page_host, Host, - ?MODULE, webadmin_page, 50), - ejabberd_hooks:delete(webadmin_user, Host, - ?MODULE, webadmin_user, 50), - ejabberd_hooks:delete(webadmin_user_parse_query, Host, - ?MODULE, webadmin_user_parse_query, 50), - Proc = gen_mod:get_module_proc(Host, ?PROCNAME), - exit(whereis(Proc), stop), - {wait, Proc}. - -get_sm_features(Acc, _From, _To, "", _Lang) -> +get_sm_features(Acc, _From, _To, <<"">>, _Lang) -> Feats = case Acc of {result, I} -> I; _ -> [] @@ -232,11 +293,26 @@ get_sm_features(_Acc, _From, _To, ?NS_FEATURE_MSGOFFLINE, _Lang) -> get_sm_features(Acc, _From, _To, _Node, _Lang) -> Acc. - -store_packet(From, To, Packet) -> +need_to_store(LServer, Packet) -> Type = xml:get_tag_attr_s(<<"type">>, Packet), if (Type /= <<"error">>) and (Type /= <<"groupchat">>) - and (Type /= <<"headline">>) -> + and (Type /= <<"headline">>) -> + case gen_mod:get_module_opt( + LServer, ?MODULE, store_empty_body, + fun(V) when is_boolean(V) -> V end, + true) of + false -> + xml:get_subtag(Packet, <<"body">>) /= false; + true -> + true + end; + true -> + false + end. + +store_packet(From, To, Packet) -> + case need_to_store(To#jid.lserver, Packet) of + true -> case has_no_storage_hint(Packet) of false -> case check_event(From, To, Packet) of @@ -254,7 +330,7 @@ store_packet(From, To, Packet) -> end; _ -> ok end; - true -> ok + false -> ok end. has_no_storage_hint(Packet) -> @@ -421,6 +497,34 @@ pop_offline_messages(Ls, LUser, LServer, odbc) -> end, Rs); _ -> Ls + end; +pop_offline_messages(Ls, LUser, LServer, riak) -> + case ejabberd_riak:get_by_index(offline_msg, offline_msg_schema(), + <<"us">>, {LUser, LServer}) of + {ok, Rs} -> + try + lists:foreach( + fun(#offline_msg{timestamp = T}) -> + ok = ejabberd_riak:delete(offline_msg, T) + end, Rs), + TS = now(), + Ls ++ lists:map( + fun (R) -> + offline_msg_to_route(LServer, R) + end, + lists:filter( + fun(R) -> + case R#offline_msg.expire of + never -> true; + TimeStamp -> TS < TimeStamp + end + end, + lists:keysort(#offline_msg.timestamp, Rs))) + catch _:{badmatch, _} -> + Ls + end; + _ -> + Ls end. remove_expired_messages(Server) -> @@ -445,7 +549,8 @@ remove_expired_messages(_LServer, mnesia) -> ok, offline_msg) end, mnesia:transaction(F); -remove_expired_messages(_LServer, odbc) -> {atomic, ok}. +remove_expired_messages(_LServer, odbc) -> {atomic, ok}; +remove_expired_messages(_LServer, riak) -> {atomic, ok}. remove_old_messages(Days, Server) -> LServer = jlib:nameprep(Server), @@ -470,6 +575,8 @@ remove_old_messages(Days, _LServer, mnesia) -> end, mnesia:transaction(F); remove_old_messages(_Days, _LServer, odbc) -> + {atomic, ok}; +remove_old_messages(_Days, _LServer, riak) -> {atomic, ok}. remove_user(User, Server) -> @@ -484,7 +591,10 @@ remove_user(LUser, LServer, mnesia) -> mnesia:transaction(F); remove_user(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), - odbc_queries:del_spool_msg(LServer, Username). + odbc_queries:del_spool_msg(LServer, Username); +remove_user(LUser, LServer, riak) -> + {atomic, ejabberd_riak:delete_by_index(offline_msg, + <<"us">>, {LUser, LServer})}. jid_to_binary(#jid{user = U, server = S, resource = R, luser = LU, lserver = LS, lresource = LR}) -> @@ -543,8 +653,9 @@ webadmin_page(Acc, _, _) -> Acc. get_offline_els(LUser, LServer) -> get_offline_els(LUser, LServer, gen_mod:db_type(LServer, ?MODULE)). -get_offline_els(LUser, LServer, mnesia) -> - Msgs = read_all_msgs(LUser, LServer, mnesia), +get_offline_els(LUser, LServer, DBType) + when DBType == mnesia; DBType == riak -> + Msgs = read_all_msgs(LUser, LServer, DBType), lists:map( fun(Msg) -> {route, From, To, Packet} = offline_msg_to_route(LServer, Msg), @@ -553,8 +664,8 @@ get_offline_els(LUser, LServer, mnesia) -> get_offline_els(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), case catch ejabberd_odbc:sql_query(LServer, - [<<"select xml from spool where username='">>, - Username, <<"' order by seq;">>]) of + [<<"select xml from spool where username='">>, + Username, <<"' order by seq;">>]) of {selected, [<<"xml">>], Rs} -> lists:flatmap( fun([XML]) -> @@ -601,6 +712,15 @@ read_all_msgs(LUser, LServer, mnesia) -> US = {LUser, LServer}, lists:keysort(#offline_msg.timestamp, mnesia:dirty_read({offline_msg, US})); +read_all_msgs(LUser, LServer, riak) -> + case ejabberd_riak:get_by_index( + offline_msg, offline_msg_schema(), + <<"us">>, {LUser, LServer}) of + {ok, Rs} -> + lists:keysort(#offline_msg.timestamp, Rs); + _Err -> + [] + end; read_all_msgs(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), case catch ejabberd_odbc:sql_query(LServer, @@ -618,7 +738,7 @@ read_all_msgs(LUser, LServer, odbc) -> _ -> [] end. -format_user_queue(Msgs, mnesia) -> +format_user_queue(Msgs, DBType) when DBType == mnesia; DBType == riak -> lists:map(fun (#offline_msg{timestamp = TimeStamp, from = From, to = To, packet = @@ -726,6 +846,26 @@ user_queue_parse_query(LUser, LServer, Query, mnesia) -> ok; false -> nothing end; +user_queue_parse_query(LUser, LServer, Query, riak) -> + case lists:keysearch(<<"delete">>, 1, Query) of + {value, _} -> + Msgs = read_all_msgs(LUser, LServer, riak), + lists:foreach( + fun (Msg) -> + ID = jlib:encode_base64((term_to_binary(Msg))), + case lists:member({<<"selected">>, ID}, Query) of + true -> + ejabberd_riak:delete(offline_msg, + Msg#offline_msg.timestamp); + false -> + ok + end + end, + Msgs), + ok; + false -> + nothing + end; user_queue_parse_query(LUser, LServer, Query, odbc) -> Username = ejabberd_odbc:escape(LUser), case lists:keysearch(<<"delete">>, 1, Query) of @@ -784,6 +924,14 @@ get_queue_length(LUser, LServer) -> get_queue_length(LUser, LServer, mnesia) -> length(mnesia:dirty_read({offline_msg, {LUser, LServer}})); +get_queue_length(LUser, LServer, riak) -> + case ejabberd_riak:count_by_index(offline_msg, + <<"us">>, {LUser, LServer}) of + {ok, N} -> + N; + _ -> + 0 + end; get_queue_length(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), case catch ejabberd_odbc:sql_query(LServer, @@ -812,7 +960,8 @@ get_messages_subset(User, Host, MsgsAll, DBType) -> get_messages_subset2(Max, Length, MsgsAll, _DBType) when Length =< Max * 2 -> MsgsAll; -get_messages_subset2(Max, Length, MsgsAll, mnesia) -> +get_messages_subset2(Max, Length, MsgsAll, DBType) + when DBType == mnesia; DBType == riak -> FirstN = Max, {MsgsFirstN, Msgs2} = lists:split(FirstN, MsgsAll), MsgsLastN = lists:nthtail(Length - FirstN - FirstN, @@ -860,6 +1009,10 @@ delete_all_msgs(LUser, LServer, mnesia) -> mnesia:dirty_read({offline_msg, US})) end, mnesia:transaction(F); +delete_all_msgs(LUser, LServer, riak) -> + Res = ejabberd_riak:delete_by_index(offline_msg, + <<"us">>, {LUser, LServer}), + {atomic, Res}; delete_all_msgs(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), odbc_queries:del_spool_msg(LServer, Username), @@ -882,17 +1035,73 @@ webadmin_user_parse_query(Acc, _Action, _User, _Server, Acc. %% Returns as integer the number of offline messages for a given user -count_offline_messages(LUser, LServer) -> +count_offline_messages(User, Server) -> + LUser = jlib:nodeprep(User), + LServer = jlib:nameprep(Server), + DBType = gen_mod:db_type(LServer, ?MODULE), + count_offline_messages(LUser, LServer, DBType). + +count_offline_messages(LUser, LServer, mnesia) -> + US = {LUser, LServer}, + F = fun () -> + count_mnesia_records(US) + end, + case catch mnesia:async_dirty(F) of + I when is_integer(I) -> I; + _ -> 0 + end; +count_offline_messages(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), - case catch odbc_queries:count_records_where( - LServer, "spool", - <<"where username='", Username/binary, "'">>) of - {selected, [_], [[Res]]} -> - jlib:binary_to_integer(Res); + case catch odbc_queries:count_records_where(LServer, + <<"spool">>, + <<"where username='", + Username/binary, "'">>) + of + {selected, [_], [[Res]]} -> + jlib:binary_to_integer(Res); + _ -> 0 + end; +count_offline_messages(LUser, LServer, riak) -> + case ejabberd_riak:count_by_index( + offline_msg, <<"us">>, {LUser, LServer}) of + {ok, Res} -> + Res; _ -> 0 + end; +count_offline_messages(_Acc, User, Server) -> + N = count_offline_messages(User, Server), + {stop, N}. + +%% Return the number of records matching a given match expression. +%% This function is intended to be used inside a Mnesia transaction. +%% The count has been written to use the fewest possible memory by +%% getting the record by small increment and by using continuation. +-define(BATCHSIZE, 100). + +count_mnesia_records(US) -> + MatchExpression = #offline_msg{us = US, _ = '_'}, + case mnesia:select(offline_msg, [{MatchExpression, [], [[]]}], + ?BATCHSIZE, read) of + {Result, Cont} -> + Count = length(Result), + count_records_cont(Cont, Count); + '$end_of_table' -> + 0 end. +count_records_cont(Cont, Count) -> + case mnesia:select(Cont) of + {Result, Cont} -> + NewCount = Count + length(Result), + count_records_cont(Cont, NewCount); + '$end_of_table' -> + Count + end. + +offline_msg_schema() -> + {record_info(fields, offline_msg), #offline_msg{}}. + export(_Server) -> [{offline_msg, fun(Host, #offline_msg{us = {LUser, LServer}, @@ -951,5 +1160,8 @@ import(LServer) -> import(_LServer, mnesia, #offline_msg{} = Msg) -> mnesia:dirty_write(Msg); +import(_LServer, riak, #offline_msg{us = US, timestamp = TS} = M) -> + ejabberd_riak:put(M, offline_msg_schema(), + [{i, TS}, {'2i', [{<<"us">>, US}]}]); import(_, _, _) -> pass. diff --git a/src/mod_privacy.erl b/src/mod_privacy.erl index 6b852bb47..9c9ec919f 100644 --- a/src/mod_privacy.erl +++ b/src/mod_privacy.erl @@ -43,7 +43,7 @@ sql_get_privacy_list_data_by_id_t/1, sql_get_privacy_list_id_t/2, sql_set_default_privacy_list/2, - sql_set_privacy_list/2]). + sql_set_privacy_list/2, privacy_schema/0]). -include("ejabberd.hrl"). -include("logger.hrl"). @@ -52,6 +52,9 @@ -include("mod_privacy.hrl"). +privacy_schema() -> + {record_info(fields, privacy), #privacy{}}. + start(Host, Opts) -> IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1, one_queue), @@ -160,6 +163,21 @@ process_lists_get(LUser, LServer, _Active, mnesia) -> Lists), {Default, LItems} end; +process_lists_get(LUser, LServer, _Active, riak) -> + case ejabberd_riak:get(privacy, privacy_schema(), {LUser, LServer}) of + {ok, #privacy{default = Default, lists = Lists}} -> + LItems = lists:map(fun ({N, _}) -> + #xmlel{name = <<"list">>, + attrs = [{<<"name">>, N}], + children = []} + end, + Lists), + {Default, LItems}; + {error, notfound} -> + {none, []}; + {error, _} -> + error + end; process_lists_get(LUser, LServer, _Active, odbc) -> Default = case catch sql_get_default_privacy_list(LUser, LServer) @@ -209,6 +227,18 @@ process_list_get(LUser, LServer, Name, mnesia) -> _ -> not_found end end; +process_list_get(LUser, LServer, Name, riak) -> + case ejabberd_riak:get(privacy, privacy_schema(), {LUser, LServer}) of + {ok, #privacy{lists = Lists}} -> + case lists:keysearch(Name, 1, Lists) of + {value, {_, List}} -> List; + _ -> not_found + end; + {error, notfound} -> + not_found; + {error, _} -> + error + end; process_list_get(LUser, LServer, Name, odbc) -> case catch sql_get_privacy_list_id(LUser, LServer, Name) of @@ -354,6 +384,21 @@ process_default_set(LUser, LServer, {value, Name}, end end, mnesia:transaction(F); +process_default_set(LUser, LServer, {value, Name}, riak) -> + {atomic, + case ejabberd_riak:get(privacy, privacy_schema(), {LUser, LServer}) of + {ok, #privacy{lists = Lists} = P} -> + case lists:keymember(Name, 1, Lists) of + true -> + ejabberd_riak:put(P#privacy{default = Name, + lists = Lists}, + privacy_schema()); + false -> + not_found + end; + {error, _} -> + not_found + end}; process_default_set(LUser, LServer, {value, Name}, odbc) -> F = fun () -> @@ -375,6 +420,14 @@ process_default_set(LUser, LServer, false, mnesia) -> end end, mnesia:transaction(F); +process_default_set(LUser, LServer, false, riak) -> + {atomic, + case ejabberd_riak:get(privacy, privacy_schema(), {LUser, LServer}) of + {ok, R} -> + ejabberd_riak:put(R#privacy{default = none}, privacy_schema()); + {error, _} -> + ok + end}; process_default_set(LUser, LServer, false, odbc) -> case catch sql_unset_default_privacy_list(LUser, LServer) @@ -407,6 +460,16 @@ process_active_set(LUser, LServer, Name, mnesia) -> false -> error end end; +process_active_set(LUser, LServer, Name, riak) -> + case ejabberd_riak:get(privacy, privacy_schema(), {LUser, LServer}) of + {ok, #privacy{lists = Lists}} -> + case lists:keysearch(Name, 1, Lists) of + {value, {_, List}} -> List; + false -> error + end; + {error, _} -> + error + end; process_active_set(LUser, LServer, Name, odbc) -> case catch sql_get_privacy_list_id(LUser, LServer, Name) of @@ -438,6 +501,20 @@ remove_privacy_list(LUser, LServer, Name, mnesia) -> end end, mnesia:transaction(F); +remove_privacy_list(LUser, LServer, Name, riak) -> + {atomic, + case ejabberd_riak:get(privacy, privacy_schema(), {LUser, LServer}) of + {ok, #privacy{default = Default, lists = Lists} = P} -> + if Name == Default -> + conflict; + true -> + NewLists = lists:keydelete(Name, 1, Lists), + ejabberd_riak:put(P#privacy{lists = NewLists}, + privacy_schema()) + end; + {error, _} -> + ok + end}; remove_privacy_list(LUser, LServer, Name, odbc) -> F = fun () -> case sql_get_default_privacy_list_t(LUser) of @@ -465,6 +542,19 @@ set_privacy_list(LUser, LServer, Name, List, mnesia) -> end end, mnesia:transaction(F); +set_privacy_list(LUser, LServer, Name, List, riak) -> + {atomic, + case ejabberd_riak:get(privacy, privacy_schema(), {LUser, LServer}) of + {ok, #privacy{lists = Lists} = P} -> + NewLists1 = lists:keydelete(Name, 1, Lists), + NewLists = [{Name, List} | NewLists1], + ejabberd_riak:put(P#privacy{lists = NewLists}, privacy_schema()); + {error, _} -> + NewLists = [{Name, List}], + ejabberd_riak:put(#privacy{us = {LUser, LServer}, + lists = NewLists}, + privacy_schema()) + end}; set_privacy_list(LUser, LServer, Name, List, odbc) -> RItems = lists:map(fun item_to_raw/1, List), F = fun () -> @@ -649,6 +739,20 @@ get_user_list(_, LUser, LServer, mnesia) -> end; _ -> {none, []} end; +get_user_list(_, LUser, LServer, riak) -> + case ejabberd_riak:get(privacy, privacy_schema(), {LUser, LServer}) of + {ok, #privacy{default = Default, lists = Lists}} -> + case Default of + none -> {none, []}; + _ -> + case lists:keysearch(Default, 1, Lists) of + {value, {_, List}} -> {Default, List}; + _ -> {none, []} + end + end; + {error, _} -> + {none, []} + end; get_user_list(_, LUser, LServer, odbc) -> case catch sql_get_default_privacy_list(LUser, LServer) of @@ -680,6 +784,13 @@ get_user_lists(LUser, LServer, mnesia) -> _ -> error end; +get_user_lists(LUser, LServer, riak) -> + case ejabberd_riak:get(privacy, privacy_schema(), {LUser, LServer}) of + {ok, #privacy{} = P} -> + {ok, P}; + {error, _} -> + error + end; get_user_lists(LUser, LServer, odbc) -> Default = case catch sql_get_default_privacy_list(LUser, LServer) of {selected, [<<"name">>], []} -> @@ -843,6 +954,8 @@ remove_user(LUser, LServer, mnesia) -> F = fun () -> mnesia:delete({privacy, {LUser, LServer}}) end, mnesia:transaction(F); +remove_user(LUser, LServer, riak) -> + {atomic, ejabberd_riak:delete(privacy, {LUser, LServer})}; remove_user(LUser, LServer, odbc) -> sql_del_privacy_lists(LUser, LServer). @@ -1134,5 +1247,7 @@ import(LServer) -> import(_LServer, mnesia, #privacy{} = P) -> mnesia:dirty_write(P); +import(_LServer, riak, #privacy{} = P) -> + ejabberd_riak:put(P, privacy_schema()); import(_, _, _) -> pass. diff --git a/src/mod_private.erl b/src/mod_private.erl index 9fa74d9b7..9fdf09dd8 100644 --- a/src/mod_private.erl +++ b/src/mod_private.erl @@ -89,7 +89,8 @@ process_sm_iq(#jid{luser = LUser, lserver = LServer}, end, case DBType of odbc -> ejabberd_odbc:sql_transaction(LServer, F); - mnesia -> mnesia:transaction(F) + mnesia -> mnesia:transaction(F); + riak -> F() end, IQ#iq{type = result, sub_el = []} end; @@ -149,7 +150,12 @@ set_data(LUser, LServer, {XMLNS, El}, odbc) -> LXMLNS = ejabberd_odbc:escape(XMLNS), SData = ejabberd_odbc:escape(xml:element_to_binary(El)), odbc_queries:set_private_data(LServer, Username, LXMLNS, - SData). + SData); +set_data(LUser, LServer, {XMLNS, El}, riak) -> + ejabberd_riak:put(#private_storage{usns = {LUser, LServer, XMLNS}, + xml = El}, + private_storage_schema(), + [{'2i', [{<<"us">>, {LUser, LServer}}]}]). get_data(LUser, LServer, Data) -> get_data(LUser, LServer, @@ -182,13 +188,18 @@ get_data(LUser, LServer, odbc, [{XMLNS, El} | Els], Data when is_record(Data, xmlel) -> get_data(LUser, LServer, odbc, Els, [Data | Res]) end; - %% MREMOND: I wonder when the query could return a vcard ? - {selected, [<<"vcard">>], []} -> - get_data(LUser, LServer, odbc, Els, [El | Res]); _ -> get_data(LUser, LServer, odbc, Els, [El | Res]) + end; +get_data(LUser, LServer, riak, [{XMLNS, El} | Els], + Res) -> + case ejabberd_riak:get(private_storage, private_storage_schema(), + {LUser, LServer, XMLNS}) of + {ok, #private_storage{xml = NewEl}} -> + get_data(LUser, LServer, riak, Els, [NewEl|Res]); + _ -> + get_data(LUser, LServer, riak, Els, [El|Res]) end. - get_data(LUser, LServer) -> get_all_data(LUser, LServer, gen_mod:db_type(LServer, ?MODULE)). @@ -214,8 +225,20 @@ get_all_data(LUser, LServer, odbc) -> end, Res); _ -> [] + end; +get_all_data(LUser, LServer, riak) -> + case ejabberd_riak:get_by_index( + private_storage, private_storage_schema(), + <<"us">>, {LUser, LServer}) of + {ok, Res} -> + [El || #private_storage{xml = El} <- Res]; + _ -> + [] end. +private_storage_schema() -> + {record_info(fields, private_storage), #private_storage{}}. + remove_user(User, Server) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), @@ -242,7 +265,10 @@ remove_user(LUser, LServer, mnesia) -> remove_user(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), odbc_queries:del_user_private_storage(LServer, - Username). + Username); +remove_user(LUser, LServer, riak) -> + {atomic, ejabberd_riak:delete_by_index(private_storage, + <<"us">>, {LUser, LServer})}. update_table() -> Fields = record_info(fields, private_storage), @@ -287,5 +313,9 @@ import(LServer) -> import(_LServer, mnesia, #private_storage{} = PS) -> mnesia:dirty_write(PS); + +import(_LServer, riak, #private_storage{usns = {LUser, LServer, _}} = PS) -> + ejabberd_riak:put(PS, private_storage_schema(), + [{'2i', [{<<"us">>, {LUser, LServer}}]}]); import(_, _, _) -> pass. diff --git a/src/mod_pubsub.erl b/src/mod_pubsub.erl index 168169a95..e6437199b 100644 --- a/src/mod_pubsub.erl +++ b/src/mod_pubsub.erl @@ -387,7 +387,7 @@ init_send_loop(ServerHost, State) -> init_plugins(Host, ServerHost, Opts) -> TreePlugin = jlib:binary_to_atom(<<(?TREE_PREFIX)/binary, - (gen_mod:get_opt(nodetree, Opts, fun(A) when is_list(A) -> A end, + (gen_mod:get_opt(nodetree, Opts, fun(A) when is_binary(A) -> A end, ?STDTREE))/binary>>), ?DEBUG("** tree plugin is ~p", [TreePlugin]), TreePlugin:init(Host, ServerHost, Opts), @@ -690,9 +690,9 @@ update_node_database(Host, ServerHost) -> end, mnesia:transaction(fun () -> case catch mnesia:first(pubsub_node) of - {_, L} when is_binary(L) -> + {_, L} when is_list(L) -> lists:foreach(fun ({H, N}) - when is_binary(N) -> + when is_list(N) -> [Node] = mnesia:read({pubsub_node, {H, @@ -3312,7 +3312,7 @@ get_allowed_items_call(Host, NodeIdx, From, Type, Options, Owners) -> %% Number = last | integer() %% @doc <p>Resend the items of a node to the user.</p> %% @todo use cache-last-item feature -send_items(Host, Node, NodeId, Type, {U, S, R} = LJID, last) -> +send_items(Host, Node, NodeId, Type, LJID, last) -> case get_cached_item(Host, NodeId) of undefined -> send_items(Host, Node, NodeId, Type, LJID, 1); @@ -3325,24 +3325,9 @@ send_items(Host, Node, NodeId, Type, {U, S, R} = LJID, last) -> children = itemsEls([LastItem])}], ModifNow, ModifUSR), - case is_tuple(Host) of - false -> - ejabberd_router:route(service_jid(Host), - jlib:make_jid(LJID), Stanza); - true -> - case ejabberd_sm:get_session_pid(U, S, R) of - C2SPid when is_pid(C2SPid) -> - ejabberd_c2s:broadcast(C2SPid, - {pep_message, - <<((Node))/binary, "+notify">>}, - _Sender = service_jid(Host), - Stanza); - _ -> ok - end - end + dispatch_items(Host, LJID, Node, Stanza) end; -send_items(Host, Node, NodeId, Type, {U, S, R} = LJID, - Number) -> +send_items(Host, Node, NodeId, Type, LJID, Number) -> ToSend = case node_action(Host, Type, get_items, [NodeId, LJID]) of @@ -3370,22 +3355,38 @@ send_items(Host, Node, NodeId, Type, {U, S, R} = LJID, attrs = nodeAttr(Node), children = itemsEls(ToSend)}]) end, - case {is_tuple(Host), Stanza} of - {_, undefined} -> - ok; - {false, _} -> - ejabberd_router:route(service_jid(Host), - jlib:make_jid(LJID), Stanza); - {true, _} -> - case ejabberd_sm:get_session_pid(U, S, R) of - C2SPid when is_pid(C2SPid) -> - ejabberd_c2s:broadcast(C2SPid, - {pep_message, - <<((Node))/binary, "+notify">>}, - _Sender = service_jid(Host), Stanza); - _ -> ok - end - end. + dispatch_items(Host, LJID, Node, Stanza). + +-spec(dispatch_items/4 :: +( + From :: mod_pubsub:host(), + To :: jid(), + Node :: mod_pubsub:nodeId(), + Stanza :: xmlel() | undefined) + -> any() +). + +dispatch_items(_From, _To, _Node, _Stanza = undefined) -> ok; +dispatch_items({FromU, FromS, FromR} = From, {ToU, ToS, ToR} = To, Node, + Stanza) -> + C2SPid = case ejabberd_sm:get_session_pid(ToU, ToS, ToR) of + ToPid when is_pid(ToPid) -> ToPid; + _ -> + R = user_resource(FromU, FromS, FromR), + case ejabberd_sm:get_session_pid(FromU, FromS, R) of + FromPid when is_pid(FromPid) -> FromPid; + _ -> undefined + end + end, + if C2SPid == undefined -> ok; + true -> + ejabberd_c2s:send_filtered(C2SPid, + {pep_message, <<Node/binary, "+notify">>}, + service_jid(From), jlib:make_jid(To), + Stanza) + end; +dispatch_items(From, To, _Node, Stanza) -> + ejabberd_router:route(service_jid(From), jlib:make_jid(To), Stanza). %% @spec (Host, JID, Plugins) -> {error, Reason} | {result, Response} %% Host = host() diff --git a/src/mod_pubsub_odbc.erl b/src/mod_pubsub_odbc.erl index 45c30a11b..e2b357f03 100644 --- a/src/mod_pubsub_odbc.erl +++ b/src/mod_pubsub_odbc.erl @@ -385,7 +385,7 @@ init_send_loop(ServerHost, State) -> init_plugins(Host, ServerHost, Opts) -> TreePlugin = jlib:binary_to_atom(<<(?TREE_PREFIX)/binary, - (gen_mod:get_opt(nodetree, Opts, fun(A) when is_list(A) -> A end, + (gen_mod:get_opt(nodetree, Opts, fun(A) when is_binary(A) -> A end, ?STDTREE))/binary, (?ODBC_SUFFIX)/binary>>), ?DEBUG("** tree plugin is ~p", [TreePlugin]), @@ -2315,7 +2315,7 @@ create_node(Host, ServerHost, Node, Owner, GivenType, Access, Configuration) -> {result, Reply}; {result, {NodeId, _SubsByDepth, Result}} -> ejabberd_hooks:run(pubsub_create_node, ServerHost, [ServerHost, Host, Node, NodeId, NodeOptions]), - {result, Result}; + {result, Reply}; Error -> %% in case we change transaction to sync_dirty... %% node_call(Type, delete_node, [Host, Node]), @@ -3011,8 +3011,8 @@ send_items(Host, Node, NodeId, Type, LJID, last) -> itemsEls([LastItem])}], ModifNow, ModifUSR) end, - ejabberd_router:route(service_jid(Host), jlib:make_jid(LJID), Stanza); -send_items(Host, Node, NodeId, Type, {U, S, R} = LJID, Number) -> + dispatch_items(Host, LJID, Node, Stanza); +send_items(Host, Node, NodeId, Type, LJID, Number) -> ToSend = case node_action(Host, Type, get_items, [NodeId, LJID]) of @@ -3040,22 +3040,38 @@ send_items(Host, Node, NodeId, Type, {U, S, R} = LJID, Number) -> attrs = nodeAttr(Node), children = itemsEls(ToSend)}]) end, - case {is_tuple(Host), Stanza} of - {_, undefined} -> - ok; - {false, _} -> - ejabberd_router:route(service_jid(Host), - jlib:make_jid(LJID), Stanza); - {true, _} -> - case ejabberd_sm:get_session_pid(U, S, R) of - C2SPid when is_pid(C2SPid) -> - ejabberd_c2s:broadcast(C2SPid, - {pep_message, - <<((Node))/binary, "+notify">>}, - _Sender = service_jid(Host), Stanza); - _ -> ok - end - end. + dispatch_items(Host, LJID, Node, Stanza). + +-spec(dispatch_items/4 :: +( + From :: mod_pubsub:host(), + To :: jid(), + Node :: mod_pubsub:nodeId(), + Stanza :: xmlel() | undefined) + -> any() +). + +dispatch_items(_From, _To, _Node, _Stanza = undefined) -> ok; +dispatch_items({FromU, FromS, FromR} = From, {ToU, ToS, ToR} = To, Node, + Stanza) -> + C2SPid = case ejabberd_sm:get_session_pid(ToU, ToS, ToR) of + ToPid when is_pid(ToPid) -> ToPid; + _ -> + R = user_resource(FromU, FromS, FromR), + case ejabberd_sm:get_session_pid(FromU, FromS, R) of + FromPid when is_pid(FromPid) -> FromPid; + _ -> undefined + end + end, + if C2SPid == undefined -> ok; + true -> + ejabberd_c2s:send_filtered(C2SPid, + {pep_message, <<Node/binary, "+notify">>}, + service_jid(From), jlib:make_jid(To), + Stanza) + end; +dispatch_items(From, To, _Node, Stanza) -> + ejabberd_router:route(service_jid(From), jlib:make_jid(To), Stanza). %% @spec (Host, JID, Plugins) -> {error, Reason} | {result, Response} %% Host = host() diff --git a/src/mod_roster.erl b/src/mod_roster.erl index 4ab8239b5..7bd171ffb 100644 --- a/src/mod_roster.erl +++ b/src/mod_roster.erl @@ -204,6 +204,12 @@ read_roster_version(LUser, LServer, odbc) -> of {selected, [<<"version">>], [[Version]]} -> Version; {selected, [<<"version">>], []} -> error + end; +read_roster_version(LServer, LUser, riak) -> + case ejabberd_riak:get(roster_version, roster_version_schema(), + {LUser, LServer}) of + {ok, #roster_version{version = V}} -> V; + _Err -> error end. write_roster_version(LUser, LServer) -> @@ -239,7 +245,12 @@ write_roster_version(LUser, LServer, InTransaction, Ver, odbc_queries:set_roster_version(Username, EVer) end) - end. + end; +write_roster_version(LUser, LServer, _InTransaction, Ver, + riak) -> + US = {LUser, LServer}, + ejabberd_riak:put(#roster_version{us = US, version = Ver}, + roster_version_schema()). %% Load roster from DB only if neccesary. %% It is neccesary if @@ -347,6 +358,12 @@ get_roster(LUser, LServer, mnesia) -> Items when is_list(Items)-> Items; _ -> [] end; +get_roster(LUser, LServer, riak) -> + case ejabberd_riak:get_by_index(roster, roster_schema(), + <<"us">>, {LUser, LServer}) of + {ok, Items} -> Items; + _Err -> [] + end; get_roster(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), case catch odbc_queries:get_roster(LServer, Username) of @@ -455,6 +472,17 @@ get_roster_by_jid_t(LUser, LServer, LJID, odbc) -> R#roster{usj = {LUser, LServer, LJID}, us = {LUser, LServer}, jid = LJID, name = <<"">>} end + end; +get_roster_by_jid_t(LUser, LServer, LJID, riak) -> + case ejabberd_riak:get(roster, roster_schema(), {LUser, LServer, LJID}) of + {ok, I} -> + I#roster{jid = LJID, name = <<"">>, groups = [], + xs = []}; + {error, notfound} -> + #roster{usj = {LUser, LServer, LJID}, + us = {LUser, LServer}, jid = LJID}; + Err -> + exit(Err) end. try_process_iq_set(From, To, #iq{sub_el = SubEl} = IQ) -> @@ -631,8 +659,14 @@ get_subscription_lists(_, LUser, LServer, odbc) -> <<"server">>, <<"subscribe">>, <<"type">>], Items} when is_list(Items) -> - Items; + lists:map(fun(I) -> raw_to_record(LServer, I) end, Items); _ -> [] + end; +get_subscription_lists(_, LUser, LServer, riak) -> + case ejabberd_riak:get_by_index(roster, roster_schema(), + <<"us">>, {LUser, LServer}) of + {ok, Items} -> Items; + _Err -> [] end. fill_subscription_lists(LServer, [#roster{} = I | Is], @@ -671,12 +705,16 @@ roster_subscribe_t(LUser, LServer, LJID, Item, odbc) -> Username = ejabberd_odbc:escape(LUser), SJID = ejabberd_odbc:escape(jlib:jid_to_string(LJID)), odbc_queries:roster_subscribe(LServer, Username, SJID, - ItemVals). + ItemVals); +roster_subscribe_t(LUser, LServer, _LJID, Item, riak) -> + ejabberd_riak:put(Item, roster_schema(), + [{'2i', [{<<"us">>, {LUser, LServer}}]}]). transaction(LServer, F) -> case gen_mod:db_type(LServer, ?MODULE) of mnesia -> mnesia:transaction(F); - odbc -> ejabberd_odbc:sql_transaction(LServer, F) + odbc -> ejabberd_odbc:sql_transaction(LServer, F); + riak -> {atomic, F()} end. in_subscription(_, User, Server, JID, Type, Reason) -> @@ -727,6 +765,16 @@ get_roster_by_jid_with_groups_t(LUser, LServer, LJID, []} -> #roster{usj = {LUser, LServer, LJID}, us = {LUser, LServer}, jid = LJID} + end; +get_roster_by_jid_with_groups_t(LUser, LServer, LJID, riak) -> + case ejabberd_riak:get(roster, roster_schema(), {LUser, LServer, LJID}) of + {ok, I} -> + I; + {error, notfound} -> + #roster{usj = {LUser, LServer, LJID}, + us = {LUser, LServer}, jid = LJID}; + Err -> + exit(Err) end. process_subscription(Direction, User, Server, JID1, @@ -924,12 +972,12 @@ in_auto_reply(_, _, _) -> none. remove_user(User, Server) -> LUser = jlib:nodeprep(User), LServer = jlib:nameprep(Server), + send_unsubscription_to_rosteritems(LUser, LServer), remove_user(LUser, LServer, gen_mod:db_type(LServer, ?MODULE)). remove_user(LUser, LServer, mnesia) -> US = {LUser, LServer}, - send_unsubscription_to_rosteritems(LUser, LServer), F = fun () -> lists:foreach(fun (R) -> mnesia:delete_object(R) end, mnesia:index_read(roster, US, #roster.us)) @@ -937,9 +985,10 @@ remove_user(LUser, LServer, mnesia) -> mnesia:transaction(F); remove_user(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), - send_unsubscription_to_rosteritems(LUser, LServer), odbc_queries:del_user_roster_t(LServer, Username), - ok. + ok; +remove_user(LUser, LServer, riak) -> + {atomic, ejabberd_riak:delete_by_index(roster, <<"us">>, {LUser, LServer})}. %% For each contact with Subscription: %% Both or From, send a "unsubscribed" presence stanza; @@ -1009,7 +1058,11 @@ update_roster_t(LUser, LServer, LJID, Item, odbc) -> SJID = ejabberd_odbc:escape(jlib:jid_to_string(LJID)), ItemVals = record_to_string(Item), ItemGroups = groups_to_string(Item), - odbc_queries:update_roster(LServer, Username, SJID, ItemVals, ItemGroups). + odbc_queries:update_roster(LServer, Username, SJID, ItemVals, + ItemGroups); +update_roster_t(LUser, LServer, _LJID, Item, riak) -> + ejabberd_riak:put(Item, roster_schema(), + [{'2i', [{<<"us">>, {LUser, LServer}}]}]). del_roster_t(LUser, LServer, LJID) -> DBType = gen_mod:db_type(LServer, ?MODULE), @@ -1020,7 +1073,9 @@ del_roster_t(LUser, LServer, LJID, mnesia) -> del_roster_t(LUser, LServer, LJID, odbc) -> Username = ejabberd_odbc:escape(LUser), SJID = ejabberd_odbc:escape(jlib:jid_to_string(LJID)), - odbc_queries:del_roster(LServer, Username, SJID). + odbc_queries:del_roster(LServer, Username, SJID); +del_roster_t(LUser, LServer, LJID, riak) -> + ejabberd_riak:delete(roster, {LUser, LServer, LJID}). process_item_set_t(LUser, LServer, #xmlel{attrs = Attrs, children = Els}) -> @@ -1086,40 +1141,35 @@ get_in_pending_subscriptions(Ls, User, Server) -> get_in_pending_subscriptions(Ls, User, Server, gen_mod:db_type(LServer, ?MODULE)). -get_in_pending_subscriptions(Ls, User, Server, - mnesia) -> +get_in_pending_subscriptions(Ls, User, Server, DBType) + when DBType == mnesia; DBType == riak -> JID = jlib:make_jid(User, Server, <<"">>), - US = {JID#jid.luser, JID#jid.lserver}, - case mnesia:dirty_index_read(roster, US, #roster.us) of - Result when is_list(Result) -> - Ls ++ - lists:map(fun (R) -> - Message = R#roster.askmessage, - Status = if is_binary(Message) -> (Message); - true -> <<"">> - end, - #xmlel{name = <<"presence">>, - attrs = - [{<<"from">>, - jlib:jid_to_string(R#roster.jid)}, - {<<"to">>, jlib:jid_to_string(JID)}, - {<<"type">>, <<"subscribe">>}], - children = - [#xmlel{name = <<"status">>, - attrs = [], - children = - [{xmlcdata, Status}]}]} - end, - lists:filter(fun (R) -> - case R#roster.ask of - in -> true; - both -> true; - _ -> false - end - end, - Result)); - _ -> Ls - end; + Result = get_roster(JID#jid.luser, JID#jid.lserver, DBType), + Ls ++ lists:map(fun (R) -> + Message = R#roster.askmessage, + Status = if is_binary(Message) -> (Message); + true -> <<"">> + end, + #xmlel{name = <<"presence">>, + attrs = + [{<<"from">>, + jlib:jid_to_string(R#roster.jid)}, + {<<"to">>, jlib:jid_to_string(JID)}, + {<<"type">>, <<"subscribe">>}], + children = + [#xmlel{name = <<"status">>, + attrs = [], + children = + [{xmlcdata, Status}]}]} + end, + lists:filter(fun (R) -> + case R#roster.ask of + in -> true; + both -> true; + _ -> false + end + end, + Result)); get_in_pending_subscriptions(Ls, User, Server, odbc) -> JID = jlib:make_jid(User, Server, <<"">>), LUser = JID#jid.luser, @@ -1188,7 +1238,7 @@ read_subscription_and_groups(LUser, LServer, LJID, case catch odbc_queries:get_subscription(LServer, Username, SJID) of - {selected, [<<"subscription">>], [{SSubscription}]} -> + {selected, [<<"subscription">>], [[SSubscription]]} -> Subscription = case SSubscription of <<"B">> -> both; <<"T">> -> to; @@ -1205,6 +1255,15 @@ read_subscription_and_groups(LUser, LServer, LJID, end, {Subscription, Groups}; _ -> error + end; +read_subscription_and_groups(LUser, LServer, LJID, + riak) -> + case ejabberd_riak:get(roster, roster_schema(), {LUser, LServer, LJID}) of + {ok, #roster{subscription = Subscription, + groups = Groups}} -> + {Subscription, Groups}; + _ -> + error end. get_jid_info(_, User, Server, JID) -> @@ -1319,7 +1378,8 @@ update_roster_table() -> iolist_to_binary(R2)}, name = iolist_to_binary(Name), groups = [iolist_to_binary(G) || G <- Gs], - askmessage = iolist_to_binary(Ask), + askmessage = try iolist_to_binary(Ask) + catch _:_ -> <<"">> end, xs = [xml:to_xmlel(X) || X <- Xs]} end); _ -> @@ -1642,6 +1702,11 @@ is_managed_from_id(<<"roster-remotely-managed">>) -> is_managed_from_id(_Id) -> false. +roster_schema() -> + {record_info(fields, roster), #roster{}}. + +roster_version_schema() -> + {record_info(fields, roster_version), #roster_version{}}. export(_Server) -> [{roster, @@ -1692,5 +1757,10 @@ import(_LServer, mnesia, #roster{} = R) -> mnesia:dirty_write(R); import(_LServer, mnesia, #roster_version{} = RV) -> mnesia:dirty_write(RV); +import(_LServer, riak, #roster{us = {LUser, LServer}} = R) -> + ejabberd_riak:put(R, roster_schema(), + [{'2i', [{<<"us">>, {LUser, LServer}}]}]); +import(_LServer, riak, #roster_version{} = RV) -> + ejabberd_riak:put(RV, roster_version_schema()); import(_, _, _) -> pass. diff --git a/src/mod_shared_roster.erl b/src/mod_shared_roster.erl index 8a1423c76..916285660 100644 --- a/src/mod_shared_roster.erl +++ b/src/mod_shared_roster.erl @@ -400,6 +400,13 @@ list_groups(Host, mnesia) -> mnesia:dirty_select(sr_group, [{#sr_group{group_host = {'$1', '$2'}, _ = '_'}, [{'==', '$2', Host}], ['$1']}]); +list_groups(Host, riak) -> + case ejabberd_riak:get_keys_by_index(sr_group, <<"host">>, Host) of + {ok, Gs} -> + [G || {G, _} <- Gs]; + _ -> + [] + end; list_groups(Host, odbc) -> case ejabberd_odbc:sql_query(Host, [<<"select name from sr_group;">>]) @@ -417,6 +424,14 @@ groups_with_opts(Host, mnesia) -> _ = '_'}, [], [['$1', '$2']]}]), lists:map(fun ([G, O]) -> {G, O} end, Gs); +groups_with_opts(Host, riak) -> + case ejabberd_riak:get_by_index(sr_group, sr_group_schema(), + <<"host">>, Host) of + {ok, Rs} -> + [{G, O} || #sr_group{group_host = {G, _}, opts = O} <- Rs]; + _ -> + [] + end; groups_with_opts(Host, odbc) -> case ejabberd_odbc:sql_query(Host, [<<"select name, opts from sr_group;">>]) @@ -438,6 +453,11 @@ create_group(Host, Group, Opts, mnesia) -> R = #sr_group{group_host = {Group, Host}, opts = Opts}, F = fun () -> mnesia:write(R) end, mnesia:transaction(F); +create_group(Host, Group, Opts, riak) -> + {atomic, ejabberd_riak:put(#sr_group{group_host = {Group, Host}, + opts = Opts}, + sr_group_schema(), + [{'2i', [{<<"host">>, Host}]}])}; create_group(Host, Group, Opts, odbc) -> SGroup = ejabberd_odbc:escape(Group), SOpts = ejabberd_odbc:encode_term(Opts), @@ -464,6 +484,15 @@ delete_group(Host, Group, mnesia) -> Users) end, mnesia:transaction(F); +delete_group(Host, Group, riak) -> + try + ok = ejabberd_riak:delete(sr_group, {Group, Host}), + ok = ejabberd_riak:delete_by_index(sr_user, <<"group_host">>, + {Group, Host}), + {atomic, ok} + catch _:{badmatch, Err} -> + {atomic, Err} + end; delete_group(Host, Group, odbc) -> SGroup = ejabberd_odbc:escape(Group), F = fun () -> @@ -472,7 +501,10 @@ delete_group(Host, Group, odbc) -> ejabberd_odbc:sql_query_t([<<"delete from sr_user where grp='">>, SGroup, <<"';">>]) end, - ejabberd_odbc:sql_transaction(Host, F). + case ejabberd_odbc:sql_transaction(Host, F) of + {atomic,{updated,_}} -> {atomic, ok}; + Res -> Res + end. get_group_opts(Host, Group) -> get_group_opts(Host, Group, @@ -483,6 +515,11 @@ get_group_opts(Host, Group, mnesia) -> [#sr_group{opts = Opts}] -> Opts; _ -> error end; +get_group_opts(Host, Group, riak) -> + case ejabberd_riak:get(sr_group, sr_group_schema(), {Group, Host}) of + {ok, #sr_group{opts = Opts}} -> Opts; + _ -> error + end; get_group_opts(Host, Group, odbc) -> SGroup = ejabberd_odbc:escape(Group), case catch ejabberd_odbc:sql_query(Host, @@ -502,6 +539,11 @@ set_group_opts(Host, Group, Opts, mnesia) -> R = #sr_group{group_host = {Group, Host}, opts = Opts}, F = fun () -> mnesia:write(R) end, mnesia:transaction(F); +set_group_opts(Host, Group, Opts, riak) -> + {atomic, ejabberd_riak:put(#sr_group{group_host = {Group, Host}, + opts = Opts}, + sr_group_schema(), + [{'2i', [{<<"host">>, Host}]}])}; set_group_opts(Host, Group, Opts, odbc) -> SGroup = ejabberd_odbc:escape(Group), SOpts = ejabberd_odbc:encode_term(Opts), @@ -525,6 +567,13 @@ get_user_groups(US, Host, mnesia) -> || #sr_user{group_host = {Group, H}} <- Rs, H == Host]; _ -> [] end; +get_user_groups(US, Host, riak) -> + case ejabberd_riak:get_by_index(sr_user, sr_user_schema(), <<"us">>, US) of + {ok, Rs} -> + [Group || #sr_user{group_host = {Group, H}} <- Rs, H == Host]; + _ -> + [] + end; get_user_groups(US, Host, odbc) -> SJID = make_jid_s(US), case catch ejabberd_odbc:sql_query(Host, @@ -595,6 +644,14 @@ get_group_explicit_users(Host, Group, mnesia) -> Rs when is_list(Rs) -> [R#sr_user.us || R <- Rs]; _ -> [] end; +get_group_explicit_users(Host, Group, riak) -> + case ejabberd_riak:get_by_index(sr_user, sr_user_schema(), + <<"group_host">>, {Group, Host}) of + {ok, Rs} -> + [R#sr_user.us || R <- Rs]; + _ -> + [] + end; get_group_explicit_users(Host, Group, odbc) -> SGroup = ejabberd_odbc:escape(Group), case catch ejabberd_odbc:sql_query(Host, @@ -681,6 +738,16 @@ get_user_displayed_groups(LUser, LServer, GroupsOpts, _ -> [] end; get_user_displayed_groups(LUser, LServer, GroupsOpts, + riak) -> + case ejabberd_riak:get_by_index(sr_user, sr_user_schema(), + <<"us">>, {LUser, LServer}) of + {ok, Rs} -> + [{Group, proplists:get_value(Group, GroupsOpts, [])} + || #sr_user{group_host = {Group, _}} <- Rs]; + _ -> + [] + end; +get_user_displayed_groups(LUser, LServer, GroupsOpts, odbc) -> SJID = make_jid_s(LUser, LServer), case catch ejabberd_odbc:sql_query(LServer, @@ -726,6 +793,21 @@ is_user_in_group(US, Group, Host, mnesia) -> [] -> lists:member(US, get_group_users(Host, Group)); _ -> true end; +is_user_in_group(US, Group, Host, riak) -> + case ejabberd_riak:get_by_index(sr_user, sr_user_schema(), <<"us">>, US) of + {ok, Rs} -> + case lists:any( + fun(#sr_user{group_host = {G, H}}) -> + (Group == G) and (Host == H) + end, Rs) of + false -> + lists:member(US, get_group_users(Host, Group)); + true -> + true + end; + _Err -> + false + end; is_user_in_group(US, Group, Host, odbc) -> SJID = make_jid_s(US), SGroup = ejabberd_odbc:escape(Group), @@ -765,6 +847,13 @@ add_user_to_group(Host, US, Group, mnesia) -> R = #sr_user{us = US, group_host = {Group, Host}}, F = fun () -> mnesia:write(R) end, mnesia:transaction(F); +add_user_to_group(Host, US, Group, riak) -> + {atomic, ejabberd_riak:put( + #sr_user{us = US, group_host = {Group, Host}}, + sr_user_schema(), + [{i, {US, {Group, Host}}}, + {'2i', [{<<"us">>, US}, + {<<"group_host">>, {Group, Host}}]}])}; add_user_to_group(Host, US, Group, odbc) -> SJID = make_jid_s(US), SGroup = ejabberd_odbc:escape(Group), @@ -816,6 +905,8 @@ remove_user_from_group(Host, US, Group, mnesia) -> R = #sr_user{us = US, group_host = {Group, Host}}, F = fun () -> mnesia:delete_object(R) end, mnesia:transaction(F); +remove_user_from_group(Host, US, Group, riak) -> + {atomic, ejabberd_riak:delete(sr_group, {US, {Group, Host}})}; remove_user_from_group(Host, US, Group, odbc) -> SJID = make_jid_s(US), SGroup = ejabberd_odbc:escape(Group), @@ -1274,6 +1365,12 @@ opts_to_binary(Opts) -> Opt end, Opts). +sr_group_schema() -> + {record_info(fields, sr_group), #sr_group{}}. + +sr_user_schema() -> + {record_info(fields, sr_user), #sr_user{}}. + update_tables() -> update_sr_group_table(), update_sr_user_table(). @@ -1355,7 +1452,15 @@ import(LServer) -> import(_LServer, mnesia, #sr_group{} = G) -> mnesia:dirty_write(G); + import(_LServer, mnesia, #sr_user{} = U) -> mnesia:dirty_write(U); +import(_LServer, riak, #sr_group{group_host = {_, Host}} = G) -> + ejabberd_riak:put(G, sr_group_schema(), [{'2i', [{<<"host">>, Host}]}]); +import(_LServer, riak, #sr_user{us = US, group_host = {Group, Host}} = User) -> + ejabberd_riak:put(User, sr_user_schema(), + [{i, {US, {Group, Host}}}, + {'2i', [{<<"us">>, US}, + {<<"group_host">>, {Group, Host}}]}]); import(_, _, _) -> pass. diff --git a/src/mod_sip.erl b/src/mod_sip.erl index 4b733c623..bf57de75c 100644 --- a/src/mod_sip.erl +++ b/src/mod_sip.erl @@ -20,7 +20,7 @@ -include("ejabberd.hrl"). -include("logger.hrl"). --include("esip.hrl"). +-include_lib("esip/include/esip.hrl"). %%%=================================================================== %%% API @@ -68,6 +68,8 @@ message_in(#sip{type = request, method = M} = Req, SIPSock) Action -> request(Req, SIPSock, undefined, Action) end; +message_in(ping, SIPSock) -> + mod_sip_registrar:ping(SIPSock); message_in(_, _) -> ok. @@ -80,12 +82,14 @@ response(_Resp, _SIPSock) -> request(#sip{method = <<"ACK">>} = Req, SIPSock) -> case action(Req, SIPSock) of {relay, LServer} -> - mod_sip_proxy:route(Req, LServer, []); + mod_sip_proxy:route(Req, LServer, [{authenticated, true}]); + {proxy_auth, LServer} -> + mod_sip_proxy:route(Req, LServer, [{authenticated, false}]); _ -> - error + ok end; request(_Req, _SIPSock) -> - error. + ok. request(Req, SIPSock, TrID) -> request(Req, SIPSock, TrID, action(Req, SIPSock)). @@ -112,20 +116,20 @@ request(Req, SIPSock, TrID, Action) -> ?INFO_MSG("failed to proxy request ~p: ~p", [Req, Err]), Err end; - {proxy_auth, Host} -> + {proxy_auth, LServer} -> make_response( Req, #sip{status = 407, type = response, hdrs = [{'proxy-authenticate', - make_auth_hdr(Host)}]}); - {auth, Host} -> + make_auth_hdr(LServer)}]}); + {auth, LServer} -> make_response( Req, #sip{status = 401, type = response, hdrs = [{'www-authenticate', - make_auth_hdr(Host)}]}); + make_auth_hdr(LServer)}]}); deny -> make_response(Req, #sip{status = 403, type = response}); @@ -158,8 +162,9 @@ action(#sip{method = <<"REGISTER">>, type = request, hdrs = Hdrs, uri = #uri{user = <<"">>} = URI} = Req, SIPSock) -> case at_my_host(URI) of true -> - case esip:get_hdrs('require', Hdrs) of - [_|_] = Require -> + Require = esip:get_hdrs('require', Hdrs) -- supported(), + case Require of + [_|_] -> {unsupported, Require}; _ -> {_, ToURI, _} = esip:get_hdr('to', Hdrs), @@ -169,7 +174,7 @@ action(#sip{method = <<"REGISTER">>, type = request, hdrs = Hdrs, true -> register; false -> - {auth, ToURI#uri.host} + {auth, jlib:nameprep(ToURI#uri.host)} end; false -> deny @@ -185,8 +190,9 @@ action(#sip{method = Method, hdrs = Hdrs, type = request} = Req, SIPSock) -> 0 -> loop; _ -> - case esip:get_hdrs('proxy-require', Hdrs) of - [_|_] = Require -> + Require = esip:get_hdrs('proxy-require', Hdrs) -- supported(), + case Require of + [_|_] -> {unsupported, Require}; _ -> {_, ToURI, _} = esip:get_hdr('to', Hdrs), @@ -249,9 +255,13 @@ check_auth(#sip{method = Method, hdrs = Hdrs, body = Body}, AuthHdr, _SIPSock) - allow() -> [<<"OPTIONS">>, <<"REGISTER">>]. +supported() -> + [<<"path">>, <<"outbound">>]. + process(#sip{method = <<"OPTIONS">>} = Req, _) -> make_response(Req, #sip{type = response, status = 200, - hdrs = [{'allow', allow()}]}); + hdrs = [{'allow', allow()}, + {'supported', supported()}]}); process(#sip{method = <<"REGISTER">>} = Req, _) -> make_response(Req, #sip{type = response, status = 400}); process(Req, _) -> @@ -259,8 +269,7 @@ process(Req, _) -> hdrs = [{'allow', allow()}]}). make_auth_hdr(LServer) -> - Realm = jlib:nameprep(LServer), - {<<"Digest">>, [{<<"realm">>, esip:quote(Realm)}, + {<<"Digest">>, [{<<"realm">>, esip:quote(LServer)}, {<<"qop">>, esip:quote(<<"auth">>)}, {<<"nonce">>, esip:quote(esip:make_hexstr(20))}]}. diff --git a/src/mod_sip_proxy.erl b/src/mod_sip_proxy.erl index b05c49061..b2f76dbb3 100644 --- a/src/mod_sip_proxy.erl +++ b/src/mod_sip_proxy.erl @@ -21,7 +21,9 @@ -include("ejabberd.hrl"). -include("logger.hrl"). --include("esip.hrl"). +-include_lib("esip/include/esip.hrl"). + +-define(SIGN_LIFETIME, 300). %% in seconds. -record(state, {host = <<"">> :: binary(), opts = [] :: [{certfile, binary()}], @@ -42,15 +44,39 @@ start_link(LServer, Opts) -> route(SIPMsg, _SIPSock, TrID, Pid) -> ?GEN_FSM:send_event(Pid, {SIPMsg, TrID}). -route(Req, LServer, Opts) -> +route(#sip{hdrs = Hdrs} = Req, LServer, Opts) -> + case proplists:get_bool(authenticated, Opts) of + true -> + route_statelessly(Req, LServer, Opts); + false -> + ConfiguredRRoute = get_configured_record_route(LServer), + case esip:get_hdrs('route', Hdrs) of + [{_, URI, _}|_] -> + case cmp_uri(URI, ConfiguredRRoute) of + true -> + case is_signed_by_me(URI#uri.user, Hdrs) of + true -> + route_statelessly(Req, LServer, Opts); + false -> + error + end; + false -> + error + end; + [] -> + error + end + end. + +route_statelessly(Req, LServer, Opts) -> Req1 = prepare_request(LServer, Req), case connect(Req1, add_certfile(LServer, Opts)) of - {ok, SIPSockets} -> + {ok, SIPSocketsWithURIs} -> lists:foreach( - fun(SIPSocket) -> + fun({SIPSocket, _URI}) -> Req2 = add_via(SIPSocket, LServer, Req1), esip:send(SIPSocket, Req2) - end, SIPSockets); + end, SIPSocketsWithURIs); _ -> error end. @@ -66,13 +92,14 @@ wait_for_request({#sip{type = request} = Req, TrID}, State) -> Opts = State#state.opts, Req1 = prepare_request(State#state.host, Req), case connect(Req1, Opts) of - {ok, SIPSockets} -> + {ok, SIPSocketsWithURIs} -> NewState = lists:foldl( - fun(_SIPSocket, {error, _} = Err) -> + fun(_SIPSocketWithURI, {error, _} = Err) -> Err; - (SIPSocket, #state{tr_ids = TrIDs} = AccState) -> - Req2 = add_record_route(SIPSocket, State#state.host, Req1), + ({SIPSocket, URI}, #state{tr_ids = TrIDs} = AccState) -> + Req2 = add_record_route_and_set_uri( + URI, State#state.host, Req1), Req3 = add_via(SIPSocket, State#state.host, Req2), case esip:request(SIPSocket, Req3, {?MODULE, route, [self()]}) of @@ -83,7 +110,7 @@ wait_for_request({#sip{type = request} = Req, TrID}, State) -> cancel_pending_transactions(AccState), Err end - end, State, SIPSockets), + end, State, SIPSocketsWithURIs), case NewState of {error, _} = Err -> {Status, Reason} = esip:error_status(Err), @@ -214,7 +241,7 @@ connect(#sip{hdrs = Hdrs} = Req, Opts) -> false -> case esip:connect(Req, Opts) of {ok, SIPSock} -> - {ok, [SIPSock]}; + {ok, [{SIPSock, Req#sip.uri}]}; {error, _} = Err -> Err end @@ -244,18 +271,68 @@ add_via(#sip_socket{type = Transport}, LServer, #sip{hdrs = Hdrs} = Req) -> Via = #via{transport = ViaTransport, host = ViaHost, port = ViaPort, - params = [{<<"branch">>, esip:make_branch()}, - {<<"rport">>, <<"">>}]}, + params = [{<<"branch">>, esip:make_branch()}]}, Req#sip{hdrs = [{'via', [Via]}|Hdrs]}. -add_record_route(_SIPSocket, LServer, #sip{hdrs = Hdrs} = Req) -> - URI = #uri{host = LServer, params = [{<<"lr">>, <<"">>}]}, - Hdrs1 = [{'record-route', [{<<>>, URI, []}]}|Hdrs], - Req#sip{hdrs = Hdrs1}. +add_record_route_and_set_uri(URI, LServer, #sip{hdrs = Hdrs} = Req) -> + case is_request_within_dialog(Req) of + false -> + case need_record_route(LServer) of + true -> + RR_URI = get_configured_record_route(LServer), + {MSecs, Secs, _} = now(), + TS = list_to_binary(integer_to_list(MSecs*1000000 + Secs)), + Sign = make_sign(TS, Hdrs), + User = <<TS/binary, $-, Sign/binary>>, + NewRR_URI = RR_URI#uri{user = User}, + Hdrs1 = [{'record-route', [{<<>>, NewRR_URI, []}]}|Hdrs], + Req#sip{uri = URI, hdrs = Hdrs1}; + false -> + Req + end; + true -> + Req + end. + +is_request_within_dialog(#sip{hdrs = Hdrs}) -> + {_, _, Params} = esip:get_hdr('to', Hdrs), + esip:has_param(<<"tag">>, Params). + +need_record_route(LServer) -> + gen_mod:get_module_opt( + LServer, mod_sip, always_record_route, + fun(true) -> true; + (false) -> false + end, true). + +make_sign(TS, Hdrs) -> + {_, #uri{user = FUser, host = FServer}, FParams} = esip:get_hdr('from', Hdrs), + {_, #uri{user = TUser, host = TServer}, _} = esip:get_hdr('to', Hdrs), + LFUser = safe_nodeprep(FUser), + LTUser = safe_nodeprep(TUser), + LFServer = safe_nameprep(FServer), + LTServer = safe_nameprep(TServer), + FromTag = esip:get_param(<<"tag">>, FParams), + CallID = esip:get_hdr('call-id', Hdrs), + SharedKey = ejabberd_config:get_option(shared_key, fun(V) -> V end), + p1_sha:sha([SharedKey, LFUser, LFServer, LTUser, LTServer, + FromTag, CallID, TS]). + +is_signed_by_me(TS_Sign, Hdrs) -> + try + [TSBin, Sign] = str:tokens(TS_Sign, <<"-">>), + TS = list_to_integer(binary_to_list(TSBin)), + {MSecs, Secs, _} = now(), + NowTS = MSecs*1000000 + Secs, + true = (NowTS - TS) =< ?SIGN_LIFETIME, + Sign == make_sign(TSBin, Hdrs) + catch _:_ -> + false + end. get_configured_vias(LServer) -> gen_mod:get_module_opt( - LServer, ?MODULE, via, + LServer, mod_sip, via, fun(L) -> lists:map( fun(Opts) -> @@ -271,6 +348,25 @@ get_configured_vias(LServer) -> end, L) end, []). +get_configured_record_route(LServer) -> + gen_mod:get_module_opt( + LServer, mod_sip, record_route, + fun(IOList) -> + S = iolist_to_binary(IOList), + #uri{} = esip:decode_uri(S) + end, #uri{host = LServer, params = [{<<"lr">>, <<"">>}]}). + +get_configured_routes(LServer) -> + gen_mod:get_module_opt( + LServer, mod_sip, routes, + fun(L) -> + lists:map( + fun(IOList) -> + S = iolist_to_binary(IOList), + #uri{} = esip:decode_uri(S) + end, L) + end, [#uri{host = LServer, params = [{<<"lr">>, <<"">>}]}]). + mark_transaction_as_complete(TrID, State) -> NewTrIDs = lists:delete(TrID, State#state.tr_ids), State#state{tr_ids = NewTrIDs}. @@ -295,13 +391,23 @@ choose_best_response(#state{responses = Responses} = State) -> end end. -prepare_request(Host, #sip{hdrs = Hdrs} = Req) -> +%% Just compare host part only. +cmp_uri(#uri{host = H1}, #uri{host = H2}) -> + jlib:nameprep(H1) == jlib:nameprep(H2). + +is_my_route(URI, URIs) -> + lists:any(fun(U) -> cmp_uri(URI, U) end, URIs). + +prepare_request(LServer, #sip{hdrs = Hdrs} = Req) -> + ConfiguredRRoute = get_configured_record_route(LServer), + ConfiguredRoutes = get_configured_routes(LServer), Hdrs1 = lists:flatmap( fun({Hdr, HdrList}) when Hdr == 'route'; Hdr == 'record-route' -> case lists:filter( - fun({_, #uri{user = <<"">>, host = Host1}, _}) -> - Host1 /= Host + fun({_, URI, _}) -> + not cmp_uri(URI, ConfiguredRRoute) + and not is_my_route(URI, ConfiguredRoutes) end, HdrList) of [] -> []; @@ -321,3 +427,15 @@ prepare_request(Host, #sip{hdrs = Hdrs} = Req) -> true end, Hdrs2), Req#sip{hdrs = Hdrs3}. + +safe_nodeprep(S) -> + case jlib:nodeprep(S) of + error -> S; + S1 -> S1 + end. + +safe_nameprep(S) -> + case jlib:nameprep(S) of + error -> S; + S1 -> S1 + end. diff --git a/src/mod_sip_registrar.erl b/src/mod_sip_registrar.erl index 57c55be08..298c7108b 100644 --- a/src/mod_sip_registrar.erl +++ b/src/mod_sip_registrar.erl @@ -12,7 +12,7 @@ -behaviour(?GEN_SERVER). %% API --export([start_link/0, request/2, find_sockets/2]). +-export([start_link/0, request/2, find_sockets/2, ping/1]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, @@ -20,19 +20,23 @@ -include("ejabberd.hrl"). -include("logger.hrl"). --include("esip.hrl"). +-include_lib("esip/include/esip.hrl"). -define(CALL_TIMEOUT, timer:seconds(30)). - --record(binding, {socket = #sip_socket{}, - call_id = <<"">> :: binary(), - cseq = 0 :: non_neg_integer(), - timestamp = now() :: erlang:timestamp(), - tref = make_ref() :: reference(), - expires = 0 :: non_neg_integer()}). +-define(DEFAULT_EXPIRES, 3600). +-define(FLOW_TIMEOUT_UDP, 29). +-define(FLOW_TIMEOUT_TCP, 120). -record(sip_session, {us = {<<"">>, <<"">>} :: {binary(), binary()}, - bindings = [] :: [#binding{}]}). + socket = #sip_socket{} :: #sip_socket{}, + call_id = <<"">> :: binary(), + cseq = 0 :: non_neg_integer(), + timestamp = now() :: erlang:timestamp(), + contact :: {binary(), #uri{}, [{binary(), binary()}]}, + flow_tref :: reference(), + reg_tref = make_ref() :: reference(), + conn_mref = make_ref() :: reference(), + expires = 0 :: non_neg_integer()}). -record(state, {}). @@ -50,15 +54,21 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) -> US = {LUser, LServer}, CallID = esip:get_hdr('call-id', Hdrs), CSeq = esip:get_hdr('cseq', Hdrs), - Expires = esip:get_hdr('expires', Hdrs, 0), + Expires = esip:get_hdr('expires', Hdrs, ?DEFAULT_EXPIRES), + Supported = esip:get_hdrs('supported', Hdrs), + IsOutboundSupported = lists:member(<<"outbound">>, Supported), case esip:get_hdrs('contact', Hdrs) of [<<"*">>] when Expires == 0 -> - case unregister_session(US, SIPSock, CallID, CSeq) of - ok -> + case unregister_session(US, CallID, CSeq) of + {ok, ContactsWithExpires} -> ?INFO_MSG("unregister SIP session for user ~s@~s from ~s", [LUser, LServer, inet_parse:ntoa(PeerIP)]), + Cs = prepare_contacts_to_send(ContactsWithExpires), mod_sip:make_response( - Req, #sip{type = response, status = 200}); + Req, + #sip{type = response, + status = 200, + hdrs = [{'contact', Cs}]}); {error, Why} -> {Status, Reason} = make_status(Why), mod_sip:make_response( @@ -67,51 +77,40 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) -> reason = Reason}) end; [{_, _URI, _Params}|_] = Contacts -> - ExpiresList = lists:map( - fun({_, _, Params}) -> - case to_integer( - esip:get_param( - <<"expires">>, Params), - 0, (1 bsl 32)-1) of - {ok, E} -> E; - _ -> Expires - end - end, Contacts), - Expires1 = lists:max(ExpiresList), - Contact = {<<"">>, #uri{user = LUser, host = LServer}, - [{<<"expires">>, jlib:integer_to_binary(Expires1)}]}, + ContactsWithExpires = make_contacts_with_expires(Contacts, Expires), + ContactsHaveManyRegID = contacts_have_many_reg_id(Contacts), + Expires1 = lists:max([E || {_, E} <- ContactsWithExpires]), MinExpires = min_expires(), - if Expires1 >= MinExpires -> - case register_session(US, SIPSock, CallID, CSeq, Expires1) of - ok -> - ?INFO_MSG("register SIP session for user ~s@~s from ~s", - [LUser, LServer, inet_parse:ntoa(PeerIP)]), - mod_sip:make_response( - Req, - #sip{type = response, - status = 200, - hdrs = [{'contact', [Contact]}]}); - {error, Why} -> - {Status, Reason} = make_status(Why), - mod_sip:make_response( - Req, #sip{type = response, - status = Status, - reason = Reason}) - end; - Expires1 > 0, Expires1 < MinExpires -> - mod_sip:make_response( + if Expires1 > 0, Expires1 < MinExpires -> + mod_sip:make_response( Req, #sip{type = response, status = 423, hdrs = [{'min-expires', MinExpires}]}); - true -> - case unregister_session(US, SIPSock, CallID, CSeq) of - ok -> - ?INFO_MSG("unregister SIP session for user ~s@~s from ~s", - [LUser, LServer, inet_parse:ntoa(PeerIP)]), + ContactsHaveManyRegID -> + mod_sip:make_response( + Req, #sip{type = response, status = 400, + reason = <<"Multiple 'reg-id' parameter">>}); + true -> + case register_session(US, SIPSock, CallID, CSeq, + IsOutboundSupported, + ContactsWithExpires) of + {ok, Res} -> + ?INFO_MSG("~s SIP session for user ~s@~s from ~s", + [Res, LUser, LServer, + inet_parse:ntoa(PeerIP)]), + Cs = prepare_contacts_to_send(ContactsWithExpires), + Require = case need_ob_hdrs( + Contacts, IsOutboundSupported) of + true -> [{'require', [<<"outbound">>]}, + {'flow-timer', + get_flow_timeout(LServer, SIPSock)}]; + false -> [] + end, mod_sip:make_response( Req, - #sip{type = response, status = 200, - hdrs = [{'contact', [Contact]}]}); + #sip{type = response, + status = 200, + hdrs = [{'contact', Cs}|Require]}); {error, Why} -> {Status, Reason} = make_status(Why), mod_sip:make_response( @@ -122,23 +121,16 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) -> end; [] -> case mnesia:dirty_read(sip_session, US) of - [#sip_session{bindings = Bindings}] -> - case pop_previous_binding(SIPSock, Bindings) of - {ok, #binding{expires = Expires1}, _} -> - Contact = {<<"">>, - #uri{user = LUser, host = LServer}, - [{<<"expires">>, - jlib:integer_to_binary(Expires1)}]}, - mod_sip:make_response( - Req, #sip{type = response, status = 200, - hdrs = [{'contact', [Contact]}]}); - {error, notfound} -> - {Status, Reason} = make_status(notfound), - mod_sip:make_response( - Req, #sip{type = response, - status = Status, - reason = Reason}) - end; + [_|_] = Sessions -> + ContactsWithExpires = + lists:map( + fun(#sip_session{contact = Contact, expires = Es}) -> + {Contact, Es} + end, Sessions), + Cs = prepare_contacts_to_send(ContactsWithExpires), + mod_sip:make_response( + Req, #sip{type = response, status = 200, + hdrs = [{'contact', Cs}]}); [] -> {Status, Reason} = make_status(notfound), mod_sip:make_response( @@ -152,27 +144,41 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) -> find_sockets(U, S) -> case mnesia:dirty_read(sip_session, {U, S}) of - [#sip_session{bindings = Bindings}] -> - [Binding#binding.socket || Binding <- Bindings]; + [_|_] = Sessions -> + lists:map( + fun(#sip_session{contact = {_, URI, _}, + socket = Socket}) -> + {Socket, URI} + end, Sessions); [] -> [] end. +ping(SIPSocket) -> + call({ping, SIPSocket}). + %%%=================================================================== %%% gen_server callbacks %%%=================================================================== init([]) -> + update_table(), mnesia:create_table(sip_session, [{ram_copies, [node()]}, + {type, bag}, {attributes, record_info(fields, sip_session)}]), + mnesia:add_table_index(sip_session, conn_mref), + mnesia:add_table_index(sip_session, socket), mnesia:add_table_copy(sip_session, node(), ram_copies), {ok, #state{}}. -handle_call({write, Session}, _From, State) -> - Res = write_session(Session), +handle_call({write, Sessions, Supported}, _From, State) -> + Res = write_session(Sessions, Supported), + {reply, Res, State}; +handle_call({delete, US, CallID, CSeq}, _From, State) -> + Res = delete_session(US, CallID, CSeq), {reply, Res, State}; -handle_call({delete, US, SIPSocket, CallID, CSeq}, _From, State) -> - Res = delete_session(US, SIPSocket, CallID, CSeq), +handle_call({ping, SIPSocket}, _From, State) -> + Res = process_ping(SIPSocket), {reply, Res, State}; handle_call(_Request, _From, State) -> Reply = ok, @@ -181,15 +187,23 @@ handle_call(_Request, _From, State) -> handle_cast(_Msg, State) -> {noreply, State}. -handle_info({write, Session}, State) -> - write_session(Session), +handle_info({write, Sessions, Supported}, State) -> + write_session(Sessions, Supported), {noreply, State}; -handle_info({delete, US, SIPSocket, CallID, CSeq}, State) -> - delete_session(US, SIPSocket, CallID, CSeq), +handle_info({delete, US, CallID, CSeq}, State) -> + delete_session(US, CallID, CSeq), {noreply, State}; handle_info({timeout, TRef, US}, State) -> delete_expired_session(US, TRef), {noreply, State}; +handle_info({'DOWN', MRef, process, _Pid, _Reason}, State) -> + case mnesia:dirty_index_read(sip_session, MRef, #sip_session.conn_mref) of + [Session] -> + mnesia:dirty_delete_object(Session); + _ -> + ok + end, + {noreply, State}; handle_info(_Info, State) -> ?ERROR_MSG("got unexpected info: ~p", [_Info]), {noreply, State}. @@ -203,70 +217,98 @@ code_change(_OldVsn, State, _Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== -register_session(US, SIPSocket, CallID, CSeq, Expires) -> - Session = #sip_session{us = US, - bindings = [#binding{socket = SIPSocket, - call_id = CallID, - cseq = CSeq, - timestamp = now(), - expires = Expires}]}, - call({write, Session}). - -unregister_session(US, SIPSocket, CallID, CSeq) -> - Msg = {delete, US, SIPSocket, CallID, CSeq}, +register_session(US, SIPSocket, CallID, CSeq, IsOutboundSupported, + ContactsWithExpires) -> + Sessions = lists:map( + fun({Contact, Expires}) -> + #sip_session{us = US, + socket = SIPSocket, + call_id = CallID, + cseq = CSeq, + timestamp = now(), + contact = Contact, + expires = Expires} + end, ContactsWithExpires), + Msg = {write, Sessions, IsOutboundSupported}, call(Msg). -write_session(#sip_session{us = {U, S} = US, - bindings = [#binding{socket = SIPSocket, - call_id = CallID, - expires = Expires, - cseq = CSeq} = Binding]}) -> - case mnesia:dirty_read(sip_session, US) of - [#sip_session{bindings = Bindings}] -> - case pop_previous_binding(SIPSocket, Bindings) of - {ok, #binding{call_id = CallID, cseq = PrevCSeq}, _} - when PrevCSeq > CSeq -> - {error, cseq_out_of_order}; - {ok, #binding{tref = Tref}, Bindings1} -> - erlang:cancel_timer(Tref), - NewTRef = erlang:start_timer(Expires * 1000, self(), US), - NewBindings = [Binding#binding{tref = NewTRef}|Bindings1], - mnesia:dirty_write( - #sip_session{us = US, bindings = NewBindings}); - {error, notfound} -> - MaxSessions = ejabberd_sm:get_max_user_sessions(U, S), - if length(Bindings) < MaxSessions -> - NewTRef = erlang:start_timer(Expires * 1000, self(), US), - NewBindings = [Binding#binding{tref = NewTRef}|Bindings], - mnesia:dirty_write( - #sip_session{us = US, bindings = NewBindings}); - true -> - {error, too_many_sessions} +unregister_session(US, CallID, CSeq) -> + Msg = {delete, US, CallID, CSeq}, + call(Msg). + +write_session([#sip_session{us = {U, S} = US}|_] = NewSessions, + IsOutboundSupported) -> + PrevSessions = mnesia:dirty_read(sip_session, US), + Res = lists:foldl( + fun(_, {error, _} = Err) -> + Err; + (#sip_session{call_id = CallID, + expires = Expires, + cseq = CSeq} = Session, {Add, Del}) -> + case find_session(Session, PrevSessions, + IsOutboundSupported) of + {ok, normal, #sip_session{call_id = CallID, + cseq = PrevCSeq}} + when PrevCSeq > CSeq -> + {error, cseq_out_of_order}; + {ok, _Type, PrevSession} when Expires == 0 -> + {Add, [PrevSession|Del]}; + {ok, _Type, PrevSession} -> + {[Session|Add], [PrevSession|Del]}; + {error, notfound} when Expires == 0 -> + {error, notfound}; + {error, notfound} -> + {[Session|Add], Del} end - end; - [] -> - NewTRef = erlang:start_timer(Expires * 1000, self(), US), - NewBindings = [Binding#binding{tref = NewTRef}], - mnesia:dirty_write(#sip_session{us = US, bindings = NewBindings}) + end, {[], []}, NewSessions), + MaxSessions = ejabberd_sm:get_max_user_sessions(U, S), + case Res of + {error, Why} -> + {error, Why}; + {AddSessions, DelSessions} -> + MaxSessions = ejabberd_sm:get_max_user_sessions(U, S), + AllSessions = AddSessions ++ PrevSessions -- DelSessions, + if length(AllSessions) > MaxSessions -> + {error, too_many_sessions}; + true -> + lists:foreach(fun delete_session/1, DelSessions), + lists:foreach( + fun(Session) -> + NewSession = set_monitor_and_timer( + Session, IsOutboundSupported), + mnesia:dirty_write(NewSession) + end, AddSessions), + case {AllSessions, AddSessions} of + {[], _} -> + {ok, unregister}; + {_, []} -> + {ok, unregister}; + _ -> + {ok, register} + end + end end. -delete_session(US, SIPSocket, CallID, CSeq) -> +delete_session(US, CallID, CSeq) -> case mnesia:dirty_read(sip_session, US) of - [#sip_session{bindings = Bindings}] -> - case pop_previous_binding(SIPSocket, Bindings) of - {ok, #binding{call_id = CallID, cseq = PrevCSeq}, _} - when PrevCSeq > CSeq -> - {error, cseq_out_of_order}; - {ok, #binding{tref = TRef}, []} -> - erlang:cancel_timer(TRef), - mnesia:dirty_delete(sip_session, US); - {ok, #binding{tref = TRef}, NewBindings} -> - erlang:cancel_timer(TRef), - mnesia:dirty_write(sip_session, - #sip_session{us = US, - bindings = NewBindings}); - {error, notfound} -> - {error, notfound} + [_|_] = Sessions -> + case lists:all( + fun(S) when S#sip_session.call_id == CallID, + S#sip_session.cseq > CSeq -> + false; + (_) -> + true + end, Sessions) of + true -> + ContactsWithExpires = + lists:map( + fun(#sip_session{contact = Contact} = Session) -> + delete_session(Session), + {Contact, 0} + end, Sessions), + {ok, ContactsWithExpires}; + false -> + {error, cseq_out_of_order} end; [] -> {error, notfound} @@ -274,20 +316,20 @@ delete_session(US, SIPSocket, CallID, CSeq) -> delete_expired_session(US, TRef) -> case mnesia:dirty_read(sip_session, US) of - [#sip_session{bindings = Bindings}] -> - case lists:filter( - fun(#binding{tref = TRef1}) when TRef1 == TRef -> - false; - (_) -> - true - end, Bindings) of - [] -> - mnesia:dirty_delete(sip_session, US); - NewBindings -> - mnesia:dirty_write(sip_session, - #sip_session{us = US, - bindings = NewBindings}) - end; + [_|_] = Sessions -> + lists:foreach( + fun(#sip_session{reg_tref = T1, + flow_tref = T2} = Session) + when T1 == TRef; T2 == TRef -> + if T2 /= undefined -> + close_socket(Session); + true -> + ok + end, + delete_session(Session); + (_) -> + ok + end, Sessions); [] -> ok end. @@ -303,17 +345,6 @@ to_integer(Bin, Min, Max) -> error end. -pop_previous_binding(#sip_socket{peer = Peer}, Bindings) -> - case lists:partition( - fun(#binding{socket = #sip_socket{peer = Peer1}}) -> - Peer1 == Peer - end, Bindings) of - {[Binding], RestBindings} -> - {ok, Binding, RestBindings}; - _ -> - {error, notfound} - end. - call(Msg) -> case catch ?GEN_SERVER:call(?MODULE, Msg, ?CALL_TIMEOUT) of {'EXIT', {timeout, _}} -> @@ -324,6 +355,87 @@ call(Msg) -> Reply end. +make_contacts_with_expires(Contacts, Expires) -> + lists:map( + fun({Name, URI, Params}) -> + E1 = case to_integer(esip:get_param(<<"expires">>, Params), + 0, (1 bsl 32)-1) of + {ok, E} -> E; + _ -> Expires + end, + Params1 = lists:keydelete(<<"expires">>, 1, Params), + {{Name, URI, Params1}, E1} + end, Contacts). + +prepare_contacts_to_send(ContactsWithExpires) -> + lists:map( + fun({{Name, URI, Params}, Expires}) -> + Params1 = esip:set_param(<<"expires">>, + list_to_binary( + integer_to_list(Expires)), + Params), + {Name, URI, Params1} + end, ContactsWithExpires). + +contacts_have_many_reg_id(Contacts) -> + Sum = lists:foldl( + fun({_Name, _URI, Params}, Acc) -> + case get_ob_params(Params) of + error -> + Acc; + {_, _} -> + Acc + 1 + end + end, 0, Contacts), + if Sum > 1 -> + true; + true -> + false + end. + +find_session(#sip_session{contact = {_, URI, Params}}, Sessions, + IsOutboundSupported) -> + if IsOutboundSupported -> + case get_ob_params(Params) of + {InstanceID, RegID} -> + find_session_by_ob({InstanceID, RegID}, Sessions); + error -> + find_session_by_uri(URI, Sessions) + end; + true -> + find_session_by_uri(URI, Sessions) + end. + +find_session_by_ob({InstanceID, RegID}, + [#sip_session{contact = {_, _, Params}} = Session|Sessions]) -> + case get_ob_params(Params) of + {InstanceID, RegID} -> + {ok, flow, Session}; + _ -> + find_session_by_ob({InstanceID, RegID}, Sessions) + end; +find_session_by_ob(_, []) -> + {error, notfound}. + +find_session_by_uri(URI1, + [#sip_session{contact = {_, URI2, _}} = Session|Sessions]) -> + case cmp_uri(URI1, URI2) of + true -> + {ok, normal, Session}; + false -> + find_session_by_uri(URI1, Sessions) + end; +find_session_by_uri(_, []) -> + {error, notfound}. + +%% TODO: this is *totally* wrong. +%% Rewrite this using URI comparison rules +cmp_uri(#uri{user = U, host = H, port = P}, + #uri{user = U, host = H, port = P}) -> + true; +cmp_uri(_, _) -> + false. + make_status(notfound) -> {404, esip:reason(404)}; make_status(cseq_out_of_order) -> @@ -334,3 +446,119 @@ make_status(too_many_sessions) -> {503, <<"Too Many Registered Sessions">>}; make_status(_) -> {500, esip:reason(500)}. + +get_ob_params(Params) -> + case esip:get_param(<<"+sip.instance">>, Params) of + <<>> -> + error; + InstanceID -> + case to_integer(esip:get_param(<<"reg-id">>, Params), + 0, (1 bsl 32)-1) of + {ok, RegID} -> + {InstanceID, RegID}; + error -> + error + end + end. + +need_ob_hdrs(_Contacts, _IsOutboundSupported = false) -> + false; +need_ob_hdrs(Contacts, _IsOutboundSupported = true) -> + lists:any( + fun({_Name, _URI, Params}) -> + case get_ob_params(Params) of + error -> false; + {_, _} -> true + end + end, Contacts). + +get_flow_timeout(LServer, #sip_socket{type = Type}) -> + {Option, Default} = + case Type of + udp -> {flow_timeout_udp, ?FLOW_TIMEOUT_UDP}; + _ -> {flow_timeout_tcp, ?FLOW_TIMEOUT_TCP} + end, + gen_mod:get_module_opt( + LServer, mod_sip, Option, + fun(I) when is_integer(I), I>0 -> I end, + Default). + +update_table() -> + Fields = record_info(fields, sip_session), + case catch mnesia:table_info(sip_session, attributes) of + Fields -> + ok; + [_|_] -> + mnesia:delete_table(sip_session); + {'EXIT', _} -> + ok + end. + +set_monitor_and_timer(#sip_session{socket = #sip_socket{type = Type, + pid = Pid} = SIPSock, + conn_mref = MRef, + expires = Expires, + us = {_, LServer}, + contact = {_, _, Params}} = Session, + IsOutboundSupported) -> + RegTRef = set_timer(Session, Expires), + Session1 = Session#sip_session{reg_tref = RegTRef}, + if IsOutboundSupported -> + case get_ob_params(Params) of + error -> + Session1; + {_, _} -> + FlowTimeout = get_flow_timeout(LServer, SIPSock), + FlowTRef = set_timer(Session1, FlowTimeout), + NewMRef = if Type == udp -> MRef; + true -> erlang:monitor(process, Pid) + end, + Session1#sip_session{conn_mref = NewMRef, + flow_tref = FlowTRef} + end; + true -> + Session1 + end. + +set_timer(#sip_session{us = US}, Timeout) -> + erlang:start_timer(Timeout * 1000, self(), US). + +close_socket(#sip_session{socket = SIPSocket}) -> + if SIPSocket#sip_socket.type /= udp -> + esip_socket:close(SIPSocket); + true -> + ok + end. + +delete_session(#sip_session{reg_tref = RegTRef, + flow_tref = FlowTRef, + conn_mref = MRef} = Session) -> + erlang:cancel_timer(RegTRef), + catch erlang:cancel_timer(FlowTRef), + catch erlang:demonitor(MRef, [flush]), + mnesia:dirty_delete_object(Session). + +process_ping(SIPSocket) -> + ErrResponse = if SIPSocket#sip_socket.type == udp -> pang; + true -> drop + end, + Sessions = mnesia:dirty_index_read( + sip_session, SIPSocket, #sip_session.socket), + lists:foldl( + fun(#sip_session{flow_tref = TRef, + us = {_, LServer}} = Session, _) + when TRef /= undefined -> + erlang:cancel_timer(TRef), + mnesia:dirty_delete_object(Session), + Timeout = get_flow_timeout(LServer, SIPSocket), + NewTRef = set_timer(Session, Timeout), + case mnesia:dirty_write( + Session#sip_session{flow_tref = NewTRef}) of + ok -> + pong; + _Err -> + pang + end; + (_, Acc) -> + Acc + end, ErrResponse, Sessions). diff --git a/src/mod_vcard.erl b/src/mod_vcard.erl index c98750ed0..8afac260b 100644 --- a/src/mod_vcard.erl +++ b/src/mod_vcard.erl @@ -46,7 +46,7 @@ lbday, ctry, lctry, locality, llocality, email, lemail, orgname, lorgname, orgunit, lorgunit}). --record(vcard, {us = {<<"">>, <<"">>} :: {binary(), binary()}, +-record(vcard, {us = {<<"">>, <<"">>} :: {binary(), binary()} | binary(), vcard = #xmlel{} :: xmlel()}). -define(PROCNAME, ejabberd_mod_vcard). @@ -186,6 +186,11 @@ process_sm_iq(From, To, error -> IQ#iq{type = error, sub_el = [SubEl, ?ERR_INTERNAL_SERVER_ERROR]}; + [] -> + IQ#iq{type = result, + sub_el = [#xmlel{name = <<"vCard">>, + attrs = [{<<"xmlns">>, ?NS_VCARD}], + children = []}]}; Els -> IQ#iq{type = result, sub_el = Els} end end. @@ -212,6 +217,15 @@ get_vcard(LUser, LServer, odbc) -> end; {selected, [<<"vcard">>], []} -> []; _ -> error + end; +get_vcard(LUser, LServer, riak) -> + case ejabberd_riak:get(vcard, vcard_schema(), {LUser, LServer}) of + {ok, R} -> + [R#vcard.vcard]; + {error, notfound} -> + []; + _ -> + error end. set_vcard(User, LServer, VCARD) -> @@ -289,6 +303,34 @@ set_vcard(User, LServer, VCARD) -> lorgunit = LOrgUnit}) end, mnesia:transaction(F); + riak -> + US = {LUser, LServer}, + ejabberd_riak:put(#vcard{us = US, vcard = VCARD}, + vcard_schema(), + [{'2i', [{<<"user">>, User}, + {<<"luser">>, LUser}, + {<<"fn">>, FN}, + {<<"lfn">>, LFN}, + {<<"family">>, Family}, + {<<"lfamily">>, LFamily}, + {<<"given">>, Given}, + {<<"lgiven">>, LGiven}, + {<<"middle">>, Middle}, + {<<"lmiddle">>, LMiddle}, + {<<"nickname">>, Nickname}, + {<<"lnickname">>, LNickname}, + {<<"bday">>, BDay}, + {<<"lbday">>, LBDay}, + {<<"ctry">>, CTRY}, + {<<"lctry">>, LCTRY}, + {<<"locality">>, Locality}, + {<<"llocality">>, LLocality}, + {<<"email">>, EMail}, + {<<"lemail">>, LEMail}, + {<<"orgname">>, OrgName}, + {<<"lorgname">>, LOrgName}, + {<<"orgunit">>, OrgUnit}, + {<<"lorgunit">>, LOrgUnit}]}]); odbc -> Username = ejabberd_odbc:escape(User), LUsername = ejabberd_odbc:escape(LUser), @@ -687,14 +729,18 @@ search(LServer, MatchSpec, AllowReturnAll, odbc) -> Rs; Error -> ?ERROR_MSG("~p", [Error]), [] end - end. + end; +search(_LServer, _MatchSpec, _AllowReturnAll, riak) -> + []. make_matchspec(LServer, Data, mnesia) -> GlobMatch = #vcard_search{_ = '_'}, Match = filter_fields(Data, GlobMatch, LServer, mnesia), Match; make_matchspec(LServer, Data, odbc) -> - filter_fields(Data, <<"">>, LServer, odbc). + filter_fields(Data, <<"">>, LServer, odbc); +make_matchspec(_LServer, _Data, riak) -> + []. filter_fields([], Match, _LServer, mnesia) -> Match; filter_fields([], Match, _LServer, odbc) -> @@ -884,7 +930,9 @@ remove_user(LUser, LServer, odbc) -> [[<<"delete from vcard where username='">>, Username, <<"';">>], [<<"delete from vcard_search where lusername='">>, - Username, <<"';">>]]). + Username, <<"';">>]]); +remove_user(LUser, LServer, riak) -> + {atomic, ejabberd_riak:delete(vcard, {LUser, LServer})}. update_tables() -> update_vcard_table(), @@ -930,6 +978,9 @@ update_vcard_search_table() -> mnesia:transform_table(vcard_search, ignore, Fields) end. +vcard_schema() -> + {record_info(fields, vcard), #vcard{}}. + export(_Server) -> [{vcard, fun(Host, #vcard{us = {LUser, LServer}, vcard = VCARD}) @@ -1039,5 +1090,72 @@ import(_LServer, mnesia, #vcard{} = VCard) -> mnesia:dirty_write(VCard); import(_LServer, mnesia, #vcard_search{} = S) -> mnesia:dirty_write(S); +import(_LServer, riak, #vcard{us = {LUser, _}, vcard = El} = VCard) -> + FN = xml:get_path_s(El, [{elem, <<"FN">>}, cdata]), + Family = xml:get_path_s(El, + [{elem, <<"N">>}, {elem, <<"FAMILY">>}, cdata]), + Given = xml:get_path_s(El, + [{elem, <<"N">>}, {elem, <<"GIVEN">>}, cdata]), + Middle = xml:get_path_s(El, + [{elem, <<"N">>}, {elem, <<"MIDDLE">>}, cdata]), + Nickname = xml:get_path_s(El, + [{elem, <<"NICKNAME">>}, cdata]), + BDay = xml:get_path_s(El, + [{elem, <<"BDAY">>}, cdata]), + CTRY = xml:get_path_s(El, + [{elem, <<"ADR">>}, {elem, <<"CTRY">>}, cdata]), + Locality = xml:get_path_s(El, + [{elem, <<"ADR">>}, {elem, <<"LOCALITY">>}, + cdata]), + EMail1 = xml:get_path_s(El, + [{elem, <<"EMAIL">>}, {elem, <<"USERID">>}, cdata]), + EMail2 = xml:get_path_s(El, + [{elem, <<"EMAIL">>}, cdata]), + OrgName = xml:get_path_s(El, + [{elem, <<"ORG">>}, {elem, <<"ORGNAME">>}, cdata]), + OrgUnit = xml:get_path_s(El, + [{elem, <<"ORG">>}, {elem, <<"ORGUNIT">>}, cdata]), + EMail = case EMail1 of + <<"">> -> EMail2; + _ -> EMail1 + end, + LFN = string2lower(FN), + LFamily = string2lower(Family), + LGiven = string2lower(Given), + LMiddle = string2lower(Middle), + LNickname = string2lower(Nickname), + LBDay = string2lower(BDay), + LCTRY = string2lower(CTRY), + LLocality = string2lower(Locality), + LEMail = string2lower(EMail), + LOrgName = string2lower(OrgName), + LOrgUnit = string2lower(OrgUnit), + ejabberd_riak:put(VCard, vcard_schema(), + [{'2i', [{<<"user">>, LUser}, + {<<"luser">>, LUser}, + {<<"fn">>, FN}, + {<<"lfn">>, LFN}, + {<<"family">>, Family}, + {<<"lfamily">>, LFamily}, + {<<"given">>, Given}, + {<<"lgiven">>, LGiven}, + {<<"middle">>, Middle}, + {<<"lmiddle">>, LMiddle}, + {<<"nickname">>, Nickname}, + {<<"lnickname">>, LNickname}, + {<<"bday">>, BDay}, + {<<"lbday">>, LBDay}, + {<<"ctry">>, CTRY}, + {<<"lctry">>, LCTRY}, + {<<"locality">>, Locality}, + {<<"llocality">>, LLocality}, + {<<"email">>, EMail}, + {<<"lemail">>, LEMail}, + {<<"orgname">>, OrgName}, + {<<"lorgname">>, LOrgName}, + {<<"orgunit">>, OrgUnit}, + {<<"lorgunit">>, LOrgUnit}]}]); +import(_LServer, riak, #vcard_search{}) -> + ok; import(_, _, _) -> pass. diff --git a/src/mod_vcard_xupdate.erl b/src/mod_vcard_xupdate.erl index b2ea34419..97d9abbb4 100644 --- a/src/mod_vcard_xupdate.erl +++ b/src/mod_vcard_xupdate.erl @@ -88,6 +88,10 @@ add_xupdate(LUser, LServer, Hash, mnesia) -> hash = Hash}) end, mnesia:transaction(F); +add_xupdate(LUser, LServer, Hash, riak) -> + {atomic, ejabberd_riak:put(#vcard_xupdate{us = {LUser, LServer}, + hash = Hash}, + vcard_xupdate_schema())}; add_xupdate(LUser, LServer, Hash, odbc) -> Username = ejabberd_odbc:escape(LUser), SHash = ejabberd_odbc:escape(Hash), @@ -109,6 +113,12 @@ get_xupdate(LUser, LServer, mnesia) -> [#vcard_xupdate{hash = Hash}] -> Hash; _ -> undefined end; +get_xupdate(LUser, LServer, riak) -> + case ejabberd_riak:get(vcard_xupdate, vcard_xupdate_schema(), + {LUser, LServer}) of + {ok, #vcard_xupdate{hash = Hash}} -> Hash; + _ -> undefined + end; get_xupdate(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), case ejabberd_odbc:sql_query(LServer, @@ -129,6 +139,8 @@ remove_xupdate(LUser, LServer, mnesia) -> mnesia:delete({vcard_xupdate, {LUser, LServer}}) end, mnesia:transaction(F); +remove_xupdate(LUser, LServer, riak) -> + {atomic, ejabberd_riak:delete(vcard_xupdate, {LUser, LServer})}; remove_xupdate(LUser, LServer, odbc) -> Username = ejabberd_odbc:escape(LUser), F = fun () -> @@ -172,6 +184,9 @@ build_xphotoel(User, Host) -> attrs = [{<<"xmlns">>, ?NS_VCARD_UPDATE}], children = PhotoEl}. +vcard_xupdate_schema() -> + {record_info(fields, vcard_xupdate), #vcard_xupdate{}}. + update_table() -> Fields = record_info(fields, vcard_xupdate), case mnesia:table_info(vcard_xupdate, attributes) of @@ -212,5 +227,7 @@ import(LServer) -> import(_LServer, mnesia, #vcard_xupdate{} = R) -> mnesia:dirty_write(R); +import(_LServer, riak, #vcard_xupdate{} = R) -> + ejabberd_riak:put(R, vcard_xupdate_schema()); import(_, _, _) -> pass. diff --git a/src/node_hometree_odbc.erl b/src/node_hometree_odbc.erl index dfb9886f2..9a4a3b2e7 100644 --- a/src/node_hometree_odbc.erl +++ b/src/node_hometree_odbc.erl @@ -1317,6 +1317,7 @@ get_items(NodeId, _From, first = <<"modification@", F/binary>>, last = <<"modification@", (jlib:i2l(L))/binary>>}, {result, {[raw_to_item(NodeId, RItem) || RItem <- RItems], RsmOut}}; + [] -> {result, {[], #rsm_out{count = Count}}}; 0 -> {result, {[], #rsm_out{count = Count}}} end; _ -> {result, {[], none}} diff --git a/src/odbc_queries.erl b/src/odbc_queries.erl index e0637f840..09549c0a2 100644 --- a/src/odbc_queries.erl +++ b/src/odbc_queries.erl @@ -97,10 +97,14 @@ update_t(Table, Fields, Vals, Where) -> of {updated, 1} -> ok; _ -> - ejabberd_odbc:sql_query_t([<<"insert into ">>, Table, + Res = ejabberd_odbc:sql_query_t([<<"insert into ">>, Table, <<"(">>, join(Fields, <<", ">>), <<") values ('">>, join(Vals, <<"', '">>), - <<"');">>]) + <<"');">>]), + case Res of + {updated,1} -> ok; + _ -> Res + end end. update(LServer, Table, Fields, Vals, Where) -> @@ -115,10 +119,14 @@ update(LServer, Table, Fields, Vals, Where) -> of {updated, 1} -> ok; _ -> - ejabberd_odbc:sql_query(LServer, + Res = ejabberd_odbc:sql_query(LServer, [<<"insert into ">>, Table, <<"(">>, join(Fields, <<", ">>), <<") values ('">>, - join(Vals, <<"', '">>), <<"');">>]) + join(Vals, <<"', '">>), <<"');">>]), + case Res of + {updated,1} -> ok; + _ -> Res + end end. %% F can be either a fun or a list of queries diff --git a/src/p1_fsm.erl b/src/p1_fsm.erl deleted file mode 100644 index 80f08c609..000000000 --- a/src/p1_fsm.erl +++ /dev/null @@ -1,848 +0,0 @@ -%% ``The contents of this file are subject to the Erlang Public License, -%% Version 1.1, (the "License"); you may not use this file except in -%% compliance with the License. You should have received a copy of the -%% Erlang Public License along with this software. If not, it can be -%% retrieved via the world wide web at http://www.erlang.org/. -%% -%% Software distributed under the License is distributed on an "AS IS" -%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See -%% the License for the specific language governing rights and limitations -%% under the License. -%% -%% The Initial Developer of the Original Code is Ericsson Utvecklings AB. -%% Portions created by Ericsson are Copyright 1999, Ericsson Utvecklings -%% AB. All Rights Reserved.'' -%% -%% The code has been modified and improved by ProcessOne. -%% Copyright 2007-2014, ProcessOne -%% -%% The change adds the following features: -%% - You can send exit(priority_shutdown) to the p1_fsm process to -%% terminate immediatetly. If the fsm trap_exit process flag has been -%% set to true, the FSM terminate function will called. -%% - You can pass the gen_fsm options to control resource usage. -%% {max_queue, N} will exit the process with priority_shutdown -%% - You can limit the time processing a message (TODO): If the -%% message processing does not return in a given period of time, the -%% process will be terminated. -%% - You might customize the State data before sending it to error_logger -%% in case of a crash (just export the function print_state/1) -%% $Id$ -%% --module(p1_fsm). - -%%%----------------------------------------------------------------- -%%% -%%% This state machine is somewhat more pure than state_lib. It is -%%% still based on State dispatching (one function per state), but -%%% allows a function handle_event to take care of events in all states. -%%% It's not that pure anymore :( We also allow synchronized event sending. -%%% -%%% If the Parent process terminates the Module:terminate/2 -%%% function is called. -%%% -%%% The user module should export: -%%% -%%% init(Args) -%%% ==> {ok, StateName, StateData} -%%% {ok, StateName, StateData, Timeout} -%%% ignore -%%% {stop, Reason} -%%% -%%% StateName(Msg, StateData) -%%% -%%% ==> {next_state, NewStateName, NewStateData} -%%% {next_state, NewStateName, NewStateData, Timeout} -%%% {stop, Reason, NewStateData} -%%% Reason = normal | shutdown | Term terminate(State) is called -%%% -%%% StateName(Msg, From, StateData) -%%% -%%% ==> {next_state, NewStateName, NewStateData} -%%% {next_state, NewStateName, NewStateData, Timeout} -%%% {reply, Reply, NewStateName, NewStateData} -%%% {reply, Reply, NewStateName, NewStateData, Timeout} -%%% {stop, Reason, NewStateData} -%%% Reason = normal | shutdown | Term terminate(State) is called -%%% -%%% handle_event(Msg, StateName, StateData) -%%% -%%% ==> {next_state, NewStateName, NewStateData} -%%% {next_state, NewStateName, NewStateData, Timeout} -%%% {stop, Reason, Reply, NewStateData} -%%% {stop, Reason, NewStateData} -%%% Reason = normal | shutdown | Term terminate(State) is called -%%% -%%% handle_sync_event(Msg, From, StateName, StateData) -%%% -%%% ==> {next_state, NewStateName, NewStateData} -%%% {next_state, NewStateName, NewStateData, Timeout} -%%% {reply, Reply, NewStateName, NewStateData} -%%% {reply, Reply, NewStateName, NewStateData, Timeout} -%%% {stop, Reason, Reply, NewStateData} -%%% {stop, Reason, NewStateData} -%%% Reason = normal | shutdown | Term terminate(State) is called -%%% -%%% handle_info(Info, StateName) (e.g. {'EXIT', P, R}, {nodedown, N}, ... -%%% -%%% ==> {next_state, NewStateName, NewStateData} -%%% {next_state, NewStateName, NewStateData, Timeout} -%%% {stop, Reason, NewStateData} -%%% Reason = normal | shutdown | Term terminate(State) is called -%%% -%%% terminate(Reason, StateName, StateData) Let the user module clean up -%%% always called when server terminates -%%% -%%% ==> the return value is ignored -%%% -%%% -%%% The work flow (of the fsm) can be described as follows: -%%% -%%% User module fsm -%%% ----------- ------- -%%% start -----> start -%%% init <----- . -%%% -%%% loop -%%% StateName <----- . -%%% -%%% handle_event <----- . -%%% -%%% handle__sunc_event <----- . -%%% -%%% handle_info <----- . -%%% -%%% terminate <----- . -%%% -%%% -%%% --------------------------------------------------- - --export([start/3, start/4, - start_link/3, start_link/4, - send_event/2, sync_send_event/2, sync_send_event/3, - send_all_state_event/2, - sync_send_all_state_event/2, sync_send_all_state_event/3, - reply/2, - start_timer/2,send_event_after/2,cancel_timer/1, - enter_loop/4, enter_loop/5, enter_loop/6, wake_hib/7]). - -%% Internal exports --export([init_it/6, print_event/3, - system_continue/3, - system_terminate/4, - system_code_change/4, - format_status/2]). - --import(error_logger , [format/2]). - -%%% Internal gen_fsm state -%%% This state is used to defined resource control values: --record(limits, {max_queue :: non_neg_integer()}). - -%%% --------------------------------------------------- -%%% Interface functions. -%%% --------------------------------------------------- - --callback init(Args :: term()) -> - {ok, StateName :: atom(), StateData :: term()} | - {ok, StateName :: atom(), StateData :: term(), timeout() | hibernate} | - {stop, Reason :: term()} | ignore. --callback handle_event(Event :: term(), StateName :: atom(), - StateData :: term()) -> - {next_state, NextStateName :: atom(), NewStateData :: term()} | - {next_state, NextStateName :: atom(), NewStateData :: term(), - timeout() | hibernate} | - {migrate, NewStateData :: term(), - {Node :: atom(), M :: atom(), F :: atom(), A :: list()}, - Timeout :: timeout()} | - {stop, Reason :: term(), NewStateData :: term()}. --callback handle_sync_event(Event :: term(), From :: {pid(), Tag :: term()}, - StateName :: atom(), StateData :: term()) -> - {reply, Reply :: term(), NextStateName :: atom(), NewStateData :: term()} | - {reply, Reply :: term(), NextStateName :: atom(), NewStateData :: term(), - timeout() | hibernate} | - {next_state, NextStateName :: atom(), NewStateData :: term()} | - {next_state, NextStateName :: atom(), NewStateData :: term(), - timeout() | hibernate} | - {migrate, NewStateData :: term(), - {Node :: atom(), M :: atom(), F :: atom(), A :: list()}, - Timeout :: timeout()} | - {stop, Reason :: term(), Reply :: term(), NewStateData :: term()} | - {stop, Reason :: term(), NewStateData :: term()}. --callback handle_info(Info :: term(), StateName :: atom(), - StateData :: term()) -> - {next_state, NextStateName :: atom(), NewStateData :: term()} | - {next_state, NextStateName :: atom(), NewStateData :: term(), - timeout() | hibernate} | - {migrate, NewStateData :: term(), - {Node :: atom(), M :: atom(), F :: atom(), A :: list()}, - Timeout :: timeout()} | - {stop, Reason :: normal | term(), NewStateData :: term()}. --callback terminate(Reason :: normal | shutdown | {shutdown, term()} - | term(), StateName :: atom(), StateData :: term()) -> - term(). --callback code_change(OldVsn :: term() | {down, term()}, StateName :: atom(), - StateData :: term(), Extra :: term()) -> - {ok, NextStateName :: atom(), NewStateData :: term()}. - -%%% --------------------------------------------------- -%%% Starts a generic state machine. -%%% start(Mod, Args, Options) -%%% start(Name, Mod, Args, Options) -%%% start_link(Mod, Args, Options) -%%% start_link(Name, Mod, Args, Options) where: -%%% Name ::= {local, atom()} | {global, atom()} -%%% Mod ::= atom(), callback module implementing the 'real' fsm -%%% Args ::= term(), init arguments (to Mod:init/1) -%%% Options ::= [{debug, [Flag]}] -%%% Flag ::= trace | log | {logfile, File} | statistics | debug -%%% (debug == log && statistics) -%%% Returns: {ok, Pid} | -%%% {error, {already_started, Pid}} | -%%% {error, Reason} -%%% --------------------------------------------------- -start(Mod, Args, Options) -> - gen:start(?MODULE, nolink, Mod, Args, Options). - -start(Name, Mod, Args, Options) -> - gen:start(?MODULE, nolink, Name, Mod, Args, Options). - -start_link(Mod, Args, Options) -> - gen:start(?MODULE, link, Mod, Args, Options). - -start_link(Name, Mod, Args, Options) -> - gen:start(?MODULE, link, Name, Mod, Args, Options). - - -send_event({global, Name}, Event) -> - catch global:send(Name, {'$gen_event', Event}), - ok; -send_event(Name, Event) -> - Name ! {'$gen_event', Event}, - ok. - -sync_send_event(Name, Event) -> - case catch gen:call(Name, '$gen_sync_event', Event) of - {ok,Res} -> - Res; - {'EXIT',Reason} -> - exit({Reason, {?MODULE, sync_send_event, [Name, Event]}}) - end. - -sync_send_event(Name, Event, Timeout) -> - case catch gen:call(Name, '$gen_sync_event', Event, Timeout) of - {ok,Res} -> - Res; - {'EXIT',Reason} -> - exit({Reason, {?MODULE, sync_send_event, [Name, Event, Timeout]}}) - end. - -send_all_state_event({global, Name}, Event) -> - catch global:send(Name, {'$gen_all_state_event', Event}), - ok; -send_all_state_event(Name, Event) -> - Name ! {'$gen_all_state_event', Event}, - ok. - -sync_send_all_state_event(Name, Event) -> - case catch gen:call(Name, '$gen_sync_all_state_event', Event) of - {ok,Res} -> - Res; - {'EXIT',Reason} -> - exit({Reason, {?MODULE, sync_send_all_state_event, [Name, Event]}}) - end. - -sync_send_all_state_event(Name, Event, Timeout) -> - case catch gen:call(Name, '$gen_sync_all_state_event', Event, Timeout) of - {ok,Res} -> - Res; - {'EXIT',Reason} -> - exit({Reason, {?MODULE, sync_send_all_state_event, - [Name, Event, Timeout]}}) - end. - -%% Designed to be only callable within one of the callbacks -%% hence using the self() of this instance of the process. -%% This is to ensure that timers don't go astray in global -%% e.g. when straddling a failover, or turn up in a restarted -%% instance of the process. - -%% Returns Ref, sends event {timeout,Ref,Msg} after Time -%% to the (then) current state. -start_timer(Time, Msg) -> - erlang:start_timer(Time, self(), {'$gen_timer', Msg}). - -%% Returns Ref, sends Event after Time to the (then) current state. -send_event_after(Time, Event) -> - erlang:start_timer(Time, self(), {'$gen_event', Event}). - -%% Returns the remaing time for the timer if Ref referred to -%% an active timer/send_event_after, false otherwise. -cancel_timer(Ref) -> - case erlang:cancel_timer(Ref) of - false -> - receive {timeout, Ref, _} -> 0 - after 0 -> false - end; - RemainingTime -> - RemainingTime - end. - -%% enter_loop/4,5,6 -%% Makes an existing process into a gen_fsm. -%% The calling process will enter the gen_fsm receive loop and become a -%% gen_fsm process. -%% The process *must* have been started using one of the start functions -%% in proc_lib, see proc_lib(3). -%% The user is responsible for any initialization of the process, -%% including registering a name for it. -enter_loop(Mod, Options, StateName, StateData) -> - enter_loop(Mod, Options, StateName, StateData, self(), infinity). - -enter_loop(Mod, Options, StateName, StateData, ServerName = {_,_}) -> - enter_loop(Mod, Options, StateName, StateData, ServerName,infinity); -enter_loop(Mod, Options, StateName, StateData, Timeout) -> - enter_loop(Mod, Options, StateName, StateData, self(), Timeout). - -enter_loop(Mod, Options, StateName, StateData, ServerName, Timeout) -> - Name = get_proc_name(ServerName), - Parent = get_parent(), - Debug = gen:debug_options(Options), - Limits = limit_options(Options), - Queue = queue:new(), - QueueLen = 0, - loop(Parent, Name, StateName, StateData, Mod, Timeout, Debug, - Limits, Queue, QueueLen). - -get_proc_name(Pid) when is_pid(Pid) -> - Pid; -get_proc_name({local, Name}) -> - case process_info(self(), registered_name) of - {registered_name, Name} -> - Name; - {registered_name, _Name} -> - exit(process_not_registered); - [] -> - exit(process_not_registered) - end; -get_proc_name({global, Name}) -> - case global:whereis_name(Name) of - undefined -> - exit(process_not_registered_globally); - Pid when Pid==self() -> - Name; - _Pid -> - exit(process_not_registered_globally) - end. - -get_parent() -> - case get('$ancestors') of - [Parent | _] when is_pid(Parent) -> - Parent; - [Parent | _] when is_atom(Parent) -> - name_to_pid(Parent); - _ -> - exit(process_was_not_started_by_proc_lib) - end. - -name_to_pid(Name) -> - case whereis(Name) of - undefined -> - case global:whereis_name(Name) of - undefined -> - exit(could_not_find_registerd_name); - Pid -> - Pid - end; - Pid -> - Pid - end. - -%%% --------------------------------------------------- -%%% Initiate the new process. -%%% Register the name using the Rfunc function -%%% Calls the Mod:init/Args function. -%%% Finally an acknowledge is sent to Parent and the main -%%% loop is entered. -%%% --------------------------------------------------- -init_it(Starter, self, Name, Mod, Args, Options) -> - init_it(Starter, self(), Name, Mod, Args, Options); -init_it(Starter, Parent, Name0, Mod, Args, Options) -> - Name = name(Name0), - Debug = gen:debug_options(Options), - Limits = limit_options(Options), - Queue = queue:new(), - QueueLen = 0, - case catch Mod:init(Args) of - {ok, StateName, StateData} -> - proc_lib:init_ack(Starter, {ok, self()}), - loop(Parent, Name, StateName, StateData, Mod, infinity, Debug, Limits, Queue, QueueLen); - {ok, StateName, StateData, Timeout} -> - proc_lib:init_ack(Starter, {ok, self()}), - loop(Parent, Name, StateName, StateData, Mod, Timeout, Debug, Limits, Queue, QueueLen); - {stop, Reason} -> - proc_lib:init_ack(Starter, {error, Reason}), - exit(Reason); - ignore -> - proc_lib:init_ack(Starter, ignore), - exit(normal); - {'EXIT', Reason} -> - proc_lib:init_ack(Starter, {error, Reason}), - exit(Reason); - Else -> - Error = {bad_return_value, Else}, - proc_lib:init_ack(Starter, {error, Error}), - exit(Error) - end. - -name({local,Name}) -> Name; -name({global,Name}) -> Name; -name(Pid) when is_pid(Pid) -> Pid. - -%%----------------------------------------------------------------- -%% The MAIN loop -%%----------------------------------------------------------------- -loop(Parent, Name, StateName, StateData, Mod, hibernate, Debug, - Limits, Queue, QueueLen) - when QueueLen > 0 -> - case queue:out(Queue) of - {{value, Msg}, Queue1} -> - decode_msg(Msg, Parent, Name, StateName, StateData, Mod, hibernate, - Debug, Limits, Queue1, QueueLen - 1, false); - {empty, _} -> - Reason = internal_queue_error, - error_info(Mod, Reason, Name, hibernate, StateName, StateData, Debug), - exit(Reason) - end; -loop(Parent, Name, StateName, StateData, Mod, hibernate, Debug, - Limits, _Queue, _QueueLen) -> - proc_lib:hibernate(?MODULE,wake_hib, - [Parent, Name, StateName, StateData, Mod, - Debug, Limits]); -%% First we test if we have reach a defined limit ... -loop(Parent, Name, StateName, StateData, Mod, Time, Debug, - Limits, Queue, QueueLen) -> - try - message_queue_len(Limits, QueueLen) - %% TODO: We can add more limit checking here... - catch - {process_limit, Limit} -> - Reason = {process_limit, Limit}, - Msg = {'EXIT', Parent, {error, {process_limit, Limit}}}, - terminate(Reason, Name, Msg, Mod, StateName, StateData, Debug) - end, - process_message(Parent, Name, StateName, StateData, - Mod, Time, Debug, Limits, Queue, QueueLen). -%% ... then we can process a new message: -process_message(Parent, Name, StateName, StateData, Mod, Time, Debug, - Limits, Queue, QueueLen) -> - {Msg, Queue1, QueueLen1} = collect_messages(Queue, QueueLen, Time), - decode_msg(Msg,Parent, Name, StateName, StateData, Mod, Time, - Debug, Limits, Queue1, QueueLen1, false). - -collect_messages(Queue, QueueLen, Time) -> - receive - Input -> - case Input of - {'EXIT', _Parent, priority_shutdown} -> - {Input, Queue, QueueLen}; - _ -> - collect_messages( - queue:in(Input, Queue), QueueLen + 1, Time) - end - after 0 -> - case queue:out(Queue) of - {{value, Msg}, Queue1} -> - {Msg, Queue1, QueueLen - 1}; - {empty, _} -> - receive - Input -> - {Input, Queue, QueueLen} - after Time -> - {{'$gen_event', timeout}, Queue, QueueLen} - end - end - end. - - -wake_hib(Parent, Name, StateName, StateData, Mod, Debug, - Limits) -> - Msg = receive - Input -> - Input - end, - Queue = queue:new(), - QueueLen = 0, - decode_msg(Msg, Parent, Name, StateName, StateData, Mod, hibernate, - Debug, Limits, Queue, QueueLen, true). - -decode_msg(Msg,Parent, Name, StateName, StateData, Mod, Time, Debug, - Limits, Queue, QueueLen, Hib) -> - put('$internal_queue_len', QueueLen), - case Msg of - {system, From, Req} -> - sys:handle_system_msg(Req, From, Parent, ?MODULE, Debug, - [Name, StateName, StateData, - Mod, Time, Limits, Queue, QueueLen], Hib); - {'EXIT', Parent, Reason} -> - terminate(Reason, Name, Msg, Mod, StateName, StateData, Debug); - _Msg when Debug == [] -> - handle_msg(Msg, Parent, Name, StateName, StateData, - Mod, Time, Limits, Queue, QueueLen); - _Msg -> - Debug1 = sys:handle_debug(Debug, fun print_event/3, - {Name, StateName}, {in, Msg}), - handle_msg(Msg, Parent, Name, StateName, StateData, - Mod, Time, Debug1, Limits, Queue, QueueLen) - end. - -%%----------------------------------------------------------------- -%% Callback functions for system messages handling. -%%----------------------------------------------------------------- -system_continue(Parent, Debug, [Name, StateName, StateData, - Mod, Time, Limits, Queue, QueueLen]) -> - loop(Parent, Name, StateName, StateData, Mod, Time, Debug, - Limits, Queue, QueueLen). - --spec system_terminate(term(), _, _, [term(),...]) -> no_return(). - -system_terminate(Reason, _Parent, Debug, - [Name, StateName, StateData, Mod, _Time, _Limits]) -> - terminate(Reason, Name, [], Mod, StateName, StateData, Debug). - -system_code_change([Name, StateName, StateData, Mod, Time, - Limits, Queue, QueueLen], - _Module, OldVsn, Extra) -> - case catch Mod:code_change(OldVsn, StateName, StateData, Extra) of - {ok, NewStateName, NewStateData} -> - {ok, [Name, NewStateName, NewStateData, Mod, Time, - Limits, Queue, QueueLen]}; - Else -> Else - end. - -%%----------------------------------------------------------------- -%% Format debug messages. Print them as the call-back module sees -%% them, not as the real erlang messages. Use trace for that. -%%----------------------------------------------------------------- -print_event(Dev, {in, Msg}, {Name, StateName}) -> - case Msg of - {'$gen_event', Event} -> - io:format(Dev, "*DBG* ~p got event ~p in state ~w~n", - [Name, Event, StateName]); - {'$gen_all_state_event', Event} -> - io:format(Dev, - "*DBG* ~p got all_state_event ~p in state ~w~n", - [Name, Event, StateName]); - {timeout, Ref, {'$gen_timer', Message}} -> - io:format(Dev, - "*DBG* ~p got timer ~p in state ~w~n", - [Name, {timeout, Ref, Message}, StateName]); - {timeout, _Ref, {'$gen_event', Event}} -> - io:format(Dev, - "*DBG* ~p got timer ~p in state ~w~n", - [Name, Event, StateName]); - _ -> - io:format(Dev, "*DBG* ~p got ~p in state ~w~n", - [Name, Msg, StateName]) - end; -print_event(Dev, {out, Msg, To, StateName}, Name) -> - io:format(Dev, "*DBG* ~p sent ~p to ~w~n" - " and switched to state ~w~n", - [Name, Msg, To, StateName]); -print_event(Dev, return, {Name, StateName}) -> - io:format(Dev, "*DBG* ~p switched to state ~w~n", - [Name, StateName]). - -relay_messages(MRef, TRef, Clone, Queue) -> - lists:foreach( - fun(Msg) -> Clone ! Msg end, - queue:to_list(Queue)), - relay_messages(MRef, TRef, Clone). - -relay_messages(MRef, TRef, Clone) -> - receive - {'DOWN', MRef, process, Clone, Reason} -> - Reason; - {'EXIT', _Parent, _Reason} -> - {migrated, Clone}; - {timeout, TRef, timeout} -> - {migrated, Clone}; - Msg -> - Clone ! Msg, - relay_messages(MRef, TRef, Clone) - end. - -handle_msg(Msg, Parent, Name, StateName, StateData, Mod, _Time, - Limits, Queue, QueueLen) -> %No debug here - From = from(Msg), - case catch dispatch(Msg, Mod, StateName, StateData) of - {next_state, NStateName, NStateData} -> - loop(Parent, Name, NStateName, NStateData, - Mod, infinity, [], Limits, Queue, QueueLen); - {next_state, NStateName, NStateData, Time1} -> - loop(Parent, Name, NStateName, NStateData, Mod, Time1, [], - Limits, Queue, QueueLen); - {reply, Reply, NStateName, NStateData} when From =/= undefined -> - reply(From, Reply), - loop(Parent, Name, NStateName, NStateData, - Mod, infinity, [], Limits, Queue, QueueLen); - {reply, Reply, NStateName, NStateData, Time1} when From =/= undefined -> - reply(From, Reply), - loop(Parent, Name, NStateName, NStateData, Mod, Time1, [], - Limits, Queue, QueueLen); - {migrate, NStateData, {Node, M, F, A}, Time1} -> - Reason = case catch rpc:call(Node, M, F, A, 5000) of - {badrpc, _} = Err -> - {migration_error, Err}; - {'EXIT', _} = Err -> - {migration_error, Err}; - {error, _} = Err -> - {migration_error, Err}; - {ok, Clone} -> - process_flag(trap_exit, true), - MRef = erlang:monitor(process, Clone), - TRef = erlang:start_timer(Time1, self(), timeout), - relay_messages(MRef, TRef, Clone, Queue); - Reply -> - {migration_error, {bad_reply, Reply}} - end, - terminate(Reason, Name, Msg, Mod, StateName, NStateData, []); - {stop, Reason, NStateData} -> - terminate(Reason, Name, Msg, Mod, StateName, NStateData, []); - {stop, Reason, Reply, NStateData} when From =/= undefined -> - {'EXIT', R} = (catch terminate(Reason, Name, Msg, Mod, - StateName, NStateData, [])), - reply(From, Reply), - exit(R); - {'EXIT', What} -> - terminate(What, Name, Msg, Mod, StateName, StateData, []); - Reply -> - terminate({bad_return_value, Reply}, - Name, Msg, Mod, StateName, StateData, []) - end. - -handle_msg(Msg, Parent, Name, StateName, StateData, - Mod, _Time, Debug, Limits, Queue, QueueLen) -> - From = from(Msg), - case catch dispatch(Msg, Mod, StateName, StateData) of - {next_state, NStateName, NStateData} -> - Debug1 = sys:handle_debug(Debug, fun print_event/3, - {Name, NStateName}, return), - loop(Parent, Name, NStateName, NStateData, - Mod, infinity, Debug1, Limits, Queue, QueueLen); - {next_state, NStateName, NStateData, Time1} -> - Debug1 = sys:handle_debug(Debug, fun print_event/3, - {Name, NStateName}, return), - loop(Parent, Name, NStateName, NStateData, - Mod, Time1, Debug1, Limits, Queue, QueueLen); - {reply, Reply, NStateName, NStateData} when From =/= undefined -> - Debug1 = reply(Name, From, Reply, Debug, NStateName), - loop(Parent, Name, NStateName, NStateData, - Mod, infinity, Debug1, Limits, Queue, QueueLen); - {reply, Reply, NStateName, NStateData, Time1} when From =/= undefined -> - Debug1 = reply(Name, From, Reply, Debug, NStateName), - loop(Parent, Name, NStateName, NStateData, - Mod, Time1, Debug1, Limits, Queue, QueueLen); - {migrate, NStateData, {Node, M, F, A}, Time1} -> - Reason = case catch rpc:call(Node, M, F, A, Time1) of - {badrpc, R} -> - {migration_error, R}; - {'EXIT', R} -> - {migration_error, R}; - {error, R} -> - {migration_error, R}; - {ok, Clone} -> - process_flag(trap_exit, true), - MRef = erlang:monitor(process, Clone), - TRef = erlang:start_timer(Time1, self(), timeout), - relay_messages(MRef, TRef, Clone, Queue); - Reply -> - {migration_error, {bad_reply, Reply}} - end, - terminate(Reason, Name, Msg, Mod, StateName, NStateData, Debug); - {stop, Reason, NStateData} -> - terminate(Reason, Name, Msg, Mod, StateName, NStateData, Debug); - {stop, Reason, Reply, NStateData} when From =/= undefined -> - {'EXIT', R} = (catch terminate(Reason, Name, Msg, Mod, - StateName, NStateData, Debug)), - reply(Name, From, Reply, Debug, StateName), - exit(R); - {'EXIT', What} -> - terminate(What, Name, Msg, Mod, StateName, StateData, Debug); - Reply -> - terminate({bad_return_value, Reply}, - Name, Msg, Mod, StateName, StateData, Debug) - end. - -dispatch({'$gen_event', Event}, Mod, StateName, StateData) -> - Mod:StateName(Event, StateData); -dispatch({'$gen_all_state_event', Event}, Mod, StateName, StateData) -> - Mod:handle_event(Event, StateName, StateData); -dispatch({'$gen_sync_event', From, Event}, Mod, StateName, StateData) -> - Mod:StateName(Event, From, StateData); -dispatch({'$gen_sync_all_state_event', From, Event}, - Mod, StateName, StateData) -> - Mod:handle_sync_event(Event, From, StateName, StateData); -dispatch({timeout, Ref, {'$gen_timer', Msg}}, Mod, StateName, StateData) -> - Mod:StateName({timeout, Ref, Msg}, StateData); -dispatch({timeout, _Ref, {'$gen_event', Event}}, Mod, StateName, StateData) -> - Mod:StateName(Event, StateData); -dispatch(Info, Mod, StateName, StateData) -> - Mod:handle_info(Info, StateName, StateData). - -from({'$gen_sync_event', From, _Event}) -> From; -from({'$gen_sync_all_state_event', From, _Event}) -> From; -from(_) -> undefined. - -%% Send a reply to the client. -reply({To, Tag}, Reply) -> - catch To ! {Tag, Reply}. - -reply(Name, {To, Tag}, Reply, Debug, StateName) -> - reply({To, Tag}, Reply), - sys:handle_debug(Debug, fun print_event/3, Name, - {out, Reply, To, StateName}). - -%%% --------------------------------------------------- -%%% Terminate the server. -%%% --------------------------------------------------- - --spec terminate(term(), _, _, atom(), _, _, _) -> no_return(). - -terminate(Reason, Name, Msg, Mod, StateName, StateData, Debug) -> - case catch Mod:terminate(Reason, StateName, StateData) of - {'EXIT', R} -> - error_info(Mod, R, Name, Msg, StateName, StateData, Debug), - exit(R); - _ -> - case Reason of - normal -> - exit(normal); - shutdown -> - exit(shutdown); - priority_shutdown -> - %% Priority shutdown should be considered as - %% shutdown by SASL - exit(shutdown); - {process_limit, _Limit} -> - exit(Reason); - {migrated, _Clone} -> - exit(normal); - _ -> - error_info(Mod, Reason, Name, Msg, StateName, StateData, Debug), - exit(Reason) - end - end. - -error_info(Mod, Reason, Name, Msg, StateName, StateData, Debug) -> - Reason1 = - case Reason of - {undef,[{M,F,A}|MFAs]} -> - case code:is_loaded(M) of - false -> - {'module could not be loaded',[{M,F,A}|MFAs]}; - _ -> - case erlang:function_exported(M, F, length(A)) of - true -> - Reason; - false -> - {'function not exported',[{M,F,A}|MFAs]} - end - end; - _ -> - Reason - end, - StateToPrint = case erlang:function_exported(Mod, print_state, 1) of - true -> (catch Mod:print_state(StateData)); - false -> StateData - end, - Str = "** State machine ~p terminating \n" ++ - get_msg_str(Msg) ++ - "** When State == ~p~n" - "** Data == ~p~n" - "** Reason for termination = ~n** ~p~n", - format(Str, [Name, get_msg(Msg), StateName, StateToPrint, Reason1]), - sys:print_log(Debug), - ok. - -get_msg_str({'$gen_event', _Event}) -> - "** Last event in was ~p~n"; -get_msg_str({'$gen_sync_event', _Event}) -> - "** Last sync event in was ~p~n"; -get_msg_str({'$gen_all_state_event', _Event}) -> - "** Last event in was ~p (for all states)~n"; -get_msg_str({'$gen_sync_all_state_event', _Event}) -> - "** Last sync event in was ~p (for all states)~n"; -get_msg_str({timeout, _Ref, {'$gen_timer', _Msg}}) -> - "** Last timer event in was ~p~n"; -get_msg_str({timeout, _Ref, {'$gen_event', _Msg}}) -> - "** Last timer event in was ~p~n"; -get_msg_str(_Msg) -> - "** Last message in was ~p~n". - -get_msg({'$gen_event', Event}) -> Event; -get_msg({'$gen_sync_event', Event}) -> Event; -get_msg({'$gen_all_state_event', Event}) -> Event; -get_msg({'$gen_sync_all_state_event', Event}) -> Event; -get_msg({timeout, Ref, {'$gen_timer', Msg}}) -> {timeout, Ref, Msg}; -get_msg({timeout, _Ref, {'$gen_event', Event}}) -> Event; -get_msg(Msg) -> Msg. - -%%----------------------------------------------------------------- -%% Status information -%%----------------------------------------------------------------- -format_status(Opt, StatusData) -> - [PDict, SysState, Parent, Debug, [Name, StateName, StateData, Mod, _Time, _Limits, _Queue, _QueueLen]] = - StatusData, - NameTag = if is_pid(Name) -> - pid_to_list(Name); - is_atom(Name) -> - Name - end, - Header = lists:concat(["Status for state machine ", NameTag]), - Log = sys:get_debug(log, Debug, []), - Specfic = - case erlang:function_exported(Mod, format_status, 2) of - true -> - case catch Mod:format_status(Opt,[PDict,StateData]) of - {'EXIT', _} -> [{data, [{"StateData", StateData}]}]; - Else -> Else - end; - _ -> - [{data, [{"StateData", StateData}]}] - end, - [{header, Header}, - {data, [{"Status", SysState}, - {"Parent", Parent}, - {"Logged events", Log}, - {"StateName", StateName}]} | - Specfic]. - -%%----------------------------------------------------------------- -%% Resources limit management -%%----------------------------------------------------------------- -%% Extract know limit options -limit_options(Options) -> - limit_options(Options, #limits{}). -limit_options([], Limits) -> - Limits; -%% Maximum number of messages allowed in the process message queue -limit_options([{max_queue,N}|Options], Limits) - when is_integer(N) -> - NewLimits = Limits#limits{max_queue=N}, - limit_options(Options, NewLimits); -limit_options([_|Options], Limits) -> - limit_options(Options, Limits). - -%% Throw max_queue if we have reach the max queue size -%% Returns ok otherwise -message_queue_len(#limits{max_queue = undefined}, _QueueLen) -> - ok; -message_queue_len(#limits{max_queue = MaxQueue}, QueueLen) -> - Pid = self(), - case process_info(Pid, message_queue_len) of - {message_queue_len, N} when N + QueueLen > MaxQueue -> - throw({process_limit, {max_queue, N + QueueLen}}); - _ -> - ok - end. diff --git a/src/p1_mnesia.erl b/src/p1_mnesia.erl deleted file mode 100644 index b792472a6..000000000 --- a/src/p1_mnesia.erl +++ /dev/null @@ -1,49 +0,0 @@ -%%% ==================================================================== -%%% ``The contents of this file are subject to the Erlang Public License, -%%% Version 1.1, (the "License"); you may not use this file except in -%%% compliance with the License. You should have received a copy of the -%%% Erlang Public License along with this software. If not, it can be -%%% retrieved via the world wide web at http://www.erlang.org/. -%%% -%%% -%%% Software distributed under the License is distributed on an "AS IS" -%%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See -%%% the License for the specific language governing rights and limitations -%%% under the License. -%%% -%%% -%%% The Initial Developer of the Original Code is ProcessOne. -%%% Portions created by ProcessOne are Copyright 2006-2014, ProcessOne -%%% All Rights Reserved.'' -%%% -%%% This software is copyright 2006-2014, ProcessOne. - --module(p1_mnesia). - --author('mickael.remond@process-one.net'). - --export([count_records/2]). - -%% Return the number of records matching a given match expression. -%% This function is intended to be used inside a Mnesia transaction. -%% The count has been written to use the fewest possible memory by -%% getting the record by small increment and by using continuation. --define(BATCHSIZE, 100). - -count_records(Tab, MatchExpression) -> - case mnesia:select(Tab, [{MatchExpression, [], [[]]}], - ?BATCHSIZE, read) - of - {Result, Cont} -> - Count = length(Result), - count_records_cont(Cont, Count); - '$end_of_table' -> 0 - end. - -count_records_cont(Cont, Count) -> - case mnesia:select(Cont) of - {Result, Cont} -> - NewCount = Count + length(Result), - count_records_cont(Cont, NewCount); - '$end_of_table' -> Count - end. diff --git a/src/treap.erl b/src/treap.erl deleted file mode 100644 index 9d1d69fc9..000000000 --- a/src/treap.erl +++ /dev/null @@ -1,166 +0,0 @@ -%%%---------------------------------------------------------------------- -%%% File : treap.erl -%%% Author : Alexey Shchepin <alexey@process-one.net> -%%% Purpose : Treaps implementation -%%% Created : 22 Apr 2008 by Alexey Shchepin <alexey@process-one.net> -%%% -%%% -%%% ejabberd, Copyright (C) 2002-2014 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - --module(treap). - --export([empty/0, insert/4, delete/2, delete_root/1, - get_root/1, lookup/2, is_empty/1, fold/3, from_list/1, - to_list/1]). - --type hashkey() :: {non_neg_integer(), any()}. - --type treap() :: {hashkey(), any(), any(), treap(), treap()} | nil. - --export_type([treap/0]). - -empty() -> nil. - -insert(Key, Priority, Value, Tree) -> - HashKey = {erlang:phash2(Key), Key}, - insert1(Tree, HashKey, Priority, Value). - -insert1(nil, HashKey, Priority, Value) -> - {HashKey, Priority, Value, nil, nil}; -insert1({HashKey1, Priority1, Value1, Left, Right} = - Tree, - HashKey, Priority, Value) -> - if HashKey < HashKey1 -> - heapify({HashKey1, Priority1, Value1, - insert1(Left, HashKey, Priority, Value), Right}); - HashKey > HashKey1 -> - heapify({HashKey1, Priority1, Value1, Left, - insert1(Right, HashKey, Priority, Value)}); - Priority == Priority1 -> - {HashKey, Priority, Value, Left, Right}; - true -> - insert1(delete_root(Tree), HashKey, Priority, Value) - end. - -heapify({_HashKey, _Priority, _Value, nil, nil} = - Tree) -> - Tree; -heapify({HashKey, Priority, Value, nil = Left, - {HashKeyR, PriorityR, ValueR, LeftR, RightR}} = - Tree) -> - if PriorityR > Priority -> - {HashKeyR, PriorityR, ValueR, - {HashKey, Priority, Value, Left, LeftR}, RightR}; - true -> Tree - end; -heapify({HashKey, Priority, Value, - {HashKeyL, PriorityL, ValueL, LeftL, RightL}, - nil = Right} = - Tree) -> - if PriorityL > Priority -> - {HashKeyL, PriorityL, ValueL, LeftL, - {HashKey, Priority, Value, RightL, Right}}; - true -> Tree - end; -heapify({HashKey, Priority, Value, - {HashKeyL, PriorityL, ValueL, LeftL, RightL} = Left, - {HashKeyR, PriorityR, ValueR, LeftR, RightR} = Right} = - Tree) -> - if PriorityR > Priority -> - {HashKeyR, PriorityR, ValueR, - {HashKey, Priority, Value, Left, LeftR}, RightR}; - PriorityL > Priority -> - {HashKeyL, PriorityL, ValueL, LeftL, - {HashKey, Priority, Value, RightL, Right}}; - true -> Tree - end. - -delete(Key, Tree) -> - HashKey = {erlang:phash2(Key), Key}, - delete1(HashKey, Tree). - -delete1(_HashKey, nil) -> nil; -delete1(HashKey, - {HashKey1, Priority1, Value1, Left, Right} = Tree) -> - if HashKey < HashKey1 -> - {HashKey1, Priority1, Value1, delete1(HashKey, Left), - Right}; - HashKey > HashKey1 -> - {HashKey1, Priority1, Value1, Left, - delete1(HashKey, Right)}; - true -> delete_root(Tree) - end. - -delete_root({HashKey, Priority, Value, Left, Right}) -> - case {Left, Right} of - {nil, nil} -> nil; - {_, nil} -> Left; - {nil, _} -> Right; - {{HashKeyL, PriorityL, ValueL, LeftL, RightL}, - {HashKeyR, PriorityR, ValueR, LeftR, RightR}} -> - if PriorityL > PriorityR -> - {HashKeyL, PriorityL, ValueL, LeftL, - delete_root({HashKey, Priority, Value, RightL, Right})}; - true -> - {HashKeyR, PriorityR, ValueR, - delete_root({HashKey, Priority, Value, Left, LeftR}), - RightR} - end - end. - -is_empty(nil) -> true; -is_empty({_HashKey, _Priority, _Value, _Left, - _Right}) -> - false. - -get_root({{_Hash, Key}, Priority, Value, _Left, - _Right}) -> - {Key, Priority, Value}. - -lookup(Key, Tree) -> - HashKey = {erlang:phash2(Key), Key}, - lookup1(Tree, HashKey). - -lookup1(nil, _HashKey) -> error; -lookup1({HashKey1, Priority1, Value1, Left, Right}, - HashKey) -> - if HashKey < HashKey1 -> lookup1(Left, HashKey); - HashKey > HashKey1 -> lookup1(Right, HashKey); - true -> {ok, Priority1, Value1} - end. - -fold(_F, Acc, nil) -> Acc; -fold(F, Acc, - {{_Hash, Key}, Priority, Value, Left, Right}) -> - Acc1 = F({Key, Priority, Value}, Acc), - Acc2 = fold(F, Acc1, Left), - fold(F, Acc2, Right). - -to_list(Tree) -> to_list(Tree, []). - -to_list(nil, Acc) -> Acc; -to_list(Tree, Acc) -> - Root = get_root(Tree), - to_list(delete_root(Tree), [Root | Acc]). - -from_list(List) -> from_list(List, nil). - -from_list([{Key, Priority, Value} | Tail], Tree) -> - from_list(Tail, insert(Key, Priority, Value, Tree)); -from_list([], Tree) -> Tree. |