from term import Atom from pyrlang.gen.server import GenServer from pyrlang.gen.decorators import call, cast, info from PIL import Image import io import sys from transformers import CLIPProcessor, CLIPModel PROMPTS=[ "photo", "dog photo", "cat photo", "food photo", "meme", "painting", "drawing", "selfie", "portrait photography", "tv capture", "screenshot", "terminal/ssh/console screenshot", "twitter screenshot", "chat log", "4chan screenshot", "scanned document", "book picture"] class ClipAsk(GenServer): def __init__(self, node, logger) -> None: super().__init__() node.register_name(self, Atom('clip_ask')) self.logger = logger self.model = None self.processor = None self.ready = False print("clipask: starting") mypid = self.pid_ node.send_nowait(mypid, mypid, "register") self.logger.info("initialized process: clip_ask.") @info(0, lambda msg: msg == 'register') def setup(self, msg): print("clipask: doing setup") self.logger.info("image_to_text_vit_gpt2: setup...") self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") self.logger.info("clip_ask: setup finished.") self.ready = True print("clipask: ready") @call(1, lambda msg: type(msg) == tuple and msg[0] == Atom("run")) def run(self, msg): if self.ready: self.logger.info("clip_ask: inference") image = Image.open(io.BytesIO(msg[1])) inputs = self.processor(text=PROMPTS, images=image, return_tensors="pt", padding=True) outputs = self.model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) labels_with_probs = dict(zip(PROMPTS, probs.detach().numpy()[0])) results = dict(sorted(labels_with_probs.items(), key=lambda item: item[1], reverse=True)) return (Atom('ok'), {k: v.item() for k, v in results.items()}) else: return (Atom('error'), Atom('not_ready'))