diff options
author | XavierXiao <xiaozhisheng950@gmail.com> | 2022-09-06 00:00:21 -0700 |
---|---|---|
committer | XavierXiao <xiaozhisheng950@gmail.com> | 2022-09-06 00:00:21 -0700 |
commit | 8f22429d7406ad450e681c9940c00461b1e3adf9 (patch) | |
tree | 0459d4e4be49ac2945a7b1263856b2313b6847e1 /scripts/evaluate_model.py |
initial commit
Diffstat (limited to 'scripts/evaluate_model.py')
-rw-r--r-- | scripts/evaluate_model.py | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/scripts/evaluate_model.py b/scripts/evaluate_model.py new file mode 100644 index 0000000..9cc2905 --- /dev/null +++ b/scripts/evaluate_model.py @@ -0,0 +1,89 @@ +import argparse, os, sys, glob + +sys.path.append(os.path.join(sys.path[0], '..')) + +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from einops import rearrange +from torchvision.utils import make_grid + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.data.personalized import PersonalizedBase +from evaluation.clip_eval import LDMCLIPEvaluator + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--ckpt_path", + type=str, + default="/data/pretrained_models/ldm/text2img-large/model.ckpt", + help="Path to pretrained ldm text2img model") + + parser.add_argument( + "--embedding_path", + type=str, + help="Path to a pre-trained embedding manager checkpoint") + + parser.add_argument( + "--data_dir", + type=str, + help="Path to directory with images used to train the embedding vectors" + ) + + opt = parser.parse_args() + + + config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic + model = load_model_from_config(config, opt.ckpt_path) # TODO: check path + model.embedding_manager.load(opt.embedding_path) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + evaluator = LDMCLIPEvaluator(device) + + prompt = opt.prompt + + data_loader = PersonalizedBase(opt.data_dir, size=256, flip_p=0.0) + + images = [torch.from_numpy(data_loader[i]["image"]).permute(2, 0, 1) for i in range(data_loader.num_images)] + images = torch.stack(images, axis=0) + + sim_img, sim_text = evaluator.evaluate(model, images, opt.prompt) + + output_dir = os.path.join(opt.out_dir, prompt.replace(" ", "-")) + + print("Image similarity: ", sim_img) + print("Text similarity: ", sim_text)
\ No newline at end of file |