aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorPaweł Chmielowski <pawel@process-one.net>2021-11-16 19:39:59 +0100
committerPaweł Chmielowski <pawel@process-one.net>2021-11-17 11:32:42 +0100
commit405a5172d5bda5fd40b6a580b87f3fab1ecdd47c (patch)
treef2c0fd501a4111e224c59f2247c4da09c48ae1e2 /src
parentBetter version of dialyzer fix (diff)
Improve mod_multicast
Diffstat (limited to 'src')
-rw-r--r--src/mod_multicast.erl426
1 files changed, 136 insertions, 290 deletions
diff --git a/src/mod_multicast.erl b/src/mod_multicast.erl
index 161d3a4c4..fa076da70 100644
--- a/src/mod_multicast.erl
+++ b/src/mod_multicast.erl
@@ -35,7 +35,7 @@
%% API
-export([start/2, stop/1, reload/3,
- user_send_packet/1]).
+ user_send_packet/1]).
%% gen_server callbacks
-export([init/1, handle_info/2, handle_call/3,
@@ -51,11 +51,6 @@
response,
ts :: integer()}).
--record(dest, {jid_string :: binary() | none,
- jid_jid :: jid() | undefined,
- type :: bcc | cc | noreply | ofrom | replyroom | replyto | to,
- address :: address()}).
-
-type limit_value() :: {default | custom, integer()}.
-record(limits, {message :: limit_value(),
presence :: limit_value()}).
@@ -63,14 +58,6 @@
-record(service_limits, {local :: #limits{},
remote :: #limits{}}).
--type routing() :: route_single | {route_multicast, binary(), #service_limits{}}.
-
--record(group, {server :: binary(),
- dests :: [#dest{}],
- multicast :: routing() | undefined,
- others :: [address()],
- addresses :: [address()]}).
-
-record(state, {lserver :: binary(),
lservice :: binary(),
access :: atom(),
@@ -117,7 +104,7 @@ reload(LServerS, NewOpts, OldOpts) ->
user_send_packet({#presence{} = Packet, C2SState} = Acc) ->
case xmpp:get_subtag(Packet, #addresses{}) of
#addresses{list = Addresses} ->
- {ToDeliver, _Delivereds} = split_addresses_todeliver(Addresses),
+ {CC, BCC, _Invalid, _Delivered} = partition_addresses(Addresses),
NewState =
lists:foldl(
fun(Address, St) ->
@@ -138,7 +125,7 @@ user_send_packet({#presence{} = Packet, C2SState} = Acc) ->
undefined ->
St
end
- end, C2SState, ToDeliver),
+ end, C2SState, CC ++ BCC),
{Packet, NewState};
false ->
Acc
@@ -308,19 +295,10 @@ iq_vcard(Lang, State) ->
%%%-------------------------
-spec route_trusted(binary(), binary(), jid(), [jid()], stanza()) -> 'ok'.
-route_trusted(LServiceS, LServerS, FromJID,
- Destinations, Packet) ->
- Packet_stripped = Packet,
- Delivereds = [],
- Dests2 = lists:map(
- fun(D) ->
- #dest{jid_string = jid:encode(D),
- jid_jid = D, type = bcc,
- address = #address{type = bcc, jid = D}}
- end, Destinations),
- Groups = group_dests(Dests2),
- route_common(LServerS, LServiceS, FromJID, Groups,
- Delivereds, Packet_stripped).
+route_trusted(LServiceS, LServerS, FromJID, Destinations, Packet) ->
+ Addresses = [#address{type = bcc, jid = D} || D <- Destinations],
+ Groups = group_by_destinations(Addresses, #{}),
+ route_grouped(LServerS, LServiceS, FromJID, Groups, [], Packet).
-spec route_untrusted(binary(), binary(), atom(), #service_limits{}, stanza()) -> 'ok'.
route_untrusted(LServiceS, LServerS, Access, SLimits, Packet) ->
@@ -356,50 +334,88 @@ route_untrusted(LServiceS, LServerS, Access, SLimits, Packet) ->
route_untrusted2(LServiceS, LServerS, Access, SLimits, Packet) ->
FromJID = xmpp:get_from(Packet),
ok = check_access(LServerS, Access, FromJID),
- {ok, Packet_stripped, Addresses} = strip_addresses_element(Packet),
- {To_deliver, Delivereds} = split_addresses_todeliver(Addresses),
- Dests = convert_dest_record(To_deliver),
- {Dests2, Not_jids} = split_dests_jid(Dests),
- report_not_jid(FromJID, Packet, Not_jids),
- ok = check_limit_dests(SLimits, FromJID, Packet, Dests2),
- Groups = group_dests(Dests2),
+ {ok, PacketStripped, Addresses} = strip_addresses_element(Packet),
+ {CC, BCC, NotJids, Rest} = partition_addresses(Addresses),
+ report_not_jid(FromJID, Packet, NotJids),
+ ok = check_limit_dests(SLimits, FromJID, Packet, length(CC) + length(BCC)),
+ Groups0 = group_by_destinations(CC, #{}),
+ Groups = group_by_destinations(BCC, Groups0),
ok = check_relay(FromJID#jid.server, LServerS, Groups),
- route_common(LServerS, LServiceS, FromJID, Groups,
- Delivereds, Packet_stripped).
-
--spec route_common(binary(), binary(), jid(), [#group{}],
- [address()], stanza()) -> 'ok'.
-route_common(LServerS, LServiceS, FromJID, Groups,
- Delivereds, Packet_stripped) ->
- Groups2 = look_cached_servers(LServerS, LServiceS, Groups),
- Groups3 = build_others_xml(Groups2),
- Groups4 = add_addresses(Delivereds, Groups3),
- AGroups = decide_action_groups(Groups4),
- act_groups(FromJID, Packet_stripped, LServiceS,
- AGroups).
-
--spec act_groups(jid(), stanza(), binary(), [{routing(), #group{}}]) -> 'ok'.
-act_groups(FromJID, Packet_stripped, LServiceS, AGroups) ->
+ route_grouped(LServerS, LServiceS, FromJID, Groups, Rest, PacketStripped).
+
+-spec mark_as_delivered([address()]) -> [address()].
+mark_as_delivered(Addresses) ->
+ [A#address{delivered = true} || A <- Addresses].
+
+-spec route_individual(jid(), [address()], [address()], [address()], stanza()) -> ok.
+route_individual(From, CC, BCC, Other, Packet) ->
+ CCDelivered = mark_as_delivered(CC),
+ Addresses = CCDelivered ++ Other,
+ PacketWithAddresses = xmpp:append_subtags(Packet, [#addresses{list = Addresses}]),
lists:foreach(
- fun(AGroup) ->
- perform(FromJID, Packet_stripped, LServiceS,
- AGroup)
- end, AGroups).
-
--spec perform(jid(), stanza(), binary(),
- {routing(), #group{}}) -> 'ok'.
-perform(From, Packet, _,
- {route_single, Group}) ->
+ fun(#address{jid = To}) ->
+ ejabberd_router:route(xmpp:set_from_to(PacketWithAddresses, From, To))
+ end, CC),
lists:foreach(
- fun(ToUser) ->
- Group_others = strip_other_bcc(ToUser, Group#group.others),
- route_packet(From, ToUser, Packet,
- Group_others, Group#group.addresses)
- end, Group#group.dests);
-perform(From, Packet, _,
- {{route_multicast, JID, RLimits}, Group}) ->
- route_packet_multicast(From, JID, Packet,
- Group#group.dests, Group#group.addresses, RLimits).
+ fun(#address{jid = To} = Address) ->
+ Packet2 = case Addresses of
+ [] ->
+ Packet;
+ _ ->
+ xmpp:append_subtags(Packet, [#addresses{list = [Address | Addresses]}])
+ end,
+ ejabberd_router:route(xmpp:set_from_to(Packet2, From, To))
+ end, BCC).
+
+-spec route_chunk(jid(), jid(), stanza(), [address()]) -> ok.
+route_chunk(From, To, Packet, Addresses) ->
+ PacketWithAddresses = xmpp:append_subtags(Packet, [#addresses{list = Addresses}]),
+ ejabberd_router:route(xmpp:set_from_to(PacketWithAddresses, From, To)).
+
+-spec route_in_chunks(jid(), jid(), stanza(), integer(), [address()], [address()], [address()]) -> ok.
+route_in_chunks(_From, _To, _Packet, _Limit, [], [], _) ->
+ ok;
+route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses) when length(CC) > Limit ->
+ {Chunk, Rest} = lists:split(Limit, CC),
+ route_chunk(From, To, Packet, Chunk ++ RestOfAddresses),
+ route_in_chunks(From, To, Packet, Limit, Rest, BCC, RestOfAddresses);
+route_in_chunks(From, To, Packet, Limit, [], BCC, RestOfAddresses) when length(BCC) > Limit ->
+ {Chunk, Rest} = lists:split(Limit, BCC),
+ route_chunk(From, To, Packet, Chunk ++ RestOfAddresses),
+ route_in_chunks(From, To, Packet, Limit, [], Rest, RestOfAddresses);
+route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses) when length(BCC) + length(CC) > Limit ->
+ {Chunk, Rest} = lists:split(Limit - length(CC), BCC),
+ route_chunk(From, To, Packet, CC ++ Chunk ++ RestOfAddresses),
+ route_in_chunks(From, To, Packet, Limit, [], Rest, RestOfAddresses);
+route_in_chunks(From, To, Packet, _Limit, CC, BCC, RestOfAddresses) ->
+ route_chunk(From, To, Packet, CC ++ BCC ++ RestOfAddresses).
+
+-spec route_multicast(jid(), jid(), [address()], [address()], [address()], stanza(), #limits{}) -> ok.
+route_multicast(From, To, CC, BCC, RestOfAddresses, Packet, Limits) ->
+ {_Type, Limit} = get_limit_number(element(1, Packet),
+ Limits),
+ route_in_chunks(From, To, Packet, Limit, CC, BCC, RestOfAddresses).
+
+-spec route_grouped(binary(), binary(), jid(), #{}, [address()], stanza()) -> ok.
+route_grouped(LServer, LService, From, Groups, RestOfAddresses, Packet) ->
+ maps:fold(
+ fun(Server, {CC, BCC}, _) ->
+ OtherCC = maps:fold(
+ fun(Server2, _, Res) when Server2 == Server ->
+ Res;
+ (_, {CC2, _}, Res) ->
+ mark_as_delivered(CC2) ++ Res
+ end, [], Groups),
+ case search_server_on_cache(Server,
+ LServer, LService,
+ {?MAXTIME_CACHE_POSITIVE,
+ ?MAXTIME_CACHE_NEGATIVE}) of
+ route_single ->
+ route_individual(From, CC, BCC, OtherCC ++ RestOfAddresses, Packet);
+ {route_multicast, Service, Limits} ->
+ route_multicast(From, Service, CC, BCC, OtherCC ++ RestOfAddresses, Packet, Limits)
+ end
+ end, ok, Groups).
%%%-------------------------
%%% Check access permission
@@ -426,244 +442,88 @@ strip_addresses_element(Packet) ->
end.
%%%-------------------------
-%%% Strip third-party bcc 'addresses'
-%%%-------------------------
-
-strip_other_bcc(#dest{jid_jid = ToUserJid}, Group_others) ->
- lists:filter(
- fun(#address{jid = JID, type = Type}) ->
- case {JID, Type} of
- {ToUserJid, bcc} -> true;
- {_, bcc} -> false;
- _ -> true
- end
- end,
- Group_others).
-
-%%%-------------------------
%%% Split Addresses
%%%-------------------------
--spec split_addresses_todeliver([address()]) -> {[address()], [address()]}.
-split_addresses_todeliver(Addresses) ->
- lists:partition(
- fun(#address{delivered = true}) ->
- false;
- (#address{type = Type}) ->
- case Type of
- to -> true;
- cc -> true;
- bcc -> true;
- _ -> false
- end
- end, Addresses).
+partition_addresses(Addresses) ->
+ lists:foldl(
+ fun(#address{delivered = true} = A, {C, B, I, D}) ->
+ {C, B, I, [A | D]};
+ (#address{type = T, jid = undefined} = A, {C, B, I, D})
+ when T == to; T == cc; T == bcc ->
+ {C, B, [A | I], D};
+ (#address{type = T} = A, {C, B, I, D})
+ when T == to; T == cc ->
+ {[A | C], B, I, D};
+ (#address{type = bcc} = A, {C, B, I, D}) ->
+ {C, [A | B], I, D};
+ (A, {C, B, I, D}) ->
+ {C, B, I, [A | D]}
+ end, {[], [], [], []}, Addresses).
%%%-------------------------
%%% Check does not exceed limit of destinations
%%%-------------------------
--spec check_limit_dests(#service_limits{}, jid(), stanza(), [address()]) -> ok.
-check_limit_dests(SLimits, FromJID, Packet,
- Addresses) ->
+-spec check_limit_dests(#service_limits{}, jid(), stanza(), integer()) -> ok.
+check_limit_dests(SLimits, FromJID, Packet, NumOfAddresses) ->
SenderT = sender_type(FromJID),
Limits = get_slimit_group(SenderT, SLimits),
- Type_of_stanza = type_of_stanza(Packet),
- {_Type, Limit_number} = get_limit_number(Type_of_stanza,
- Limits),
- case length(Addresses) > Limit_number of
+ StanzaType = type_of_stanza(Packet),
+ {_Type, Limit} = get_limit_number(StanzaType,
+ Limits),
+ case NumOfAddresses > Limit of
false -> ok;
true -> throw(etoorec)
end.
-%%%-------------------------
-%%% Convert Destination XML to record
-%%%-------------------------
-
--spec convert_dest_record([address()]) -> [#dest{}].
-convert_dest_record(Addrs) ->
- lists:map(
- fun(#address{jid = undefined, type = Type} = Addr) ->
- #dest{jid_string = none,
- type = Type, address = Addr};
- (#address{jid = JID, type = Type} = Addr) ->
- #dest{jid_string = jid:encode(JID), jid_jid = JID,
- type = Type, address = Addr}
- end, Addrs).
-
-%%%-------------------------
-%%% Split destinations by existence of JID
-%%% and send error messages for other dests
-%%%-------------------------
--spec split_dests_jid([#dest{}]) -> {[#dest{}], [#dest{}]}.
-split_dests_jid(Dests) ->
- lists:partition(fun (Dest) ->
- case Dest#dest.jid_string of
- none -> false;
- _ -> true
- end
- end,
- Dests).
-
--spec report_not_jid(jid(), stanza(), [#dest{}]) -> any().
-report_not_jid(From, Packet, Dests) ->
- Dests2 = [fxml:element_to_binary(xmpp:encode(Dest#dest.address))
- || Dest <- Dests],
- [route_error(
- xmpp:set_from_to(Packet, From, From), jid_malformed,
- str:format(?T("This service can not process the address: ~s"), [D]))
- || D <- Dests2].
+-spec report_not_jid(jid(), stanza(), [address()]) -> any().
+report_not_jid(From, Packet, Addresses) ->
+ lists:foreach(
+ fun(Address) ->
+ route_error(
+ xmpp:set_from_to(Packet, From, From), jid_malformed,
+ str:format(?T("This service can not process the address: ~s"),
+ [fxml:element_to_binary(xmpp:encode(Address))]))
+ end, Addresses).
%%%-------------------------
%%% Group destinations by their servers
%%%-------------------------
--spec group_dests([#dest{}]) -> [#group{}].
-group_dests(Dests) ->
- D = lists:foldl(fun (Dest, Dict) ->
- ServerS = (Dest#dest.jid_jid)#jid.server,
- dict:append(ServerS, Dest, Dict)
- end,
- dict:new(), Dests),
- Keys = dict:fetch_keys(D),
- [#group{server = Key, dests = dict:fetch(Key, D),
- addresses = [], others = []}
- || Key <- Keys].
-
-%%%-------------------------
-%%% Look for cached responses
-%%%-------------------------
-
-look_cached_servers(LServerS, LServiceS, Groups) ->
- [look_cached(LServerS, LServiceS, Group) || Group <- Groups].
-
-look_cached(LServerS, LServiceS, G) ->
- Maxtime_positive = (?MAXTIME_CACHE_POSITIVE),
- Maxtime_negative = (?MAXTIME_CACHE_NEGATIVE),
- Cached_response = search_server_on_cache(G#group.server,
- LServerS, LServiceS,
- {Maxtime_positive,
- Maxtime_negative}),
- G#group{multicast = Cached_response}.
-
-%%%-------------------------
-%%% Build delivered XML element
-%%%-------------------------
-
-build_others_xml(Groups) ->
- [Group#group{others =
- build_other_xml(Group#group.dests)}
- || Group <- Groups].
-
-build_other_xml(Dests) ->
- lists:foldl(fun (Dest, R) ->
- XML = Dest#dest.address,
- case Dest#dest.type of
- to -> [add_delivered(XML) | R];
- cc -> [add_delivered(XML) | R];
- _ -> [XML | R]
- end
- end,
- [], Dests).
-
--spec add_delivered(address()) -> address().
-add_delivered(Addr) ->
- Addr#address{delivered = true}.
-
-%%%-------------------------
-%%% Add preliminary packets
-%%%-------------------------
-
-add_addresses(Delivereds, Groups) ->
- Ps = [Group#group.others || Group <- Groups],
- add_addresses2(Delivereds, Groups, [], [], Ps).
-
-add_addresses2(_, [], Res, _, []) -> Res;
-add_addresses2(Delivereds, [Group | Groups], Res, Pa,
- [Pi | Pz]) ->
- Addresses = lists:append([Delivereds] ++ Pa ++ Pz),
- Group2 = Group#group{addresses = Addresses},
- add_addresses2(Delivereds, Groups, [Group2 | Res],
- [Pi | Pa], Pz).
-
-%%%-------------------------
-%%% Decide action groups
-%%%-------------------------
-
--spec decide_action_groups([#group{}]) -> [{routing(), #group{}}].
-decide_action_groups(Groups) ->
- [{Group#group.multicast, Group}
- || Group <- Groups].
+group_by_destinations(Addrs, Map) ->
+ lists:foldl(
+ fun
+ (#address{type = Type, jid = #jid{lserver = Server}} = Addr, Map2) when Type == to; Type == cc ->
+ maps:update_with(Server,
+ fun({CC, BCC}) ->
+ {[Addr | CC], BCC}
+ end, {[Addr], []}, Map2);
+ (#address{type = bcc, jid = #jid{lserver = Server}} = Addr, Map2) ->
+ maps:update_with(Server,
+ fun({CC, BCC}) ->
+ {CC, [Addr | BCC]}
+ end, {[], [Addr]}, Map2)
+ end, Map, Addrs).
%%%-------------------------
%%% Route packet
%%%-------------------------
--spec route_packet(jid(), #dest{}, stanza(), [addresses()], [addresses()]) -> 'ok'.
-route_packet(From, ToDest, Packet, Others, Addresses) ->
- Dests = case ToDest#dest.type of
- bcc -> [];
- _ -> [ToDest]
- end,
- route_packet2(From, ToDest#dest.jid_string, Dests,
- Packet, {Others, Addresses}).
-
--spec route_packet_multicast(jid(), binary(), stanza(), [#dest{}], [address()], #limits{}) -> 'ok'.
-route_packet_multicast(From, ToS, Packet, Dests,
- Addresses, Limits) ->
- Type_of_stanza = type_of_stanza(Packet),
- {_Type, Limit_number} = get_limit_number(Type_of_stanza,
- Limits),
- Fragmented_dests = fragment_dests(Dests, Limit_number),
- lists:foreach(fun(DFragment) ->
- route_packet2(From, ToS, DFragment, Packet,
- Addresses)
- end, Fragmented_dests).
-
--spec route_packet2(jid(), binary(), [#dest{}], stanza(), {[address()], [address()]} | [address()]) -> 'ok'.
-route_packet2(From, ToS, Dests, Packet, Addresses) ->
- Els = case append_dests(Dests, Addresses) of
- [] ->
- xmpp:get_els(Packet);
- ACs ->
- [#addresses{list = ACs}|xmpp:get_els(Packet)]
- end,
- Packet2 = xmpp:set_els(Packet, Els),
- ToJID = stj(ToS),
- ejabberd_router:route(xmpp:set_from_to(Packet2, From, ToJID)).
-
--spec append_dests([#dest{}], {[address()], [address()]} | [address()]) -> [address()].
-append_dests(_Dests, {Others, Addresses}) ->
- Addresses ++ Others;
-append_dests([], Addresses) -> Addresses;
-append_dests([Dest | Dests], Addresses) ->
- append_dests(Dests, [Dest#dest.address | Addresses]).
-
%%%-------------------------
%%% Check relay
%%%-------------------------
--spec check_relay(binary(), binary(), [#group{}]) -> ok.
+-spec check_relay(binary(), binary(), #{}) -> ok.
check_relay(RS, LS, Gs) ->
- case check_relay_required(RS, LS, Gs) of
- false -> ok;
- true -> throw(edrelay)
- end.
-
--spec check_relay_required(binary(), binary(), [#group{}]) -> boolean().
-check_relay_required(RServer, LServerS, Groups) ->
- case lists:suffix(str:tokens(LServerS, <<".">>),
- str:tokens(RServer, <<".">>)) of
- true -> false;
- false -> check_relay_required(LServerS, Groups)
+ case lists:suffix(str:tokens(LS, <<".">>),
+ str:tokens(RS, <<".">>)) orelse
+ (maps:is_key(LS, Gs) andalso maps:size(Gs) == 1) of
+ true -> ok;
+ _ -> throw(edrelay)
end.
--spec check_relay_required(binary(), [#group{}]) -> boolean().
-check_relay_required(LServerS, Groups) ->
- lists:any(fun (Group) -> Group#group.server /= LServerS
- end,
- Groups).
-
%%%-------------------------
%%% Check protocol support: Send request
%%%-------------------------
@@ -1060,20 +920,6 @@ get_slimit_group(local, SLimits) ->
get_slimit_group(remote, SLimits) ->
SLimits#service_limits.remote.
-fragment_dests(Dests, Limit_number) ->
- {R, _} = lists:foldl(fun (Dest, {Res, Count}) ->
- case Count of
- Limit_number ->
- Head2 = [Dest], {[Head2 | Res], 0};
- _ ->
- [Head | Tail] = Res,
- Head2 = [Dest | Head],
- {[Head2 | Tail], Count + 1}
- end
- end,
- {[[]], 0}, Dests),
- R.
-
%%%-------------------------
%%% Limits: XEP-0128 Service Discovery Extensions
%%%-------------------------