aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXavier <xiaozhisheng950@gmail.com>2022-09-20 21:48:16 -0700
committerGitHub <noreply@github.com>2022-09-20 21:48:16 -0700
commita22c61800db1c82c4cab0c76f8dd62537d99a435 (patch)
tree06ab14935d752d57abdc1152cae3649ab3bbacac
parentUpdate README.md (diff)
Update util.py
-rw-r--r--ldm/modules/diffusionmodules/util.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
index 6b5b9dc..3d25919 100644
--- a/ldm/modules/diffusionmodules/util.py
+++ b/ldm/modules/diffusionmodules/util.py
@@ -109,7 +109,7 @@ def checkpoint(func, inputs, params, flag):
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
- if False: # disabled checkpointing to allow requires_grad = False for main model
+ if flag: # disabled checkpointing to allow requires_grad = False for main model
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
@@ -264,4 +264,4 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
- return repeat_noise() if repeat else noise() \ No newline at end of file
+ return repeat_noise() if repeat else noise()