summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Bracco <href@random.sh>2023-03-05 10:12:23 +0100
committerJordan Bracco <href@random.sh>2023-03-05 10:12:23 +0100
commit36cc5b221eb4f362081328a91ac762f3133bbe38 (patch)
treee29655e43541de632290eefb54147812d0d3475f
parentconnection: ignore snoonet channels on join (diff)
plugin: image
-rw-r--r--lib/nola/plugins.ex1
-rw-r--r--lib/plugins/image.ex246
2 files changed, 247 insertions, 0 deletions
diff --git a/lib/nola/plugins.ex b/lib/nola/plugins.ex
index 7872cd6..ac94736 100644
--- a/lib/nola/plugins.ex
+++ b/lib/nola/plugins.ex
@@ -14,6 +14,7 @@ defmodule Nola.Plugins do
Nola.Plugins.Dice,
Nola.Plugins.Finance,
Nola.Plugins.Gpt,
+ Nola.Plugins.Image,
Nola.Plugins.KickRoulette,
Nola.Plugins.LastFm,
Nola.Plugins.Link,
diff --git a/lib/plugins/image.ex b/lib/plugins/image.ex
new file mode 100644
index 0000000..446cb49
--- /dev/null
+++ b/lib/plugins/image.ex
@@ -0,0 +1,246 @@
+defmodule Nola.Plugins.Image do
+ require Logger
+ import Nola.Plugins.TempRefHelper
+
+ def irc_doc() do
+ """
+ # Image Generation
+
+ * **`!d2 [-n 1..10] [-g 256, 512, 1024] <prompt>`** generate image(s) using OpenAI Dall-E 2
+ * **`!sd [options] <prompt>`** generate image(s) using Stable Diffusion models (see below)
+
+ ## !sd
+
+ * `-m X` (sd2) Model (sd2: Stable Diffusion v2, sd1: Stable Diffusion v1.5, any3: Anything v3, any4: Anything v4, oj: OpenJourney)
+ * `-w X, -h X` (512) width and height. (128, 256, 384, 448, 512, 576, 640, 704, 768)
+ * `-n 1..10` (1) number of images to generate
+ * `-s X` (null) Seed
+ * `-S 0..500` (50) denoising steps
+ * `-X X` (KLMS) scheduler (DDIM, K_EULER, DPMSolverMultistep, K_EULER_ANCESTRAL, PNDM, KLMS)
+ * `-g 1..20` (7.5) guidance scale
+ * `-P 0.0..1.0` (0.8) prompt strength
+ """
+ end
+
+ def start_link() do
+ GenServer.start_link(__MODULE__, [], name: __MODULE__)
+ end
+
+ defstruct [:temprefs]
+
+ def init(_) do
+ regopts = [plugin: __MODULE__]
+ {:ok, _} = Registry.register(Nola.PubSub, "trigger:d2", regopts)
+ {:ok, _} = Registry.register(Nola.PubSub, "trigger:sd", regopts)
+ {:ok, %__MODULE__{temprefs: new_temp_refs()}}
+ end
+
+ def handle_info({:irc, :trigger, "sd", msg = %Nola.Message{trigger: %Nola.Trigger{type: :bang, args: args}}}, state) do
+ {:noreply, case OptionParser.parse(args, aliases: [m: :model], strict: [model: :string]) do
+ {_, [], _} ->
+ msg.replyfun.("#{msg.sender.nick}: sd: missing prompt")
+ state
+ {opts, prompt, _} ->
+ process_sd(Keyword.get(opts, :model, "sd2"), Enum.join(prompt, " "), msg, state)
+ end}
+ end
+
+ def handle_info({:irc, :trigger, "d2", msg = %Nola.Message{trigger: %Nola.Trigger{type: :bang, args: args}}}, state) do
+ opts = OptionParser.parse(args,
+ aliases: [n: :n, g: :geometry],
+ strict: [n: :integer, geometry: :integer]
+ )
+ case opts do
+ {_opts, [], _} ->
+ msg.replyfun.("#{msg.sender.nick}: d2: missing prompt")
+ {:noreply, state}
+ {opts, prompts, _} ->
+ prompt = Enum.join(prompts, " ")
+ geom = Keyword.get(opts, :geometry, 256)
+ request = %{
+ "prompt" => prompt,
+ "n" => Keyword.get(opts, :n, 1),
+ "size" => "#{geom}x#{geom}",
+ "response_format" => "b64_json",
+ "user" => msg.account.id,
+ }
+
+ id = FlakeId.get()
+
+ state = case OpenAi.post("/v1/images/generations", request) do
+ {:ok, %{"data" => data}} ->
+ urls = for {%{"b64_json" => b64}, idx} <- Enum.with_index(data) do
+ with {:ok, body} <- Base.decode64(b64),
+ <<smol_body::binary-size(20), _::binary>> = body,
+ {:ok, magic} <- GenMagic.Pool.perform(Nola.GenMagic, {:bytes, smol_body}),
+ bucket = Application.get_env(:nola, :s3, []) |> Keyword.get(:bucket),
+ s3path = "#{msg.account.id}/iD2#{id}#{idx}.png",
+ s3req = ExAws.S3.put_object(bucket, s3path, body, acl: :public_read, content_type: magic.mime_type),
+ {:ok, _} <- ExAws.request(s3req),
+ path = NolaWeb.Router.Helpers.url(NolaWeb.Endpoint) <> "/files/#{s3path}"
+ do
+ {:ok, path}
+ end
+ end
+
+ urls = for {:ok, path} <- urls, do: path
+ msg.replyfun.("#{msg.sender.nick}: #{Enum.join(urls, " ")}")
+ state
+ {:error, atom} when is_atom(atom) ->
+ Logger.error("dalle2: #{inspect atom}")
+ msg.replyfun.("#{msg.sender.nick}: dalle2: ☠️ #{to_string(atom)}")
+ state
+ error ->
+ Logger.error("dalle2: #{inspect error}")
+ msg.replyfun.("#{msg.sender.nick}: dalle2: ☠️ ")
+ state
+ end
+ {:noreply, state}
+ end
+ end
+
+ defp process_sd(model, prompt, msg, state) do
+ {general_opts, _, _} = OptionParser.parse(msg.trigger.args,
+ aliases: [n: :number, w: :width, h: :height],
+ strict: [number: :integer, width: :integer, height: :integer]
+ )
+
+ general_opts = general_opts
+ |> Keyword.put_new(:number, 1)
+
+ case sd_model(model, prompt, general_opts, msg.trigger.args) do
+ {:ok, env} ->
+ base_url = "https://api.runpod.ai/v1/#{env.name}"
+ {headers, options} = runpod_headers(env, state)
+ result = with {:ok, json} <- Poison.encode(%{"input" => env.request}),
+ {:ok, %HTTPoison.Response{status_code: 200, body: body}} <- HTTPoison.post("#{base_url}/run", json, headers, options),
+ {:ok, %{"id" => id} = data} <- Poison.decode(body) do
+ Logger.debug("runpod: started job #{id}: #{inspect data}")
+ spawn(fn() -> runpod_result_loop("#{base_url}/status/#{id}", env, msg, state) end)
+ :ok
+ else
+ {:ok, %HTTPoison.Response{status_code: code}} -> {:error, Plug.Conn.Status.reason_atom(code)}
+ {:error, %HTTPoison.Error{reason: reason}} -> {:error, reason}
+ end
+
+ case result do
+ {:error, reason} ->
+ Logger.error("runpod: http error for #{base_url}/run: #{inspect reason}")
+ msg.replyfun.("#{msg.sender.nick}: sd: runpod failed: #{inspect reason}")
+ _ -> :ok
+ end
+ {:error, error} ->
+ msg.replyfun.("#{msg.sender.nick}: sd: #{error}")
+ end
+
+ state
+ end
+
+ defp runpod_result_loop(url, env, msg, state) do
+ Logger.debug("runpod_result_loop: new")
+ {headers, options} = runpod_headers(env, state)
+ with {:ok, %HTTPoison.Response{status_code: 200, body: body}} <- HTTPoison.get(url, headers ++ [{"content-type", "application/json"}], options),
+ {:ok, %{"status" => "COMPLETED"} = data} <- Poison.decode(body) do
+ id = FlakeId.get()
+ tasks = for {%{"image" => url, "seed" => seed}, idx} <- Enum.with_index(Map.get(data, "output", [])) do
+ Task.async(fn() ->
+with {:ok, %HTTPoison.Response{status_code: 200, body: body}} <- HTTPoison.get(url, [], options),
+bucket = Application.get_env(:nola, :s3, []) |> Keyword.get(:bucket),
+s3path = "#{msg.account.id}/iR#{env.nick}#{id}#{idx}-#{seed}.png",
+s3req = ExAws.S3.put_object(bucket, s3path, body, acl: :public_read, content_type: "image/png"),
+{:ok, _} <- ExAws.request(s3req),
+path = NolaWeb.Router.Helpers.url(NolaWeb.Endpoint) <> "/files/#{s3path}"
+do
+ {:ok, path}
+else
+ error ->
+ Logger.error("runpod_result: error while uploading #{url}: #{inspect error}")
+ {:error, error}
+end
+ end)
+ end
+ |> Task.yield_many(5000)
+ |> Enum.map(fn {task, res} ->
+ res || Task.shutdown(task, :brutal_kill)
+ end)
+
+ results = for({:ok, {:ok, url}} <- tasks, do: url)
+
+ msg.replyfun.("#{msg.sender.nick}: #{Enum.join(results, " ")}")
+ else
+ {:ok, %{"status" => "FAILED"} = data} ->
+ Logger.error("runpod_result_loop: job FAILED: #{inspect data}")
+ msg.replyfun.("#{msg.sender.nick}: sd: job failed: #{Map.get(data, "error", "error")}")
+ {:ok, %{"status" => _} = data} ->
+ Logger.debug("runpod_result_loop: not completed: #{inspect data}")
+ :timer.sleep(:timer.seconds(1))
+ runpod_result_loop(url, env, msg, state)
+ {:ok, %HTTPoison.Response{status_code: 403}} ->
+ msg.replyfun.("#{msg.sender.nick}: sd: runpod failure: unauthorized")
+ error ->
+ Logger.warning("image: sd: runpod http error: #{inspect error}")
+ :timer.sleep(:timer.seconds(2))
+ runpod_result_loop(url, env, msg, state)
+ end
+ end
+
+ defp runpod_headers(_env, _state) do
+ config = Application.get_env(:nola, :runpod, [])
+ headers = [{"user-agent", "nola.lol bot, href@random.sh"},
+ {"authorization", "Bearer " <> Keyword.get(config, :key, "unset-api-key")}]
+ options = [timeout: :timer.seconds(180), recv_timeout: :timer.seconds(180)]
+ {headers, options}
+ end
+
+ defp sd_model(name, _, general_opts, opts) when name in ~w(sd2 sd1 oj any any4) do
+ {opts, prompt, _} = OptionParser.parse(opts, [
+ aliases: [P: :strength, s: :seed, S: :steps, g: :guidance, X: :scheduler, q: :negative],
+ strict: [strength: :float, steps: :integer, guidance: :float, scheduler: :string, seed: :integer, negative: :keep]
+ ])
+ opts = general_opts ++ opts
+ prompt = Enum.join(prompt, " ")
+
+ negative = case Keyword.get_values(opts, :negative) do
+ [] -> nil
+ list -> Enum.join(list, " ")
+ end
+
+ full_name = case name do
+ "sd2" -> "stable-diffusion-v2"
+ "sd1" -> "stable-diffusion-v1"
+ "oj" -> "sd-openjourney"
+ "any" -> "sd-anything-v3"
+ "any4" -> "sd-anything-v4"
+ end
+
+ default_scheduler = case name do
+ "sd2" -> "KLMS"
+ _ -> "K-LMS"
+ end
+
+ request = %{
+ "prompt" => prompt,
+ "num_outputs" => general_opts[:number],
+ "width" => opts[:width] || 512,
+ "height" => opts[:height] || 512,
+ "prompt_strength" => opts[:strength] || 0.8,
+ "num_inference_steps" => opts[:steps] || 30,
+ "guidance_scale" => opts[:guidance] || 7.5,
+ "scheduler" => opts[:scheduler] || default_scheduler,
+ "seed" => opts[:seed] || :rand.uniform(100_000_00)
+ }
+
+ request = if negative do
+ Map.put(request, "negative_prompt", negative)
+ else
+ request
+ end
+
+ {:ok, %{name: full_name, nick: name, request: request}}
+ end
+
+ defp sd_model(name, _, _, _) do
+ {:error, "unsupported model: \"#{name}\""}
+ end
+
+end