diff options
Diffstat (limited to 'priv/python/pyerlai/genservers/image_to_text.py')
-rw-r--r-- | priv/python/pyerlai/genservers/image_to_text.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/priv/python/pyerlai/genservers/image_to_text.py b/priv/python/pyerlai/genservers/image_to_text.py new file mode 100644 index 0000000..a1abc74 --- /dev/null +++ b/priv/python/pyerlai/genservers/image_to_text.py @@ -0,0 +1,43 @@ +from term import Atom +from pyrlang.gen.server import GenServer +from pyrlang.gen.decorators import call, cast, info +from PIL import Image +from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel +import io + +class ImageToTextViTGPT2(GenServer): + def __init__(self, node, logger) -> None: + super().__init__() + node.register_name(self, Atom('image_to_text_vit_gpt2')) + self.logger = logger + self.model = None + self.tokenizer = None + self.image_processor = None + self.ready = False + print("lol") + mypid = self.pid_ + node.send_nowait(mypid, mypid, "register") + self.logger.info("initialized process: text_to_image_vit_gpt2.") + + @info(0, lambda msg: msg == 'register') + def setup(self, msg): + print("doing setup") + self.logger.info("image_to_text_vit_gpt2: setup...") + self.model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") + self.tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning") + self.image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") + self.logger.info("text_to_image_vit_gpt2: setup finished.") + self.ready = True + print("ready") + + @call(1, lambda msg: type(msg) == tuple and msg[0] == Atom("run")) + def run(self, msg): + if self.ready: + self.logger.info("image_to_text_vit_gpt2: inference") + image = Image.open(io.BytesIO(msg[1])).convert('RGB') + pixel_values = self.image_processor(image, return_tensors="pt").pixel_values + generated_ids = self.model.generate(pixel_values, max_new_tokens=40) + generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + return (Atom('ok'), generated_text) + else: + return (Atom('error'), Atom('not_ready')) |