summaryrefslogtreecommitdiff
path: root/priv/python/pyerlai/genservers/image_to_text.py
diff options
context:
space:
mode:
Diffstat (limited to 'priv/python/pyerlai/genservers/image_to_text.py')
-rw-r--r--priv/python/pyerlai/genservers/image_to_text.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/priv/python/pyerlai/genservers/image_to_text.py b/priv/python/pyerlai/genservers/image_to_text.py
new file mode 100644
index 0000000..a1abc74
--- /dev/null
+++ b/priv/python/pyerlai/genservers/image_to_text.py
@@ -0,0 +1,43 @@
+from term import Atom
+from pyrlang.gen.server import GenServer
+from pyrlang.gen.decorators import call, cast, info
+from PIL import Image
+from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
+import io
+
+class ImageToTextViTGPT2(GenServer):
+ def __init__(self, node, logger) -> None:
+ super().__init__()
+ node.register_name(self, Atom('image_to_text_vit_gpt2'))
+ self.logger = logger
+ self.model = None
+ self.tokenizer = None
+ self.image_processor = None
+ self.ready = False
+ print("lol")
+ mypid = self.pid_
+ node.send_nowait(mypid, mypid, "register")
+ self.logger.info("initialized process: text_to_image_vit_gpt2.")
+
+ @info(0, lambda msg: msg == 'register')
+ def setup(self, msg):
+ print("doing setup")
+ self.logger.info("image_to_text_vit_gpt2: setup...")
+ self.model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
+ self.image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
+ self.logger.info("text_to_image_vit_gpt2: setup finished.")
+ self.ready = True
+ print("ready")
+
+ @call(1, lambda msg: type(msg) == tuple and msg[0] == Atom("run"))
+ def run(self, msg):
+ if self.ready:
+ self.logger.info("image_to_text_vit_gpt2: inference")
+ image = Image.open(io.BytesIO(msg[1])).convert('RGB')
+ pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
+ generated_ids = self.model.generate(pixel_values, max_new_tokens=40)
+ generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ return (Atom('ok'), generated_text)
+ else:
+ return (Atom('error'), Atom('not_ready'))