aboutsummaryrefslogtreecommitdiff
path: root/scripts/sample_diffusion.py
diff options
context:
space:
mode:
authorXavierXiao <xiaozhisheng950@gmail.com>2022-09-06 00:00:21 -0700
committerXavierXiao <xiaozhisheng950@gmail.com>2022-09-06 00:00:21 -0700
commit8f22429d7406ad450e681c9940c00461b1e3adf9 (patch)
tree0459d4e4be49ac2945a7b1263856b2313b6847e1 /scripts/sample_diffusion.py
initial commit
Diffstat (limited to 'scripts/sample_diffusion.py')
-rw-r--r--scripts/sample_diffusion.py313
1 files changed, 313 insertions, 0 deletions
diff --git a/scripts/sample_diffusion.py b/scripts/sample_diffusion.py
new file mode 100644
index 0000000..876fe3c
--- /dev/null
+++ b/scripts/sample_diffusion.py
@@ -0,0 +1,313 @@
+import argparse, os, sys, glob, datetime, yaml
+import torch
+import time
+import numpy as np
+from tqdm import trange
+
+from omegaconf import OmegaConf
+from PIL import Image
+
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.util import instantiate_from_config
+
+rescale = lambda x: (x + 1.) / 2.
+
+def custom_to_pil(x):
+ x = x.detach().cpu()
+ x = torch.clamp(x, -1., 1.)
+ x = (x + 1.) / 2.
+ x = x.permute(1, 2, 0).numpy()
+ x = (255 * x).astype(np.uint8)
+ x = Image.fromarray(x)
+ if not x.mode == "RGB":
+ x = x.convert("RGB")
+ return x
+
+
+def custom_to_np(x):
+ # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
+ sample = x.detach().cpu()
+ sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
+ sample = sample.permute(0, 2, 3, 1)
+ sample = sample.contiguous()
+ return sample
+
+
+def logs2pil(logs, keys=["sample"]):
+ imgs = dict()
+ for k in logs:
+ try:
+ if len(logs[k].shape) == 4:
+ img = custom_to_pil(logs[k][0, ...])
+ elif len(logs[k].shape) == 3:
+ img = custom_to_pil(logs[k])
+ else:
+ print(f"Unknown format for key {k}. ")
+ img = None
+ except:
+ img = None
+ imgs[k] = img
+ return imgs
+
+
+@torch.no_grad()
+def convsample(model, shape, return_intermediates=True,
+ verbose=True,
+ make_prog_row=False):
+
+
+ if not make_prog_row:
+ return model.p_sample_loop(None, shape,
+ return_intermediates=return_intermediates, verbose=verbose)
+ else:
+ return model.progressive_denoising(
+ None, shape, verbose=True
+ )
+
+
+@torch.no_grad()
+def convsample_ddim(model, steps, shape, eta=1.0
+ ):
+ ddim = DDIMSampler(model)
+ bs = shape[0]
+ shape = shape[1:]
+ samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
+ return samples, intermediates
+
+
+@torch.no_grad()
+def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
+
+
+ log = dict()
+
+ shape = [batch_size,
+ model.model.diffusion_model.in_channels,
+ model.model.diffusion_model.image_size,
+ model.model.diffusion_model.image_size]
+
+ with model.ema_scope("Plotting"):
+ t0 = time.time()
+ if vanilla:
+ sample, progrow = convsample(model, shape,
+ make_prog_row=True)
+ else:
+ sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
+ eta=eta)
+
+ t1 = time.time()
+
+ x_sample = model.decode_first_stage(sample)
+
+ log["sample"] = x_sample
+ log["time"] = t1 - t0
+ log['throughput'] = sample.shape[0] / (t1 - t0)
+ print(f'Throughput for this batch: {log["throughput"]}')
+ return log
+
+def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
+ if vanilla:
+ print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
+ else:
+ print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
+
+
+ tstart = time.time()
+ n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
+ # path = logdir
+ if model.cond_stage_model is None:
+ all_images = []
+
+ print(f"Running unconditional sampling for {n_samples} samples")
+ for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
+ logs = make_convolutional_sample(model, batch_size=batch_size,
+ vanilla=vanilla, custom_steps=custom_steps,
+ eta=eta)
+ n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
+ all_images.extend([custom_to_np(logs["sample"])])
+ if n_saved >= n_samples:
+ print(f'Finish after generating {n_saved} samples')
+ break
+ all_img = np.concatenate(all_images, axis=0)
+ all_img = all_img[:n_samples]
+ shape_str = "x".join([str(x) for x in all_img.shape])
+ nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
+ np.savez(nppath, all_img)
+
+ else:
+ raise NotImplementedError('Currently only sampling for unconditional models supported.')
+
+ print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
+
+
+def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
+ for k in logs:
+ if k == key:
+ batch = logs[key]
+ if np_path is None:
+ for x in batch:
+ img = custom_to_pil(x)
+ imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
+ img.save(imgpath)
+ n_saved += 1
+ else:
+ npbatch = custom_to_np(batch)
+ shape_str = "x".join([str(x) for x in npbatch.shape])
+ nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
+ np.savez(nppath, npbatch)
+ n_saved += npbatch.shape[0]
+ return n_saved
+
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-r",
+ "--resume",
+ type=str,
+ nargs="?",
+ help="load from logdir or checkpoint in logdir",
+ )
+ parser.add_argument(
+ "-n",
+ "--n_samples",
+ type=int,
+ nargs="?",
+ help="number of samples to draw",
+ default=50000
+ )
+ parser.add_argument(
+ "-e",
+ "--eta",
+ type=float,
+ nargs="?",
+ help="eta for ddim sampling (0.0 yields deterministic sampling)",
+ default=1.0
+ )
+ parser.add_argument(
+ "-v",
+ "--vanilla_sample",
+ default=False,
+ action='store_true',
+ help="vanilla sampling (default option is DDIM sampling)?",
+ )
+ parser.add_argument(
+ "-l",
+ "--logdir",
+ type=str,
+ nargs="?",
+ help="extra logdir",
+ default="none"
+ )
+ parser.add_argument(
+ "-c",
+ "--custom_steps",
+ type=int,
+ nargs="?",
+ help="number of steps for ddim and fastdpm sampling",
+ default=50
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ nargs="?",
+ help="the bs",
+ default=10
+ )
+ return parser
+
+
+def load_model_from_config(config, sd):
+ model = instantiate_from_config(config)
+ model.load_state_dict(sd,strict=False)
+ model.cuda()
+ model.eval()
+ return model
+
+
+def load_model(config, ckpt, gpu, eval_mode):
+ if ckpt:
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ global_step = pl_sd["global_step"]
+ else:
+ pl_sd = {"state_dict": None}
+ global_step = None
+ model = load_model_from_config(config.model,
+ pl_sd["state_dict"])
+
+ return model, global_step
+
+
+if __name__ == "__main__":
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+ sys.path.append(os.getcwd())
+ command = " ".join(sys.argv)
+
+ parser = get_parser()
+ opt, unknown = parser.parse_known_args()
+ ckpt = None
+
+ if not os.path.exists(opt.resume):
+ raise ValueError("Cannot find {}".format(opt.resume))
+ if os.path.isfile(opt.resume):
+ # paths = opt.resume.split("/")
+ try:
+ logdir = '/'.join(opt.resume.split('/')[:-1])
+ # idx = len(paths)-paths[::-1].index("logs")+1
+ print(f'Logdir is {logdir}')
+ except ValueError:
+ paths = opt.resume.split("/")
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
+ logdir = "/".join(paths[:idx])
+ ckpt = opt.resume
+ else:
+ assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
+ logdir = opt.resume.rstrip("/")
+ ckpt = os.path.join(logdir, "model.ckpt")
+
+ base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
+ opt.base = base_configs
+
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
+ cli = OmegaConf.from_dotlist(unknown)
+ config = OmegaConf.merge(*configs, cli)
+
+ gpu = True
+ eval_mode = True
+
+ if opt.logdir != "none":
+ locallog = logdir.split(os.sep)[-1]
+ if locallog == "": locallog = logdir.split(os.sep)[-2]
+ print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
+ logdir = os.path.join(opt.logdir, locallog)
+
+ print(config)
+
+ model, global_step = load_model(config, ckpt, gpu, eval_mode)
+ print(f"global step: {global_step}")
+ print(75 * "=")
+ print("logging to:")
+ logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
+ imglogdir = os.path.join(logdir, "img")
+ numpylogdir = os.path.join(logdir, "numpy")
+
+ os.makedirs(imglogdir)
+ os.makedirs(numpylogdir)
+ print(logdir)
+ print(75 * "=")
+
+ # write config out
+ sampling_file = os.path.join(logdir, "sampling_config.yaml")
+ sampling_conf = vars(opt)
+
+ with open(sampling_file, 'w') as f:
+ yaml.dump(sampling_conf, f, default_flow_style=False)
+ print(sampling_conf)
+
+
+ run(model, imglogdir, eta=opt.eta,
+ vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
+ batch_size=opt.batch_size, nplog=numpylogdir)
+
+ print("done.")