aboutsummaryrefslogtreecommitdiff
path: root/scripts/inpaint.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/inpaint.py')
-rw-r--r--scripts/inpaint.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/scripts/inpaint.py b/scripts/inpaint.py
new file mode 100644
index 0000000..d6e6387
--- /dev/null
+++ b/scripts/inpaint.py
@@ -0,0 +1,98 @@
+import argparse, os, sys, glob
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm
+import numpy as np
+import torch
+from main import instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+def make_batch(image, mask, device):
+ image = np.array(Image.open(image).convert("RGB"))
+ image = image.astype(np.float32)/255.0
+ image = image[None].transpose(0,3,1,2)
+ image = torch.from_numpy(image)
+
+ mask = np.array(Image.open(mask).convert("L"))
+ mask = mask.astype(np.float32)/255.0
+ mask = mask[None,None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = (1-mask)*image
+
+ batch = {"image": image, "mask": mask, "masked_image": masked_image}
+ for k in batch:
+ batch[k] = batch[k].to(device=device)
+ batch[k] = batch[k]*2.0-1.0
+ return batch
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--indir",
+ type=str,
+ nargs="?",
+ help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
+ )
+ parser.add_argument(
+ "--outdir",
+ type=str,
+ nargs="?",
+ help="dir to write results to",
+ )
+ parser.add_argument(
+ "--steps",
+ type=int,
+ default=50,
+ help="number of ddim sampling steps",
+ )
+ opt = parser.parse_args()
+
+ masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
+ images = [x.replace("_mask.png", ".png") for x in masks]
+ print(f"Found {len(masks)} inputs.")
+
+ config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
+ model = instantiate_from_config(config.model)
+ model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
+ strict=False)
+
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+ sampler = DDIMSampler(model)
+
+ os.makedirs(opt.outdir, exist_ok=True)
+ with torch.no_grad():
+ with model.ema_scope():
+ for image, mask in tqdm(zip(images, masks)):
+ outpath = os.path.join(opt.outdir, os.path.split(image)[1])
+ batch = make_batch(image, mask, device=device)
+
+ # encode masked image and concat downsampled mask
+ c = model.cond_stage_model.encode(batch["masked_image"])
+ cc = torch.nn.functional.interpolate(batch["mask"],
+ size=c.shape[-2:])
+ c = torch.cat((c, cc), dim=1)
+
+ shape = (c.shape[1]-1,)+c.shape[2:]
+ samples_ddim, _ = sampler.sample(S=opt.steps,
+ conditioning=c,
+ batch_size=c.shape[0],
+ shape=shape,
+ verbose=False)
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
+
+ image = torch.clamp((batch["image"]+1.0)/2.0,
+ min=0.0, max=1.0)
+ mask = torch.clamp((batch["mask"]+1.0)/2.0,
+ min=0.0, max=1.0)
+ predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
+ min=0.0, max=1.0)
+
+ inpainted = (1-mask)*image+mask*predicted_image
+ inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
+ Image.fromarray(inpainted.astype(np.uint8)).save(outpath)