diff options
author | XavierXiao <xiaozhisheng950@gmail.com> | 2022-09-06 00:00:21 -0700 |
---|---|---|
committer | XavierXiao <xiaozhisheng950@gmail.com> | 2022-09-06 00:00:21 -0700 |
commit | 8f22429d7406ad450e681c9940c00461b1e3adf9 (patch) | |
tree | 0459d4e4be49ac2945a7b1263856b2313b6847e1 /scripts/sample_diffusion.py |
initial commit
Diffstat (limited to 'scripts/sample_diffusion.py')
-rw-r--r-- | scripts/sample_diffusion.py | 313 |
1 files changed, 313 insertions, 0 deletions
diff --git a/scripts/sample_diffusion.py b/scripts/sample_diffusion.py new file mode 100644 index 0000000..876fe3c --- /dev/null +++ b/scripts/sample_diffusion.py @@ -0,0 +1,313 @@ +import argparse, os, sys, glob, datetime, yaml +import torch +import time +import numpy as np +from tqdm import trange + +from omegaconf import OmegaConf +from PIL import Image + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config + +rescale = lambda x: (x + 1.) / 2. + +def custom_to_pil(x): + x = x.detach().cpu() + x = torch.clamp(x, -1., 1.) + x = (x + 1.) / 2. + x = x.permute(1, 2, 0).numpy() + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == "RGB": + x = x.convert("RGB") + return x + + +def custom_to_np(x): + # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py + sample = x.detach().cpu() + sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + return sample + + +def logs2pil(logs, keys=["sample"]): + imgs = dict() + for k in logs: + try: + if len(logs[k].shape) == 4: + img = custom_to_pil(logs[k][0, ...]) + elif len(logs[k].shape) == 3: + img = custom_to_pil(logs[k]) + else: + print(f"Unknown format for key {k}. ") + img = None + except: + img = None + imgs[k] = img + return imgs + + +@torch.no_grad() +def convsample(model, shape, return_intermediates=True, + verbose=True, + make_prog_row=False): + + + if not make_prog_row: + return model.p_sample_loop(None, shape, + return_intermediates=return_intermediates, verbose=verbose) + else: + return model.progressive_denoising( + None, shape, verbose=True + ) + + +@torch.no_grad() +def convsample_ddim(model, steps, shape, eta=1.0 + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): + + + log = dict() + + shape = [batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size] + + with model.ema_scope("Plotting"): + t0 = time.time() + if vanilla: + sample, progrow = convsample(model, shape, + make_prog_row=True) + else: + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, + eta=eta) + + t1 = time.time() + + x_sample = model.decode_first_stage(sample) + + log["sample"] = x_sample + log["time"] = t1 - t0 + log['throughput'] = sample.shape[0] / (t1 - t0) + print(f'Throughput for this batch: {log["throughput"]}') + return log + +def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): + if vanilla: + print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + else: + print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') + + + tstart = time.time() + n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + # path = logdir + if model.cond_stage_model is None: + all_images = [] + + print(f"Running unconditional sampling for {n_samples} samples") + for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): + logs = make_convolutional_sample(model, batch_size=batch_size, + vanilla=vanilla, custom_steps=custom_steps, + eta=eta) + n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") + all_images.extend([custom_to_np(logs["sample"])]) + if n_saved >= n_samples: + print(f'Finish after generating {n_saved} samples') + break + all_img = np.concatenate(all_images, axis=0) + all_img = all_img[:n_samples] + shape_str = "x".join([str(x) for x in all_img.shape]) + nppath = os.path.join(nplog, f"{shape_str}-samples.npz") + np.savez(nppath, all_img) + + else: + raise NotImplementedError('Currently only sampling for unconditional models supported.') + + print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") + + +def save_logs(logs, path, n_saved=0, key="sample", np_path=None): + for k in logs: + if k == key: + batch = logs[key] + if np_path is None: + for x in batch: + img = custom_to_pil(x) + imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") + img.save(imgpath) + n_saved += 1 + else: + npbatch = custom_to_np(batch) + shape_str = "x".join([str(x) for x in npbatch.shape]) + nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") + np.savez(nppath, npbatch) + n_saved += npbatch.shape[0] + return n_saved + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--resume", + type=str, + nargs="?", + help="load from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-n", + "--n_samples", + type=int, + nargs="?", + help="number of samples to draw", + default=50000 + ) + parser.add_argument( + "-e", + "--eta", + type=float, + nargs="?", + help="eta for ddim sampling (0.0 yields deterministic sampling)", + default=1.0 + ) + parser.add_argument( + "-v", + "--vanilla_sample", + default=False, + action='store_true', + help="vanilla sampling (default option is DDIM sampling)?", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + nargs="?", + help="extra logdir", + default="none" + ) + parser.add_argument( + "-c", + "--custom_steps", + type=int, + nargs="?", + help="number of steps for ddim and fastdpm sampling", + default=50 + ) + parser.add_argument( + "--batch_size", + type=int, + nargs="?", + help="the bs", + default=10 + ) + return parser + + +def load_model_from_config(config, sd): + model = instantiate_from_config(config) + model.load_state_dict(sd,strict=False) + model.cuda() + model.eval() + return model + + +def load_model(config, ckpt, gpu, eval_mode): + if ckpt: + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, + pl_sd["state_dict"]) + + return model, global_step + + +if __name__ == "__main__": + now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + sys.path.append(os.getcwd()) + command = " ".join(sys.argv) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + ckpt = None + + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + # paths = opt.resume.split("/") + try: + logdir = '/'.join(opt.resume.split('/')[:-1]) + # idx = len(paths)-paths[::-1].index("logs")+1 + print(f'Logdir is {logdir}') + except ValueError: + paths = opt.resume.split("/") + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + logdir = "/".join(paths[:idx]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "model.ckpt") + + base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) + opt.base = base_configs + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + gpu = True + eval_mode = True + + if opt.logdir != "none": + locallog = logdir.split(os.sep)[-1] + if locallog == "": locallog = logdir.split(os.sep)[-2] + print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") + logdir = os.path.join(opt.logdir, locallog) + + print(config) + + model, global_step = load_model(config, ckpt, gpu, eval_mode) + print(f"global step: {global_step}") + print(75 * "=") + print("logging to:") + logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) + imglogdir = os.path.join(logdir, "img") + numpylogdir = os.path.join(logdir, "numpy") + + os.makedirs(imglogdir) + os.makedirs(numpylogdir) + print(logdir) + print(75 * "=") + + # write config out + sampling_file = os.path.join(logdir, "sampling_config.yaml") + sampling_conf = vars(opt) + + with open(sampling_file, 'w') as f: + yaml.dump(sampling_conf, f, default_flow_style=False) + print(sampling_conf) + + + run(model, imglogdir, eta=opt.eta, + vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, + batch_size=opt.batch_size, nplog=numpylogdir) + + print("done.") |