summaryrefslogtreecommitdiff
path: root/lib/irc/sts.ex
diff options
context:
space:
mode:
Diffstat (limited to 'lib/irc/sts.ex')
-rw-r--r--lib/irc/sts.ex169
1 files changed, 169 insertions, 0 deletions
diff --git a/lib/irc/sts.ex b/lib/irc/sts.ex
new file mode 100644
index 0000000..9980f8c
--- /dev/null
+++ b/lib/irc/sts.ex
@@ -0,0 +1,169 @@
+defmodule Irc.STS do
+ @moduledoc """
+ # STS Store.
+
+ When a connection encounters a STS policy, it signals it using `witness/4`. The store will consider the policy valid indefinitely as long as
+ any connection are still alive for that host/port pair. Once all connections or the system stops, the latest policy expiration date will be computed.
+
+ By default, the store is not persistent. If you wish to enable persistance, set the `:irc, :sts_cache_file` app environment.
+ """
+
+ @ets __MODULE__.ETS
+ # tuple {{host,port}, tls_port, period, until | true, at}
+
+ @doc "Lookup a STS entry"
+ @spec lookup(host :: String.t(), port :: Integer.t()) :: {enabled :: boolean, port :: Integer.t()}
+ def lookup(host,port) do
+ with \
+ [{_, port, period, until}] <- :ets.lookup(@ets, {host,port}),
+ true <- verify_validity(period, until)
+ do
+ {true, port}
+ else
+ [] -> {false, port}
+ false ->
+ GenServer.cast(__MODULE__, {:expired, {host,port}})
+ {false, port}
+ end
+ end
+
+ @doc """
+ Signal a STS policy.
+
+ The STS cache will consider the policy as infinitely valid as long as the calling PID is alive,
+ or signal a new `witness/4`.
+ """
+ def witness(host, port, sts_port, period) do
+ GenServer.call(__MODULE__, {:witness, host, port, sts_port, period, self()})
+ end
+
+ @doc """
+ Revoke a STS policy. This is the same as calling `witness/4` with sts_port = nil and period = nil.
+ """
+ def revoke(host, port) do
+ GenServer.call(__MODULE__, {:witness, {host,port,nil,nil,self()}})
+ end
+
+ @doc "Returns all entries in the STS store"
+ def all() do
+ fold = fn(el, acc) -> [el | acc] end
+ :ets.foldl(fold, @ets, [])
+ end
+
+ def start_link() do
+ GenServer.start_link(__MODULE__, [], [name: __MODULE__])
+ end
+
+ def init(_) do
+ ets = :ets.new(@ets, [:named_table, :protected])
+ cache_file = Application.get_env(:irc, :sts_cache_file)
+ dets = if cache_file do
+ {:ok, dets} = :dets.open_file(cache_file)
+ true = :ets.from_dets(ets, dets)
+ end
+ {:ok, %{ets: ets, dets: dets, map: %{}}}
+ end
+
+ def handle_continue(nil, state) do
+ {:noreply, state}
+ end
+ def handle_continue(_, state = %{dets: nil}) do
+ {:noreply, state}
+ end
+ def handle_continue({:write, entry}, state = %{dets: dets}) do
+ :ok = :dets.insert(dets, entry)
+ :ok = :dets.sync(dets)
+ {:noreply, state}
+ end
+ def handle_continue({:remove, key}, state = %{dets: dets}) do
+ :ok = :dets.delete(dets, key)
+ :ok = :dets.sync(dets)
+ {:noreply, state}
+ end
+
+ # Duration 0 -- Policy is removed
+ def handle_call({:witness, host, port, _tls_port, period, pid}, from, state) when period in ["0", 0, nil] do
+ state = remove({host,port}, state)
+ {:reply, :ok, state, {:handle_continue, {:remove, {host,port}}}}
+ end
+
+ # Witnessed policy.
+ # As long as caller PID is alive, consider the policy always valid
+ def handle_call({:witness, host, port, tls_port, period, pid}, _, state) do
+ entry = {{host,port}, tls_port, period, true}
+ :ets.insert(@ets, entry)
+ mon = Process.monitor(pid)
+ state = %{state | map: Map.put(state.map, pid, {mon,{host,port}})}
+ {:reply, :ok, state, {:handle_continue, {:write, entry}}}
+ end
+
+ # Caller side encountered an expired policy, check and remove it.
+ def handle_cast({:expired, key}, state) do
+ {state, continue} = case :ets.lookup(@ets, key) do
+ [{_, _, period, until}] ->
+ if !verify_validity(period, until) do
+ {remove(key, state), {:remove, key}}
+ else
+ {state, nil}
+ end
+ [] -> {state, nil}
+ end
+ {:noreply, state, {:handle_continue, continue}}
+ end
+
+ # A connection disconnected
+ def handle_info({:DOWN, _, :process, pid, _}, state) do
+ key = Map.get(state.map, pid)
+ others = Enum.filter(state.map, fn({p, {_,k}}) -> k == key && p != pid end)
+ state = %{state | map: Map.delete(state.map, pid)}
+ if key && Enum.empty?(others) do
+ case :ets.lookup(@ets, key) do
+ [{key, tls_port, period, until}] ->
+ until = DateTime.utc_now() |> DateTime.add(period)
+ entry = {key, tls_port, period, until}
+ :ets.insert(@ets, entry)
+ {:noreply, state, {:handle_continue, {:write, entry}}}
+ [] ->
+ {:noreply, state}
+ end
+ else
+ {:noreply, state}
+ end
+ end
+
+ # Calculate expiration periods from time of shutdown.
+ def terminate(_, state) do
+ if state.dets do
+ fold = fn
+ ({key, _tls_port, period, true}, acc) ->
+ until = DateTime.utc_now() |> DateTime.add(period)
+ [{key, until} | acc]
+ (_, acc) -> acc
+ end
+ for {key, until} <- :ets.foldl(fold, [], @ets) do
+ :ets.update_element(@ets, key, {4, until})
+ end
+ :ets.to_dets(@ets, state.dets)
+ :dets.close(state.dets)
+ end
+ :ok
+ end
+
+ # Remove an entry from ETS, demonitor related PIDs, and remove from map.
+ defp remove(key, state) do
+ :ets.delete(@ets, key)
+ pids = Enum.filter(state.map, fn({p, {mon,k}}) -> k == key end) |> Enum.map(fn({p, {mon,_}}) -> {p,mon} end)
+ for {_,mon} <- pids, do: Process.demonitor(mon, [:flush])
+ map = Enum.reduce(pids, state.map, fn({p,_}, map) -> Map.delete(map, p) end)
+ %{state | map: map}
+ end
+
+ defp verify_validity(_, true) do
+ true
+ end
+
+ defp verify_validity(_period, until) do
+ DateTime.utc_now() >= until
+ end
+
+end