summaryrefslogblamecommitdiff
path: root/lib/plugins/gpt.ex
blob: f89bec1a641b97841f3bf30e6b22f8fcc4f771bf (plain) (tree)
1
2
3
4
5
6
7
8
                             
                
                                   




                  












                                                                                  



                                       
                                        





                                                          









                                

                                  
                                                                             
                                                 

     
                                                                                                                                      


                                                                                 
                                                              
                         
              
                                                               
                                          
                         
       

     
                                                                                                                         
                                             



                                                                                                 
              
                                                               




                                          
                                                                                                                                         














                                                                                                                        
                                                                                                                                            










                                                            
                                                                                                                                           
                                                              
                         
                                                                                                              
        
                                                                   




                             
                                                                                                                                
                         
                                                                                                                  
        
                                                                       






                                                        


                     



                                                   
                                                                              



                             



                                                        


                                                                                














                                                                                                          
        

                                                                                                              
       

     
                                                                                                       
                                                  


                                                                                                                                        


                                           
                                     
                                      
 

                                                                                             
 

                                                




                                                                                                                                                 

                                                      



























                                                                                               
                                       


                                                       





                                                                          








                                                                                                         


                                                   
             



















                                                                                      
defmodule Nola.Plugins.Gpt do
  require Logger
  import Nola.Plugins.TempRefHelper

  def irc_doc() do
    """
    # OpenAI GPT

    Uses OpenAI's GPT-3 API to bring natural language prompts to your IRC channel.

    _prompts_ are pre-defined prompts and parameters defined in the bot' CouchDB.

    _Runs_ (results of the inference of a _prompt_) are also stored in CouchDB and
    may be resumed.

    * **!gpt** list GPT prompts
    * **!gpt `[prompt]` `<prompt or args>`** run a prompt
    * **+gpt `[short ref|run id]` `<prompt or args>`** continue a prompt
    * **?gpt offensive `<content>`** is content offensive ?
    * **?gpt show `[short ref|run id]`** run information and web link
    * **?gpt `[prompt]`** prompt information and web link
    """
  end

  @couch_db "bot-plugin-openai-prompts"
  @couch_run_db "bot-plugin-gpt-history"
  @trigger "gpt"

  def start_link() do
    GenServer.start_link(__MODULE__, [], name: __MODULE__)
  end

  defstruct [:temprefs]

  def get_result(id) do
    Couch.get(@couch_run_db, id)
  end

  def get_prompt(id) do
    Couch.get(@couch_db, id)
  end

  def init(_) do
    regopts = [plugin: __MODULE__]
    {:ok, _} = Registry.register(Nola.PubSub, "trigger:#{@trigger}", regopts)
    {:ok, %__MODULE__{temprefs: new_temp_refs()}}
  end

  def handle_info({:irc, :trigger, @trigger, m = %Nola.Message{trigger: %Nola.Trigger{type: :bang, args: [prompt | args]}}}, state) do
    case Couch.get(@couch_db, prompt) do
      {:ok, prompt} -> {:noreply, prompt(m, prompt, Enum.join(args, " "), state)}
      {:error, :not_found} ->
        m.replyfun.("gpt: prompt '#{prompt}' does not exists")
        {:noreply, state}
      error ->
        Logger.info("gpt: prompt load error: #{inspect error}")
        m.replyfun.("gpt: database error")
        {:noreply, state}
    end
  end

  def handle_info({:irc, :trigger, @trigger, m = %Nola.Message{trigger: %Nola.Trigger{type: :bang, args: []}}}, state) do
    case Couch.get(@couch_db, "_all_docs") do
      {:ok, %{"rows" => []}} -> m.replyfun.("gpt: no prompts available")
      {:ok, %{"rows" => prompts}} ->
        prompts = prompts |> Enum.map(fn(prompt) -> Map.get(prompt, "id") end) |> Enum.join(", ")
        m.replyfun.("gpt: prompts: #{prompts}")
      error ->
        Logger.info("gpt: prompt load error: #{inspect error}")
        m.replyfun.("gpt: database error")
    end
    {:noreply, state}
  end

  def handle_info({:irc, :trigger, @trigger, m = %Nola.Message{trigger: %Nola.Trigger{type: :plus, args: [ref_or_id | args]}}}, state) do
    id = lookup_temp_ref(ref_or_id, state.temprefs, ref_or_id)
    case Couch.get(@couch_run_db, id) do
      {:ok, run} ->
        Logger.debug("+gpt run: #{inspect run}")
        {:noreply, continue_prompt(m, run, Enum.join(args, " "), state)}
      {:error, :not_found} ->
        m.replyfun.("gpt: ref or id not found or expired: #{inspect ref_or_id} (if using short ref, try using full id)")
        {:noreply, state}
      error ->
        Logger.info("+gpt: run load error: #{inspect error}")
        m.replyfun.("gpt: database error")
        {:noreply, state}
    end
  end

  def handle_info({:irc, :trigger, @trigger, m = %Nola.Message{trigger: %Nola.Trigger{type: :query, args: ["offensive" | text]}}}, state) do
    text = Enum.join(text, " ")
    {moderate?, moderation} = moderation(text, m.account.id)
    reply = cond do
      moderate? -> "⚠️ #{Enum.join(moderation, ", ")}"
      !moderate? && moderation -> "👍"
      !moderate? -> "☠️ error"
    end
    m.replyfun.(reply)
    {:noreply, state}
  end

  def handle_info({:irc, :trigger, @trigger, m = %Nola.Message{trigger: %Nola.Trigger{type: :query, args: ["show", ref_or_id]}}}, state) do
    id = lookup_temp_ref(ref_or_id, state.temprefs, ref_or_id)
    url = if m.channel do
      NolaWeb.Router.Helpers.gpt_url(NolaWeb.Endpoint, :result, m.network, NolaWeb.format_chan(m.channel), id)
    else
      NolaWeb.Router.Helpers.gpt_url(NolaWeb.Endpoint, :result, id)
    end
    m.replyfun.("→ #{url}")
    {:noreply, state}
  end

  def handle_info({:irc, :trigger, @trigger, m = %Nola.Message{trigger: %Nola.Trigger{type: :query, args: [prompt]}}}, state) do
    url = if m.channel do
      NolaWeb.Router.Helpers.gpt_url(NolaWeb.Endpoint, :prompt, m.network, NolaWeb.format_chan(m.channel), prompt)
    else
      NolaWeb.Router.Helpers.gpt_url(NolaWeb.Endpoint, :prompt, prompt)
    end
    m.replyfun.("→ #{url}")
    {:noreply, state}
  end

  def handle_info(info, state) do
    Logger.debug("gpt: unhandled info: #{inspect info}")
    {:noreply, state}
  end

  defp continue_prompt(msg, run, content, state) do
    prompt_id = Map.get(run, "prompt_id")
    prompt_rev = Map.get(run, "prompt_rev")

    original_prompt = case Couch.get(@couch_db, prompt_id, rev: prompt_rev) do
      {:ok, prompt} -> prompt
      _ -> nil
    end

    if original_prompt do
      continue_prompt = %{"_id" => prompt_id,
             "_rev" => prompt_rev,
             "type" => Map.get(original_prompt, "type"),
             "parent_run_id" => Map.get(run, "_id"),
             "openai_params" => Map.get(run, "request") |> Map.delete("prompt")}

      continue_prompt = if prompt_string = Map.get(original_prompt, "continue_prompt") do
        full_text = get_in(run, ~w(request prompt)) <> "\n" <> Map.get(run, "response")
        continue_prompt
        |> Map.put("prompt", prompt_string)
        |> Map.put("prompt_format", "liquid")
        |> Map.put("prompt_liquid_variables", %{"previous" => full_text})
      else
        prompt_content_tag = if content != "", do: " {{content}}", else: ""
        string = get_in(run, ~w(request prompt)) <> "\n" <> Map.get(run, "response") <> prompt_content_tag
        continue_prompt
        |> Map.put("prompt", string)
        |> Map.put("prompt_format", "liquid")
      end

      prompt(msg, continue_prompt, content, state)
    else
      msg.replyfun.("gpt: cannot continue this prompt: original prompt not found #{prompt_id}@v#{prompt_rev}")
      state
    end
  end

  defp prompt(msg, prompt = %{"type" => "completions", "prompt" => prompt_template}, content, state) do
    Logger.debug("gpt:prompt/4 #{inspect prompt}")
    prompt_text = case Map.get(prompt, "prompt_format", "liquid") do
      "liquid" -> Tmpl.render(prompt_template, msg, Map.merge(Map.get(prompt, "prompt_liquid_variables", %{}), %{"content" => content}))
      "norender" -> prompt_template
    end

    args = Map.get(prompt, "openai_params")
    |> Map.put("prompt", prompt_text)
    |> Map.put("user", msg.account.id)

    {moderate?, moderation} = moderation(content, msg.account.id)
    if moderate?, do: msg.replyfun.("⚠️ offensive input: #{Enum.join(moderation, ", ")}")

    Logger.debug("GPT: request #{inspect args}")
    case OpenAi.post("/v1/completions", args) do
      {:ok, %{"choices" => [%{"text" => text, "finish_reason" => finish_reason} | _], "usage" => usage, "id" => gpt_id, "created" => created}} ->
        text = String.trim(text)
        {o_moderate?, o_moderation} = moderation(text, msg.account.id)
        if o_moderate?, do: msg.replyfun.("🚨 offensive output: #{Enum.join(o_moderation, ", ")}")
        msg.replyfun.(text)
        doc = %{"id" => FlakeId.get(),
                "prompt_id" => Map.get(prompt, "_id"),
                "prompt_rev" => Map.get(prompt, "_rev"),
                "network" => msg.network,
                "channel" => msg.channel,
                "nick" => msg.sender.nick,
                "account_id" => (if msg.account, do: msg.account.id),
                "request" => args,
                "response" => text,
                "message_at" => msg.at,
                "reply_at" => DateTime.utc_now(),
                "gpt_id" => gpt_id,
                "gpt_at" => created,
                "gpt_usage" => usage,
                "type" => "completions",
                "parent_run_id" => Map.get(prompt, "parent_run_id"),
                "moderation" => %{"input" => %{flagged: moderate?, categories: moderation},
                                  "output" => %{flagged: o_moderate?, categories: o_moderation}
                }
        }
        Logger.debug("Saving result to couch: #{inspect doc}")
        {id, ref, temprefs} = case Couch.post(@couch_run_db, doc) do
          {:ok, id, _rev} ->
            {ref, temprefs} = put_temp_ref(id, state.temprefs)
            {id, ref, temprefs}
          error ->
            Logger.error("Failed to save to Couch: #{inspect error}")
            {nil, nil, state.temprefs}
        end
        stop = cond do
          finish_reason == "stop" -> ""
          finish_reason == "length" -> " — truncated"
          true -> " — #{finish_reason}"
        end
        ref_and_prefix = if Map.get(usage, "completion_tokens", 0) == 0 do
          "GPT had nothing else to say :( ↪ #{ref || "✗"}"
        else
          "   ↪ #{ref || "✗"}"
        end
        msg.replyfun.(ref_and_prefix <>
           stop <>
           " — #{Map.get(usage, "total_tokens", 0)}" <>
           " (#{Map.get(usage, "prompt_tokens", 0)}/#{Map.get(usage, "completion_tokens", 0)}) tokens" <>
           " — #{id || "save failed"}")
        %__MODULE__{state | temprefs: temprefs}
      {:error, atom} when is_atom(atom) ->
        Logger.error("gpt error: #{inspect atom}")
        msg.replyfun.("gpt: ☠️  #{to_string(atom)}") 
        state
      error ->
        Logger.error("gpt error: #{inspect error}")
        msg.replyfun.("gpt: ☠️ ")
        state
    end
  end

  defp moderation(content, user_id) do
    case OpenAi.post("/v1/moderations", %{"input" => content, "user" => user_id}) do
      {:ok, %{"results" => [%{"flagged" => true, "categories" => categories} | _]}} ->
        cat = categories
        |> Enum.filter(fn({_key, value}) -> value end)
        |> Enum.map(fn({key, _}) -> key end)
        {true, cat}
      {:ok, moderation} ->
        Logger.debug("gpt: moderation: not flagged, #{inspect moderation}")
        {false, true}
      error ->
        Logger.error("gpt: moderation error: #{inspect error}")
        {false, false}
    end
  end

end