aboutsummaryrefslogblamecommitdiff
path: root/src/ejabberd_shaper.erl
blob: 4bd5229fa8829610370a87a128db27d11831ba9f (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
                                                                         
                                                  









                                                                      
   


                                                                           
   
                                                                         
                         
                       
 


                                                                  


                                                             
 
                       
 
                                        
                                           




                                                                                 
 


                                                                      


                                                             
 




























                                                                                              
                                         
































                                                                        
           

















                                                                                        

                     


                                                         

                     
                                      
                             
                                                                            
 
                                                            


                                      



































                                                                       
 




                                                                           
        
 














































                                                                                                   
                                                                           





                                             
%%%----------------------------------------------------------------------
%%% ejabberd, Copyright (C) 2002-2022   ProcessOne
%%%
%%% This program is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU General Public License as
%%% published by the Free Software Foundation; either version 2 of the
%%% License, or (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
%%% General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License along
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
%%%----------------------------------------------------------------------
-module(ejabberd_shaper).
-behaviour(gen_server).

-export([start_link/0, new/1, update/2, match/3, get_max_rate/1]).
-export([reload_from_config/0]).
-export([validator/1, shaper_rules_validator/0]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
	 terminate/2, code_change/3]).

-include("logger.hrl").

-type state() :: #{hosts := [binary()]}.
-type shaper() :: none | p1_shaper:state().
-type shaper_rate() :: {pos_integer(), pos_integer()} | pos_integer() | infinity.
-type shaper_rule() :: {atom() | pos_integer(), [acl:access_rule()]}.
-type shaper_rate_rule() :: {shaper_rate(), [acl:access_rule()]}.

-export_type([shaper/0, shaper_rule/0, shaper_rate/0]).

%%%===================================================================
%%% API
%%%===================================================================
-spec start_link() -> {ok, pid()} | {error, any()}.
start_link() ->
    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).

-spec match(global | binary(), atom() | [shaper_rule()],
	    jid:jid() | jid:ljid() | inet:ip_address() | acl:match()) -> none | shaper_rate().
match(_, none, _) -> none;
match(_, infinity, _) -> infinity;
match(Host, Shaper, Match) when is_map(Match) ->
    Rules = if is_atom(Shaper) -> read_shaper_rules(Shaper, Host);
	       true -> Shaper
	    end,
    Rate = acl:match_rules(Host, Rules, Match, none),
    read_shaper(Rate);
