summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Bracco <href@random.sh>2023-03-05 10:12:30 +0100
committerJordan Bracco <href@random.sh>2023-03-05 10:12:30 +0100
commitd4a7e3efcc4aae1a6a43600245c208e5d829c932 (patch)
tree308d9de72af6795ade7a89707fd87d42f1e807ab
parentplugin: image (diff)
plugin: gpt
-rw-r--r--lib/plugins/gpt.ex112
1 files changed, 104 insertions, 8 deletions
diff --git a/lib/plugins/gpt.ex b/lib/plugins/gpt.ex
index f89bec1..8ee8c10 100644
--- a/lib/plugins/gpt.ex
+++ b/lib/plugins/gpt.ex
@@ -142,13 +142,19 @@ defmodule Nola.Plugins.Gpt do
"parent_run_id" => Map.get(run, "_id"),
"openai_params" => Map.get(run, "request") |> Map.delete("prompt")}
- continue_prompt = if prompt_string = Map.get(original_prompt, "continue_prompt") do
- full_text = get_in(run, ~w(request prompt)) <> "\n" <> Map.get(run, "response")
- continue_prompt
- |> Map.put("prompt", prompt_string)
- |> Map.put("prompt_format", "liquid")
- |> Map.put("prompt_liquid_variables", %{"previous" => full_text})
- else
+ continue_prompt = case original_prompt do
+ %{"continue_prompt" => prompt_string} when is_binary(prompt_string) ->
+ full_text = get_in(run, ~w(request prompt)) <> "\n" <> Map.get(run, "response")
+ continue_prompt
+ |> Map.put("prompt", prompt_string)
+ |> Map.put("prompt_format", "liquid")
+ |> Map.put("prompt_liquid_variables", %{"previous" => full_text})
+ %{"messages" => _} ->
+ continue_prompt
+ |> Map.put("prompt", "{{content}}")
+ |> Map.put("prompt_format", "liquid")
+ |> Map.put("messages", Map.get(run, "messages"))
+ _ ->
prompt_content_tag = if content != "", do: " {{content}}", else: ""
string = get_in(run, ~w(request prompt)) <> "\n" <> Map.get(run, "response") <> prompt_content_tag
continue_prompt
@@ -163,6 +169,96 @@ defmodule Nola.Plugins.Gpt do
end
end
+ # Chat prompt
+ # "prompt" is the template for the initial user message
+ # "messages" is original messages to be put before the initial user one
+ defp prompt(msg, prompt = %{"type" => "chat", "prompt" => prompt_template, "messages" => messages}, content, state) do
+ Logger.debug("gpt_plugin:prompt/4 (chat) #{inspect prompt}")
+ prompt_text = case Map.get(prompt, "prompt_format", "liquid") do
+ "liquid" -> Tmpl.render(prompt_template, msg, Map.merge(Map.get(prompt, "prompt_liquid_variables", %{}), %{"content" => content}))
+ "norender" -> prompt_template
+ end
+
+ messages = Enum.map(messages, fn(%{"role" => role, "content" => text}) ->
+ text = case Map.get(prompt, "prompt_format", "liquid") do
+ "liquid" -> Tmpl.render(text, msg, Map.get(prompt, "prompt_liquid_variables", %{}))
+ "norender" -> text
+ end
+ %{"role" => role, "content" => text}
+ end) ++ [%{"role" => "user", "content" => prompt_text}]
+
+ args = Map.get(prompt, "openai_params")
+ |> Map.put_new("model", "gpt-3.5-turbo")
+ |> Map.put("messages", messages)
+ |> Map.put("user", msg.account.id)
+
+ {moderate?, moderation} = moderation(content, msg.account.id)
+ if moderate?, do: msg.replyfun.("⚠️ offensive input: #{Enum.join(moderation, ", ")}")
+
+ Logger.debug("GPT: request #{inspect args}")
+ case OpenAi.post("/v1/chat/completions", args) do
+ {:ok, %{"choices" => [%{"message" => %{"content" => text}, "finish_reason" => finish_reason} | _], "usage" => usage, "id" => gpt_id, "created" => created}} ->
+ text = String.trim(text)
+ {o_moderate?, o_moderation} = moderation(text, msg.account.id)
+ if o_moderate?, do: msg.replyfun.("🚨 offensive output: #{Enum.join(o_moderation, ", ")}")
+ msg.replyfun.(text)
+ doc = %{"id" => FlakeId.get(),
+ "prompt_id" => Map.get(prompt, "_id"),
+ "prompt_rev" => Map.get(prompt, "_rev"),
+ "network" => msg.network,
+ "channel" => msg.channel,
+ "nick" => msg.sender.nick,
+ "account_id" => (if msg.account, do: msg.account.id),
+ "request" => args,
+ "messages" => messages ++ [%{"role" => "assistant", "content" => text}],
+ "message_at" => msg.at,
+ "reply_at" => DateTime.utc_now(),
+ "gpt_id" => gpt_id,
+ "gpt_at" => created,
+ "gpt_usage" => usage,
+ "type" => "chat",
+ "parent_run_id" => Map.get(prompt, "parent_run_id"),
+ "moderation" => %{"input" => %{flagged: moderate?, categories: moderation},
+ "output" => %{flagged: o_moderate?, categories: o_moderation}
+ }
+ }
+ Logger.debug("Saving result to couch: #{inspect doc}")
+ {id, ref, temprefs} = case Couch.post(@couch_run_db, doc) do
+ {:ok, id, _rev} ->
+ {ref, temprefs} = put_temp_ref(id, state.temprefs)
+ {id, ref, temprefs}
+ error ->
+ Logger.error("Failed to save to Couch: #{inspect error}")
+ {nil, nil, state.temprefs}
+ end
+ stop = cond do
+ finish_reason == "stop" -> ""
+ finish_reason == "length" -> " — truncated"
+ true -> " — #{finish_reason}"
+ end
+ ref_and_prefix = if Map.get(usage, "completion_tokens", 0) == 0 do
+ "GPT had nothing else to say :( ↪ #{ref || "✗"}"
+ else
+ " ↪ #{ref || "✗"}"
+ end
+ msg.replyfun.(ref_and_prefix <>
+ stop <>
+ " — #{Map.get(usage, "total_tokens", 0)}" <>
+ " (#{Map.get(usage, "prompt_tokens", 0)}/#{Map.get(usage, "completion_tokens", 0)}) tokens" <>
+ " — #{id || "save failed"}")
+ %__MODULE__{state | temprefs: temprefs}
+ {:error, atom} when is_atom(atom) ->
+ Logger.error("gpt error: #{inspect atom}")
+ msg.replyfun.("gpt: ☠️ #{to_string(atom)}")
+ state
+ error ->
+ Logger.error("gpt error: #{inspect error}")
+ msg.replyfun.("gpt: ☠️ ")
+ state
+ end
+ end
+
+
defp prompt(msg, prompt = %{"type" => "completions", "prompt" => prompt_template}, content, state) do
Logger.debug("gpt:prompt/4 #{inspect prompt}")
prompt_text = case Map.get(prompt, "prompt_format", "liquid") do
@@ -231,7 +327,7 @@ defmodule Nola.Plugins.Gpt do
%__MODULE__{state | temprefs: temprefs}
{:error, atom} when is_atom(atom) ->
Logger.error("gpt error: #{inspect atom}")
- msg.replyfun.("gpt: ☠️ #{to_string(atom)}")
+ msg.replyfun.("gpt: ☠️ #{to_string(atom)}")
state
error ->
Logger.error("gpt error: #{inspect error}")