From d4a7e3efcc4aae1a6a43600245c208e5d829c932 Mon Sep 17 00:00:00 2001 From: Jordan Bracco Date: Sun, 5 Mar 2023 10:12:30 +0100 Subject: plugin: gpt --- lib/plugins/gpt.ex | 112 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 104 insertions(+), 8 deletions(-) diff --git a/lib/plugins/gpt.ex b/lib/plugins/gpt.ex index f89bec1..8ee8c10 100644 --- a/lib/plugins/gpt.ex +++ b/lib/plugins/gpt.ex @@ -142,13 +142,19 @@ defmodule Nola.Plugins.Gpt do "parent_run_id" => Map.get(run, "_id"), "openai_params" => Map.get(run, "request") |> Map.delete("prompt")} - continue_prompt = if prompt_string = Map.get(original_prompt, "continue_prompt") do - full_text = get_in(run, ~w(request prompt)) <> "\n" <> Map.get(run, "response") - continue_prompt - |> Map.put("prompt", prompt_string) - |> Map.put("prompt_format", "liquid") - |> Map.put("prompt_liquid_variables", %{"previous" => full_text}) - else + continue_prompt = case original_prompt do + %{"continue_prompt" => prompt_string} when is_binary(prompt_string) -> + full_text = get_in(run, ~w(request prompt)) <> "\n" <> Map.get(run, "response") + continue_prompt + |> Map.put("prompt", prompt_string) + |> Map.put("prompt_format", "liquid") + |> Map.put("prompt_liquid_variables", %{"previous" => full_text}) + %{"messages" => _} -> + continue_prompt + |> Map.put("prompt", "{{content}}") + |> Map.put("prompt_format", "liquid") + |> Map.put("messages", Map.get(run, "messages")) + _ -> prompt_content_tag = if content != "", do: " {{content}}", else: "" string = get_in(run, ~w(request prompt)) <> "\n" <> Map.get(run, "response") <> prompt_content_tag continue_prompt @@ -163,6 +169,96 @@ defmodule Nola.Plugins.Gpt do end end + # Chat prompt + # "prompt" is the template for the initial user message + # "messages" is original messages to be put before the initial user one + defp prompt(msg, prompt = %{"type" => "chat", "prompt" => prompt_template, "messages" => messages}, content, state) do + Logger.debug("gpt_plugin:prompt/4 (chat) #{inspect prompt}") + prompt_text = case Map.get(prompt, "prompt_format", "liquid") do + "liquid" -> Tmpl.render(prompt_template, msg, Map.merge(Map.get(prompt, "prompt_liquid_variables", %{}), %{"content" => content})) + "norender" -> prompt_template + end + + messages = Enum.map(messages, fn(%{"role" => role, "content" => text}) -> + text = case Map.get(prompt, "prompt_format", "liquid") do + "liquid" -> Tmpl.render(text, msg, Map.get(prompt, "prompt_liquid_variables", %{})) + "norender" -> text + end + %{"role" => role, "content" => text} + end) ++ [%{"role" => "user", "content" => prompt_text}] + + args = Map.get(prompt, "openai_params") + |> Map.put_new("model", "gpt-3.5-turbo") + |> Map.put("messages", messages) + |> Map.put("user", msg.account.id) + + {moderate?, moderation} = moderation(content, msg.account.id) + if moderate?, do: msg.replyfun.("⚠️ offensive input: #{Enum.join(moderation, ", ")}") + + Logger.debug("GPT: request #{inspect args}") + case OpenAi.post("/v1/chat/completions", args) do + {:ok, %{"choices" => [%{"message" => %{"content" => text}, "finish_reason" => finish_reason} | _], "usage" => usage, "id" => gpt_id, "created" => created}} -> + text = String.trim(text) + {o_moderate?, o_moderation} = moderation(text, msg.account.id) + if o_moderate?, do: msg.replyfun.("🚨 offensive output: #{Enum.join(o_moderation, ", ")}") + msg.replyfun.(text) + doc = %{"id" => FlakeId.get(), + "prompt_id" => Map.get(prompt, "_id"), + "prompt_rev" => Map.get(prompt, "_rev"), + "network" => msg.network, + "channel" => msg.channel, + "nick" => msg.sender.nick, + "account_id" => (if msg.account, do: msg.account.id), + "request" => args, + "messages" => messages ++ [%{"role" => "assistant", "content" => text}], + "message_at" => msg.at, + "reply_at" => DateTime.utc_now(), + "gpt_id" => gpt_id, + "gpt_at" => created, + "gpt_usage" => usage, + "type" => "chat", + "parent_run_id" => Map.get(prompt, "parent_run_id"), + "moderation" => %{"input" => %{flagged: moderate?, categories: moderation}, + "output" => %{flagged: o_moderate?, categories: o_moderation} + } + } + Logger.debug("Saving result to couch: #{inspect doc}") + {id, ref, temprefs} = case Couch.post(@couch_run_db, doc) do + {:ok, id, _rev} -> + {ref, temprefs} = put_temp_ref(id, state.temprefs) + {id, ref, temprefs} + error -> + Logger.error("Failed to save to Couch: #{inspect error}") + {nil, nil, state.temprefs} + end + stop = cond do + finish_reason == "stop" -> "" + finish_reason == "length" -> " — truncated" + true -> " — #{finish_reason}" + end + ref_and_prefix = if Map.get(usage, "completion_tokens", 0) == 0 do + "GPT had nothing else to say :( ↪ #{ref || "✗"}" + else + " ↪ #{ref || "✗"}" + end + msg.replyfun.(ref_and_prefix <> + stop <> + " — #{Map.get(usage, "total_tokens", 0)}" <> + " (#{Map.get(usage, "prompt_tokens", 0)}/#{Map.get(usage, "completion_tokens", 0)}) tokens" <> + " — #{id || "save failed"}") + %__MODULE__{state | temprefs: temprefs} + {:error, atom} when is_atom(atom) -> + Logger.error("gpt error: #{inspect atom}") + msg.replyfun.("gpt: ☠️ #{to_string(atom)}") + state + error -> + Logger.error("gpt error: #{inspect error}") + msg.replyfun.("gpt: ☠️ ") + state + end + end + + defp prompt(msg, prompt = %{"type" => "completions", "prompt" => prompt_template}, content, state) do Logger.debug("gpt:prompt/4 #{inspect prompt}") prompt_text = case Map.get(prompt, "prompt_format", "liquid") do @@ -231,7 +327,7 @@ defmodule Nola.Plugins.Gpt do %__MODULE__{state | temprefs: temprefs} {:error, atom} when is_atom(atom) -> Logger.error("gpt error: #{inspect atom}") - msg.replyfun.("gpt: ☠️ #{to_string(atom)}") + msg.replyfun.("gpt: ☠️ #{to_string(atom)}") state error -> Logger.error("gpt error: #{inspect error}") -- cgit v1.2.3