aboutsummaryrefslogtreecommitdiff
path: root/scripts/evaluate_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/evaluate_model.py')
-rw-r--r--scripts/evaluate_model.py89
1 files changed, 89 insertions, 0 deletions
diff --git a/scripts/evaluate_model.py b/scripts/evaluate_model.py
new file mode 100644
index 0000000..9cc2905
--- /dev/null
+++ b/scripts/evaluate_model.py
@@ -0,0 +1,89 @@
+import argparse, os, sys, glob
+
+sys.path.append(os.path.join(sys.path[0], '..'))
+
+import torch
+import numpy as np
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from einops import rearrange
+from torchvision.utils import make_grid
+
+from ldm.util import instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.data.personalized import PersonalizedBase
+from evaluation.clip_eval import LDMCLIPEvaluator
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ nargs="?",
+ default="a painting of a virus monster playing guitar",
+ help="the prompt to render"
+ )
+
+ parser.add_argument(
+ "--ckpt_path",
+ type=str,
+ default="/data/pretrained_models/ldm/text2img-large/model.ckpt",
+ help="Path to pretrained ldm text2img model")
+
+ parser.add_argument(
+ "--embedding_path",
+ type=str,
+ help="Path to a pre-trained embedding manager checkpoint")
+
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ help="Path to directory with images used to train the embedding vectors"
+ )
+
+ opt = parser.parse_args()
+
+
+ config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
+ model = load_model_from_config(config, opt.ckpt_path) # TODO: check path
+ model.embedding_manager.load(opt.embedding_path)
+
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+
+ evaluator = LDMCLIPEvaluator(device)
+
+ prompt = opt.prompt
+
+ data_loader = PersonalizedBase(opt.data_dir, size=256, flip_p=0.0)
+
+ images = [torch.from_numpy(data_loader[i]["image"]).permute(2, 0, 1) for i in range(data_loader.num_images)]
+ images = torch.stack(images, axis=0)
+
+ sim_img, sim_text = evaluator.evaluate(model, images, opt.prompt)
+
+ output_dir = os.path.join(opt.out_dir, prompt.replace(" ", "-"))
+
+ print("Image similarity: ", sim_img)
+ print("Text similarity: ", sim_text) \ No newline at end of file