diff options
Diffstat (limited to 'lib/irc/sts.ex')
-rw-r--r-- | lib/irc/sts.ex | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/lib/irc/sts.ex b/lib/irc/sts.ex new file mode 100644 index 0000000..9980f8c --- /dev/null +++ b/lib/irc/sts.ex @@ -0,0 +1,169 @@ +defmodule Irc.STS do + @moduledoc """ + # STS Store. + + When a connection encounters a STS policy, it signals it using `witness/4`. The store will consider the policy valid indefinitely as long as + any connection are still alive for that host/port pair. Once all connections or the system stops, the latest policy expiration date will be computed. + + By default, the store is not persistent. If you wish to enable persistance, set the `:irc, :sts_cache_file` app environment. + """ + + @ets __MODULE__.ETS + # tuple {{host,port}, tls_port, period, until | true, at} + + @doc "Lookup a STS entry" + @spec lookup(host :: String.t(), port :: Integer.t()) :: {enabled :: boolean, port :: Integer.t()} + def lookup(host,port) do + with \ + [{_, port, period, until}] <- :ets.lookup(@ets, {host,port}), + true <- verify_validity(period, until) + do + {true, port} + else + [] -> {false, port} + false -> + GenServer.cast(__MODULE__, {:expired, {host,port}}) + {false, port} + end + end + + @doc """ + Signal a STS policy. + + The STS cache will consider the policy as infinitely valid as long as the calling PID is alive, + or signal a new `witness/4`. + """ + def witness(host, port, sts_port, period) do + GenServer.call(__MODULE__, {:witness, host, port, sts_port, period, self()}) + end + + @doc """ + Revoke a STS policy. This is the same as calling `witness/4` with sts_port = nil and period = nil. + """ + def revoke(host, port) do + GenServer.call(__MODULE__, {:witness, {host,port,nil,nil,self()}}) + end + + @doc "Returns all entries in the STS store" + def all() do + fold = fn(el, acc) -> [el | acc] end + :ets.foldl(fold, @ets, []) + end + + def start_link() do + GenServer.start_link(__MODULE__, [], [name: __MODULE__]) + end + + def init(_) do + ets = :ets.new(@ets, [:named_table, :protected]) + cache_file = Application.get_env(:irc, :sts_cache_file) + dets = if cache_file do + {:ok, dets} = :dets.open_file(cache_file) + true = :ets.from_dets(ets, dets) + end + {:ok, %{ets: ets, dets: dets, map: %{}}} + end + + def handle_continue(nil, state) do + {:noreply, state} + end + def handle_continue(_, state = %{dets: nil}) do + {:noreply, state} + end + def handle_continue({:write, entry}, state = %{dets: dets}) do + :ok = :dets.insert(dets, entry) + :ok = :dets.sync(dets) + {:noreply, state} + end + def handle_continue({:remove, key}, state = %{dets: dets}) do + :ok = :dets.delete(dets, key) + :ok = :dets.sync(dets) + {:noreply, state} + end + + # Duration 0 -- Policy is removed + def handle_call({:witness, host, port, _tls_port, period, pid}, from, state) when period in ["0", 0, nil] do + state = remove({host,port}, state) + {:reply, :ok, state, {:handle_continue, {:remove, {host,port}}}} + end + + # Witnessed policy. + # As long as caller PID is alive, consider the policy always valid + def handle_call({:witness, host, port, tls_port, period, pid}, _, state) do + entry = {{host,port}, tls_port, period, true} + :ets.insert(@ets, entry) + mon = Process.monitor(pid) + state = %{state | map: Map.put(state.map, pid, {mon,{host,port}})} + {:reply, :ok, state, {:handle_continue, {:write, entry}}} + end + + # Caller side encountered an expired policy, check and remove it. + def handle_cast({:expired, key}, state) do + {state, continue} = case :ets.lookup(@ets, key) do + [{_, _, period, until}] -> + if !verify_validity(period, until) do + {remove(key, state), {:remove, key}} + else + {state, nil} + end + [] -> {state, nil} + end + {:noreply, state, {:handle_continue, continue}} + end + + # A connection disconnected + def handle_info({:DOWN, _, :process, pid, _}, state) do + key = Map.get(state.map, pid) + others = Enum.filter(state.map, fn({p, {_,k}}) -> k == key && p != pid end) + state = %{state | map: Map.delete(state.map, pid)} + if key && Enum.empty?(others) do + case :ets.lookup(@ets, key) do + [{key, tls_port, period, until}] -> + until = DateTime.utc_now() |> DateTime.add(period) + entry = {key, tls_port, period, until} + :ets.insert(@ets, entry) + {:noreply, state, {:handle_continue, {:write, entry}}} + [] -> + {:noreply, state} + end + else + {:noreply, state} + end + end + + # Calculate expiration periods from time of shutdown. + def terminate(_, state) do + if state.dets do + fold = fn + ({key, _tls_port, period, true}, acc) -> + until = DateTime.utc_now() |> DateTime.add(period) + [{key, until} | acc] + (_, acc) -> acc + end + for {key, until} <- :ets.foldl(fold, [], @ets) do + :ets.update_element(@ets, key, {4, until}) + end + :ets.to_dets(@ets, state.dets) + :dets.close(state.dets) + end + :ok + end + + # Remove an entry from ETS, demonitor related PIDs, and remove from map. + defp remove(key, state) do + :ets.delete(@ets, key) + pids = Enum.filter(state.map, fn({p, {mon,k}}) -> k == key end) |> Enum.map(fn({p, {mon,_}}) -> {p,mon} end) + for {_,mon} <- pids, do: Process.demonitor(mon, [:flush]) + map = Enum.reduce(pids, state.map, fn({p,_}, map) -> Map.delete(map, p) end) + %{state | map: map} + end + + defp verify_validity(_, true) do + true + end + + defp verify_validity(_period, until) do + DateTime.utc_now() >= until + end + +end |