defmodule Irc.STS do require Logger use GenServer @moduledoc """ # STS Store. When a connection encounters a STS policy, it signals it using `witness/4`. The store will consider the policy valid indefinitely as long as any connection are still alive for that host. Once all connections or the system stops, the latest policy expiration date will be computed. The store is shared for all a given node/uses of the IRC app, but can be disabled per connection basis, by disabling STS. By default, the store is persisted to disk. If you wish to configure persistance, set the `:irc, :sts_store_file` app environment: * `nil` to disable, * `{:priv, "filename.dets"}` to store in the irc app priv directory, * `{:priv, app, "filename.dets"}` to store in another app priv directory, * any file path as string. """ @ets __MODULE__.ETS # tuple {host, # hostname # tls_port, # port to use # period, # advertised period. use `until` instead # until | true, # until which date the entry is valid. true if connections are already connected with the sts policy. # at # last witness time # } @doc "Lookup a STS entry" @spec lookup(host :: String.t()) :: {:ok, port :: Integer.t()} | nil def lookup(host) do with \ [{_, port, period, until, _}] <- :ets.lookup(@ets, host), true <- verify_validity(period, until) do {:ok, port} else [] -> nil false -> GenServer.cast(__MODULE__, {:expired, host}) nil end end @doc """ Signal a STS policy. The STS cache will consider the policy as infinitely valid as long as the calling PID is alive, or signal a new `witness/4`. """ def witness(host, policy) do GenServer.call(__MODULE__, {:witness, host, policy, self()}) end @doc "Returns all entries in the STS store" def all() do fold = fn({host, port, period, until, at}, acc) -> policy = %{port: port, period: period, until: until, at: at} Map.put(acc, host, policy) end :ets.foldl(fold, @ets, Map.new) end def start_link(args) do GenServer.start_link(__MODULE__, args, [name: __MODULE__]) end def init(_) do ets = :ets.new(@ets, [:named_table, :protected]) store_file = parse_env_store_file(Application.get_env(:irc, :sts_store_file, {:priv, "sts_policies.dets"})) dets = if store_file do {:ok, dets} = :dets.open_file(store_file, []) true = :ets.from_dets(ets, dets) # Fix possible stale entries by using their last witness known time fold = fn ({key, _tls_port, period, true, at}, acc) -> until = at |> DateTime.add(period) [{key, until} | acc] (_, acc) -> acc end for {key, until} <- :ets.foldl(fold, [], @ets) do :ets.update_element(@ets, key, {4, until}) end :ets.to_dets(@ets, dets) :dets.sync(dets) dets end {:ok, %{ets: ets, dets: dets, map: %{}}} end def handle_continue(nil, state) do {:noreply, state} end def handle_continue(_, state = %{dets: nil}) do {:noreply, state} end def handle_continue({:write, entry}, state = %{dets: dets}) do :ok = :dets.insert(dets, entry) :ok = :dets.sync(dets) {:noreply, state} end def handle_continue({:remove, key}, state = %{dets: dets}) do :ok = :dets.delete(dets, key) :ok = :dets.sync(dets) {:noreply, state} end # Duration 0 -- Policy is removed def handle_call({:witness, host, %{"duration" => period}, pid}, from, state) when period in ["0", 0, nil] do state = remove(host, state) {:reply, :ok, state, {:handle_continue, {:remove, host}}} end # Witnessed policy. # As long as caller PID is alive, consider the policy always valid def handle_call({:witness, host, %{"port" => tls_port, "duration" => duration}, pid}, _, state) do entry = {host, tls_port, duration, true, DateTime.utc_now()} :ets.insert(@ets, entry) mon = Process.monitor(pid) state = %{state | map: Map.put(state.map, pid, {mon,host})} {:reply, :ok, state, {:handle_continue, {:write, entry}}} end # Caller side encountered an expired policy, check and remove it. def handle_cast({:expired, key}, state) do {state, continue} = case :ets.lookup(@ets, key) do [{_, _, period, until, _at}] -> if !verify_validity(period, until) do {remove(key, state), {:remove, key}} else {state, nil} end [] -> {state, nil} end {:noreply, state, {:handle_continue, continue}} end # A connection disconnected def handle_info({:DOWN, _, :process, pid, _}, state) do key = Map.get(state.map, pid) others = Enum.filter(state.map, fn({p, {_,k}}) -> k == key && p != pid end) state = %{state | map: Map.delete(state.map, pid)} if key && Enum.empty?(others) do case {Enum.empty?(others), :ets.lookup(@ets, key)} do {last?, [{key, tls_port, period, until, at}]} -> now = DateTime.utc_now() until = if last? do DateTime.add(now, period) else true end entry = {key, tls_port, period, until, now} :ets.insert(@ets, entry) {:noreply, state, {:handle_continue, {:write, entry}}} {_, []} -> {:noreply, state} end else {:noreply, state} end end # Calculate expiration periods from time of shutdown. def terminate(_, state) do if state.dets do fold = fn ({key, _tls_port, period, true, _}, acc) -> until = DateTime.utc_now() |> DateTime.add(period) [{key, until} | acc] (_, acc) -> acc end for {key, until} <- :ets.foldl(fold, [], @ets) do :ets.update_element(@ets, key, {4, until}) end :ets.to_dets(@ets, state.dets) :dets.close(state.dets) end :ok end # Remove an entry from ETS, demonitor related PIDs, and remove from map. defp remove(key, state) do :ets.delete(@ets, key) pids = Enum.filter(state.map, fn({p, {mon,k}}) -> k == key end) |> Enum.map(fn({p, {mon,_}}) -> {p,mon} end) for {_,mon} <- pids, do: Process.demonitor(mon, [:flush]) map = Enum.reduce(pids, state.map, fn({p,_}, map) -> Map.delete(map, p) end) %{state | map: map} end defp verify_validity(_, true) do true end defp verify_validity(_period, until) do DateTime.utc_now() >= until end defp parse_env_store_file({:priv, file}) do parse_env_store_file({:priv, :irc, file}) end defp parse_env_store_file({:priv, app, file}) do :code.priv_dir(app) ++ '/' ++ String.to_charlist(file) end defp parse_env_store_file(string) when is_binary(string) do String.to_charlist(string) end defp parse_env_store_file(nil) do Logger.info "Irc.STS: Permanent cache NOT ENABLED" nil end defp parse_env_store_file(_invalid) do Logger.error "Irc.STS: Invalid cache file configuration, permanent store NOT ENABLED" nil end end