match(Host, Shaper, IP) when tuple_size(IP) == 4; tuple_size(IP) == 8 ->
    match(Host, Shaper, #{ip => IP});
match(Host, Shaper, JID) ->
    match(Host, Shaper, #{usr => jid:tolower(JID)}).

-spec get_max_rate(none | shaper_rate()) -> none | pos_integer().
get_max_rate({Rate, _}) -> Rate;
get_max_rate(Rate) when is_integer(Rate), Rate > 0 -> Rate;
get_max_rate(_) -> none.

-spec new(none | shaper_rate()) -> shaper().
new({Rate, Burst}) -> p1_shaper:new(Rate, Burst);
new(Rate) when is_integer(Rate), Rate > 0 -> p1_shaper:new(Rate);
new(_) -> none.

-spec update(shaper(), non_neg_integer()) -> {shaper(), non_neg_integer()}.
update(none, _Size) -> {none, 0};
update(Shaper1, Size) ->
    Shaper2 = p1_shaper:update(Shaper1, Size),
    ?DEBUG("Shaper update:~n~ts =>~n~ts",
	   [p1_shaper:pp(Shaper1), p1_shaper:pp(Shaper2)]),
    Shaper2.

-spec validator(shaper | shaper_rules) -> econf:validator().
validator(shaper) ->
    econf:options(
      #{'_' => shaper_validator()},
      [{disallowed, reserved()}, {return, map}, unique]);
validator(shaper_rules) ->
    econf:options(
      #{'_' => shaper_rules_validator()},
      [{disallowed, reserved()}, unique]).

-spec shaper_rules_validator() -> econf:validator().
shaper_rules_validator() ->
    fun(L) when is_list(L) ->
	    lists:map(
	      fun({K, V}) ->
		      {(shaper_name())(K), (acl:access_validator())(V)};
		 (N) ->
		      {(shaper_name())(N), [{acl, all}]}
	      end, lists:flatten(L));
       (N) ->
	    [{(shaper_name())(N), [{acl, all}]}]
    end.

-spec reload_from_config() -> ok.
reload_from_config() ->
    gen_server:call(?MODULE, reload_from_config, timer:minutes(1)).

%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([]) ->
    create_tabs(),
    Hosts = ejabberd_option:hosts(),
    load_from_config([], Hosts),
    ejabberd_hooks:add(config_reloaded, ?MODULE, reload_from_config, 20),
    {ok, #{hosts => Hosts}}.

-spec handle_call(term(), term(), state()) -> {reply, ok, state()} | {noreply, state()}.
handle_call(reload_from_config, _, #{hosts := OldHosts} = State) ->
    NewHosts = ejabberd_option:hosts(),
    load_from_config(OldHosts, NewHosts),
    {reply, ok, State#{hosts => NewHosts}};
handle_call(Request, From, State) ->
    ?WARNING_MSG("Unexpected call from ~p: ~p", [From, Request]),
    {noreply, State}.

-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast(Msg, State) ->
    ?WARNING_MSG("Unexpected cast: ~p", [Msg]),
    {noreply, State}.

-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info(Info, State) ->
    ?WARNING_MSG("Unexpected info: ~p", [Info]),
    {noreply, State}.

-spec terminate(any(), state()) -> ok.
terminate(_Reason, _State) ->
    ejabberd_hooks:delete(config_reloaded, ?MODULE, reload_from_config, 20).

-spec code_change(term(), state(), term()) -> {ok, state()}.
code_change(_OldVsn, State, _Extra) ->
    {ok, State}.

%%%===================================================================
%%% Internal functions
%%%===================================================================
%%%===================================================================
%%% Table management
%%%===================================================================
-spec load_from_config([binary()], [binary()]) -> ok.
load_from_config(OldHosts, NewHosts) ->
    ?DEBUG("Loading shaper rules from config", []),
    Shapers = ejabberd_option:shaper(),
    ets:insert(shaper, maps:to_list(Shapers)),
    ets:insert(
      shaper_rules,
      lists:flatmap(
	fun(Host) ->
		lists:flatmap(
		  fun({Name, List}) ->
			  case resolve_shapers(Name, List, Shapers) of
			      [] -> [];
			      List1 ->
				  [{{Name, Host}, List1}]
			  end
		  end, ejabberd_option:shaper_rules(Host))
	end, [global|NewHosts])),
    lists:foreach(
      fun(Host) ->
	      ets:match_delete(shaper_rules, {{'_', Host}, '_'})
      end, OldHosts -- NewHosts),
    ?DEBUG("Shaper rules loaded successfully", []).

-spec create_tabs() -> ok.
create_tabs() ->
    _ = mnesia:delete_table(shaper),
    _ = ets:new(shaper, [named_table, {read_concurrency, true}]),
    _ = ets:new(shaper_rules, [named_table, {read_concurrency, true}]),
    ok.

-spec read_shaper_rules(atom(), global | binary()) -> [shaper_rate_rule()].
read_shaper_rules(Name, Host) ->
    case ets:lookup(shaper_rules, {Name, Host}) of
	[{_, Rule}] -> Rule;
	[] -> []
    end.

-spec read_shaper(atom() | shaper_rate()) -> none | shaper_rate().
read_shaper(Name) when is_atom(Name), Name /= none, Name /= infinity ->
    case ets:lookup(shaper, Name) of
	[{_, Rate}] -> Rate;
	[] -> none
    end;
read_shaper(Rate) ->
    Rate.

%%%===================================================================
%%% Validators
%%%===================================================================
shaper_name() ->
    econf:either(
      econf:and_then(
	econf:atom(),
	fun(infinite) -> infinity;
	   (unlimited) -> infinity;
	   (A) -> A
	end),
      econf:pos_int()).

shaper_validator() ->
    econf:either(
      econf:and_then(
	econf:options(
	  #{rate => econf:pos_int(),
	    burst_size => econf:pos_int()},
	  [unique, {required, [rate]}, {return, map}]),
	fun(#{rate := Rate} = Map) ->
		{Rate, maps:get(burst_size, Map, Rate)}
	end),
      econf:pos_int(infinity)).

%%%===================================================================
%%% Aux
%%%===================================================================
reserved() ->
    [none, infinite, unlimited, infinity].

-spec resolve_shapers(atom(), [shaper_rule()], #{atom() => shaper_rate()}) -> [shaper_rate_rule()].
resolve_shapers(ShaperRule, Rules, Shapers) ->
    lists:filtermap(
      fun({Name, Rule}) when is_atom(Name), Name /= none, Name /= infinity ->
	      try {true, {maps:get(Name, Shapers), Rule}}
	      catch _:{badkey, _} ->
		      ?WARNING_MSG(
			 "Shaper rule '~ts' refers to unknown shaper: ~ts",
			 [ShaperRule, Name]),
		      false
	      end;
	 (_) ->
	      true
      end, Rules).