From 36cc5b221eb4f362081328a91ac762f3133bbe38 Mon Sep 17 00:00:00 2001 From: Jordan Bracco Date: Sun, 5 Mar 2023 10:12:23 +0100 Subject: plugin: image --- lib/nola/plugins.ex | 1 + lib/plugins/image.ex | 246 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 247 insertions(+) create mode 100644 lib/plugins/image.ex diff --git a/lib/nola/plugins.ex b/lib/nola/plugins.ex index 7872cd6..ac94736 100644 --- a/lib/nola/plugins.ex +++ b/lib/nola/plugins.ex @@ -14,6 +14,7 @@ defmodule Nola.Plugins do Nola.Plugins.Dice, Nola.Plugins.Finance, Nola.Plugins.Gpt, + Nola.Plugins.Image, Nola.Plugins.KickRoulette, Nola.Plugins.LastFm, Nola.Plugins.Link, diff --git a/lib/plugins/image.ex b/lib/plugins/image.ex new file mode 100644 index 0000000..446cb49 --- /dev/null +++ b/lib/plugins/image.ex @@ -0,0 +1,246 @@ +defmodule Nola.Plugins.Image do + require Logger + import Nola.Plugins.TempRefHelper + + def irc_doc() do + """ + # Image Generation + + * **`!d2 [-n 1..10] [-g 256, 512, 1024] `** generate image(s) using OpenAI Dall-E 2 + * **`!sd [options] `** generate image(s) using Stable Diffusion models (see below) + + ## !sd + + * `-m X` (sd2) Model (sd2: Stable Diffusion v2, sd1: Stable Diffusion v1.5, any3: Anything v3, any4: Anything v4, oj: OpenJourney) + * `-w X, -h X` (512) width and height. (128, 256, 384, 448, 512, 576, 640, 704, 768) + * `-n 1..10` (1) number of images to generate + * `-s X` (null) Seed + * `-S 0..500` (50) denoising steps + * `-X X` (KLMS) scheduler (DDIM, K_EULER, DPMSolverMultistep, K_EULER_ANCESTRAL, PNDM, KLMS) + * `-g 1..20` (7.5) guidance scale + * `-P 0.0..1.0` (0.8) prompt strength + """ + end + + def start_link() do + GenServer.start_link(__MODULE__, [], name: __MODULE__) + end + + defstruct [:temprefs] + + def init(_) do + regopts = [plugin: __MODULE__] + {:ok, _} = Registry.register(Nola.PubSub, "trigger:d2", regopts) + {:ok, _} = Registry.register(Nola.PubSub, "trigger:sd", regopts) + {:ok, %__MODULE__{temprefs: new_temp_refs()}} + end + + def handle_info({:irc, :trigger, "sd", msg = %Nola.Message{trigger: %Nola.Trigger{type: :bang, args: args}}}, state) do + {:noreply, case OptionParser.parse(args, aliases: [m: :model], strict: [model: :string]) do + {_, [], _} -> + msg.replyfun.("#{msg.sender.nick}: sd: missing prompt") + state + {opts, prompt, _} -> + process_sd(Keyword.get(opts, :model, "sd2"), Enum.join(prompt, " "), msg, state) + end} + end + + def handle_info({:irc, :trigger, "d2", msg = %Nola.Message{trigger: %Nola.Trigger{type: :bang, args: args}}}, state) do + opts = OptionParser.parse(args, + aliases: [n: :n, g: :geometry], + strict: [n: :integer, geometry: :integer] + ) + case opts do + {_opts, [], _} -> + msg.replyfun.("#{msg.sender.nick}: d2: missing prompt") + {:noreply, state} + {opts, prompts, _} -> + prompt = Enum.join(prompts, " ") + geom = Keyword.get(opts, :geometry, 256) + request = %{ + "prompt" => prompt, + "n" => Keyword.get(opts, :n, 1), + "size" => "#{geom}x#{geom}", + "response_format" => "b64_json", + "user" => msg.account.id, + } + + id = FlakeId.get() + + state = case OpenAi.post("/v1/images/generations", request) do + {:ok, %{"data" => data}} -> + urls = for {%{"b64_json" => b64}, idx} <- Enum.with_index(data) do + with {:ok, body} <- Base.decode64(b64), + <> = body, + {:ok, magic} <- GenMagic.Pool.perform(Nola.GenMagic, {:bytes, smol_body}), + bucket = Application.get_env(:nola, :s3, []) |> Keyword.get(:bucket), + s3path = "#{msg.account.id}/iD2#{id}#{idx}.png", + s3req = ExAws.S3.put_object(bucket, s3path, body, acl: :public_read, content_type: magic.mime_type), + {:ok, _} <- ExAws.request(s3req), + path = NolaWeb.Router.Helpers.url(NolaWeb.Endpoint) <> "/files/#{s3path}" + do + {:ok, path} + end + end + + urls = for {:ok, path} <- urls, do: path + msg.replyfun.("#{msg.sender.nick}: #{Enum.join(urls, " ")}") + state + {:error, atom} when is_atom(atom) -> + Logger.error("dalle2: #{inspect atom}") + msg.replyfun.("#{msg.sender.nick}: dalle2: ☠️ #{to_string(atom)}") + state + error -> + Logger.error("dalle2: #{inspect error}") + msg.replyfun.("#{msg.sender.nick}: dalle2: ☠️ ") + state + end + {:noreply, state} + end + end + + defp process_sd(model, prompt, msg, state) do + {general_opts, _, _} = OptionParser.parse(msg.trigger.args, + aliases: [n: :number, w: :width, h: :height], + strict: [number: :integer, width: :integer, height: :integer] + ) + + general_opts = general_opts + |> Keyword.put_new(:number, 1) + + case sd_model(model, prompt, general_opts, msg.trigger.args) do + {:ok, env} -> + base_url = "https://api.runpod.ai/v1/#{env.name}" + {headers, options} = runpod_headers(env, state) + result = with {:ok, json} <- Poison.encode(%{"input" => env.request}), + {:ok, %HTTPoison.Response{status_code: 200, body: body}} <- HTTPoison.post("#{base_url}/run", json, headers, options), + {:ok, %{"id" => id} = data} <- Poison.decode(body) do + Logger.debug("runpod: started job #{id}: #{inspect data}") + spawn(fn() -> runpod_result_loop("#{base_url}/status/#{id}", env, msg, state) end) + :ok + else + {:ok, %HTTPoison.Response{status_code: code}} -> {:error, Plug.Conn.Status.reason_atom(code)} + {:error, %HTTPoison.Error{reason: reason}} -> {:error, reason} + end + + case result do + {:error, reason} -> + Logger.error("runpod: http error for #{base_url}/run: #{inspect reason}") + msg.replyfun.("#{msg.sender.nick}: sd: runpod failed: #{inspect reason}") + _ -> :ok + end + {:error, error} -> + msg.replyfun.("#{msg.sender.nick}: sd: #{error}") + end + + state + end + + defp runpod_result_loop(url, env, msg, state) do + Logger.debug("runpod_result_loop: new") + {headers, options} = runpod_headers(env, state) + with {:ok, %HTTPoison.Response{status_code: 200, body: body}} <- HTTPoison.get(url, headers ++ [{"content-type", "application/json"}], options), + {:ok, %{"status" => "COMPLETED"} = data} <- Poison.decode(body) do + id = FlakeId.get() + tasks = for {%{"image" => url, "seed" => seed}, idx} <- Enum.with_index(Map.get(data, "output", [])) do + Task.async(fn() -> +with {:ok, %HTTPoison.Response{status_code: 200, body: body}} <- HTTPoison.get(url, [], options), +bucket = Application.get_env(:nola, :s3, []) |> Keyword.get(:bucket), +s3path = "#{msg.account.id}/iR#{env.nick}#{id}#{idx}-#{seed}.png", +s3req = ExAws.S3.put_object(bucket, s3path, body, acl: :public_read, content_type: "image/png"), +{:ok, _} <- ExAws.request(s3req), +path = NolaWeb.Router.Helpers.url(NolaWeb.Endpoint) <> "/files/#{s3path}" +do + {:ok, path} +else + error -> + Logger.error("runpod_result: error while uploading #{url}: #{inspect error}") + {:error, error} +end + end) + end + |> Task.yield_many(5000) + |> Enum.map(fn {task, res} -> + res || Task.shutdown(task, :brutal_kill) + end) + + results = for({:ok, {:ok, url}} <- tasks, do: url) + + msg.replyfun.("#{msg.sender.nick}: #{Enum.join(results, " ")}") + else + {:ok, %{"status" => "FAILED"} = data} -> + Logger.error("runpod_result_loop: job FAILED: #{inspect data}") + msg.replyfun.("#{msg.sender.nick}: sd: job failed: #{Map.get(data, "error", "error")}") + {:ok, %{"status" => _} = data} -> + Logger.debug("runpod_result_loop: not completed: #{inspect data}") + :timer.sleep(:timer.seconds(1)) + runpod_result_loop(url, env, msg, state) + {:ok, %HTTPoison.Response{status_code: 403}} -> + msg.replyfun.("#{msg.sender.nick}: sd: runpod failure: unauthorized") + error -> + Logger.warning("image: sd: runpod http error: #{inspect error}") + :timer.sleep(:timer.seconds(2)) + runpod_result_loop(url, env, msg, state) + end + end + + defp runpod_headers(_env, _state) do + config = Application.get_env(:nola, :runpod, []) + headers = [{"user-agent", "nola.lol bot, href@random.sh"}, + {"authorization", "Bearer " <> Keyword.get(config, :key, "unset-api-key")}] + options = [timeout: :timer.seconds(180), recv_timeout: :timer.seconds(180)] + {headers, options} + end + + defp sd_model(name, _, general_opts, opts) when name in ~w(sd2 sd1 oj any any4) do + {opts, prompt, _} = OptionParser.parse(opts, [ + aliases: [P: :strength, s: :seed, S: :steps, g: :guidance, X: :scheduler, q: :negative], + strict: [strength: :float, steps: :integer, guidance: :float, scheduler: :string, seed: :integer, negative: :keep] + ]) + opts = general_opts ++ opts + prompt = Enum.join(prompt, " ") + + negative = case Keyword.get_values(opts, :negative) do + [] -> nil + list -> Enum.join(list, " ") + end + + full_name = case name do + "sd2" -> "stable-diffusion-v2" + "sd1" -> "stable-diffusion-v1" + "oj" -> "sd-openjourney" + "any" -> "sd-anything-v3" + "any4" -> "sd-anything-v4" + end + + default_scheduler = case name do + "sd2" -> "KLMS" + _ -> "K-LMS" + end + + request = %{ + "prompt" => prompt, + "num_outputs" => general_opts[:number], + "width" => opts[:width] || 512, + "height" => opts[:height] || 512, + "prompt_strength" => opts[:strength] || 0.8, + "num_inference_steps" => opts[:steps] || 30, + "guidance_scale" => opts[:guidance] || 7.5, + "scheduler" => opts[:scheduler] || default_scheduler, + "seed" => opts[:seed] || :rand.uniform(100_000_00) + } + + request = if negative do + Map.put(request, "negative_prompt", negative) + else + request + end + + {:ok, %{name: full_name, nick: name, request: request}} + end + + defp sd_model(name, _, _, _) do + {:error, "unsupported model: \"#{name}\""} + end + +end -- cgit v1.2.3