Skip to content

Commit 0a85eed

Browse files
virginiafdezVirginia FernandezKumoLiu
authored
Inferer modification - save_intermediates clashes with latent shape adjustment in latent diffusion inferers (#8343)
Fixes #8334 ### Description There was an if save_intermediates missing in the code that was trying to run crop of the latent spaces on the sample function of the Latent Diffusion Inferers (normal one and ControlNet one) even when intermediates aren't created. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 44add8d commit 0a85eed

File tree

3 files changed

+151
-10
lines changed

3 files changed

+151
-10
lines changed

monai/inferers/inferer.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1202,15 +1202,16 @@ def sample( # type: ignore[override]
12021202

12031203
if self.autoencoder_latent_shape is not None:
12041204
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1205-
latent_intermediates = [
1206-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1207-
]
1205+
if save_intermediates:
1206+
latent_intermediates = [
1207+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1208+
for l in latent_intermediates
1209+
]
12081210

12091211
decode = autoencoder_model.decode_stage_2_outputs
12101212
if isinstance(autoencoder_model, SPADEAutoencoderKL):
12111213
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
12121214
image = decode(latent / self.scale_factor)
1213-
12141215
if save_intermediates:
12151216
intermediates = []
12161217
for latent_intermediate in latent_intermediates:
@@ -1727,9 +1728,11 @@ def sample( # type: ignore[override]
17271728

17281729
if self.autoencoder_latent_shape is not None:
17291730
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1730-
latent_intermediates = [
1731-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1732-
]
1731+
if save_intermediates:
1732+
latent_intermediates = [
1733+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1734+
for l in latent_intermediates
1735+
]
17331736

17341737
decode = autoencoder_model.decode_stage_2_outputs
17351738
if isinstance(autoencoder_model, SPADEAutoencoderKL):

tests/inferers/test_controlnet_inferers.py

+80-2
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def test_prediction_shape(
722722

723723
@parameterized.expand(LATENT_CNDM_TEST_CASES)
724724
@skipUnless(has_einops, "Requires einops")
725-
def test_sample_shape(
725+
def test_pred_shape(
726726
self,
727727
ae_model_type,
728728
autoencoder_params,
@@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat(
11651165

11661166
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
11671167
@skipUnless(has_einops, "Requires einops")
1168-
def test_sample_shape_different_latents(
1168+
def test_shape_different_latents(
11691169
self,
11701170
ae_model_type,
11711171
autoencoder_params,
@@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents(
12421242
)
12431243
self.assertEqual(prediction.shape, latent_shape)
12441244

1245+
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
1246+
@skipUnless(has_einops, "Requires einops")
1247+
def test_sample_shape_different_latents(
1248+
self,
1249+
ae_model_type,
1250+
autoencoder_params,
1251+
dm_model_type,
1252+
stage_2_params,
1253+
controlnet_params,
1254+
input_shape,
1255+
latent_shape,
1256+
):
1257+
stage_1 = None
1258+
1259+
if ae_model_type == "AutoencoderKL":
1260+
stage_1 = AutoencoderKL(**autoencoder_params)
1261+
if ae_model_type == "VQVAE":
1262+
stage_1 = VQVAE(**autoencoder_params)
1263+
if ae_model_type == "SPADEAutoencoderKL":
1264+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
1265+
if dm_model_type == "SPADEDiffusionModelUNet":
1266+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
1267+
else:
1268+
stage_2 = DiffusionModelUNet(**stage_2_params)
1269+
controlnet = ControlNet(**controlnet_params)
1270+
1271+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
1272+
stage_1.to(device)
1273+
stage_2.to(device)
1274+
controlnet.to(device)
1275+
stage_1.eval()
1276+
stage_2.eval()
1277+
controlnet.eval()
1278+
1279+
noise = torch.randn(latent_shape).to(device)
1280+
mask = torch.randn(input_shape).to(device)
1281+
scheduler = DDPMScheduler(num_train_timesteps=10)
1282+
# We infer the VAE shape
1283+
if ae_model_type == "VQVAE":
1284+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
1285+
else:
1286+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
1287+
1288+
inferer = ControlNetLatentDiffusionInferer(
1289+
scheduler=scheduler,
1290+
scale_factor=1.0,
1291+
ldm_latent_shape=list(latent_shape[2:]),
1292+
autoencoder_latent_shape=autoencoder_latent_shape,
1293+
)
1294+
scheduler.set_timesteps(num_inference_steps=10)
1295+
1296+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
1297+
input_shape_seg = list(input_shape)
1298+
if "label_nc" in stage_2_params.keys():
1299+
input_shape_seg[1] = stage_2_params["label_nc"]
1300+
else:
1301+
input_shape_seg[1] = autoencoder_params["label_nc"]
1302+
input_seg = torch.randn(input_shape_seg).to(device)
1303+
prediction, _ = inferer.sample(
1304+
autoencoder_model=stage_1,
1305+
diffusion_model=stage_2,
1306+
controlnet=controlnet,
1307+
cn_cond=mask,
1308+
input_noise=noise,
1309+
seg=input_seg,
1310+
save_intermediates=True,
1311+
)
1312+
else:
1313+
prediction = inferer.sample(
1314+
autoencoder_model=stage_1,
1315+
diffusion_model=stage_2,
1316+
input_noise=noise,
1317+
controlnet=controlnet,
1318+
cn_cond=mask,
1319+
save_intermediates=False,
1320+
)
1321+
self.assertEqual(prediction.shape, input_shape)
1322+
12451323
@skipUnless(has_einops, "Requires einops")
12461324
def test_incompatible_spade_setup(self):
12471325
stage_1 = SPADEAutoencoderKL(

tests/inferers/test_latent_diffusion_inferer.py

+61-1
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ def test_sample_shape_conditioned_concat(
714714

715715
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
716716
@skipUnless(has_einops, "Requires einops")
717-
def test_sample_shape_different_latents(
717+
def test_shape_different_latents(
718718
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
719719
):
720720
stage_1 = None
@@ -772,6 +772,66 @@ def test_sample_shape_different_latents(
772772
)
773773
self.assertEqual(prediction.shape, latent_shape)
774774

775+
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
776+
@skipUnless(has_einops, "Requires einops")
777+
def test_sample_shape_different_latents(
778+
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
779+
):
780+
stage_1 = None
781+
782+
if ae_model_type == "AutoencoderKL":
783+
stage_1 = AutoencoderKL(**autoencoder_params)
784+
if ae_model_type == "VQVAE":
785+
stage_1 = VQVAE(**autoencoder_params)
786+
if ae_model_type == "SPADEAutoencoderKL":
787+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
788+
if dm_model_type == "SPADEDiffusionModelUNet":
789+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
790+
else:
791+
stage_2 = DiffusionModelUNet(**stage_2_params)
792+
793+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
794+
stage_1.to(device)
795+
stage_2.to(device)
796+
stage_1.eval()
797+
stage_2.eval()
798+
799+
noise = torch.randn(latent_shape).to(device)
800+
scheduler = DDPMScheduler(num_train_timesteps=10)
801+
# We infer the VAE shape
802+
if ae_model_type == "VQVAE":
803+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
804+
else:
805+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
806+
807+
inferer = LatentDiffusionInferer(
808+
scheduler=scheduler,
809+
scale_factor=1.0,
810+
ldm_latent_shape=list(latent_shape[2:]),
811+
autoencoder_latent_shape=autoencoder_latent_shape,
812+
)
813+
scheduler.set_timesteps(num_inference_steps=10)
814+
815+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
816+
input_shape_seg = list(input_shape)
817+
if "label_nc" in stage_2_params.keys():
818+
input_shape_seg[1] = stage_2_params["label_nc"]
819+
else:
820+
input_shape_seg[1] = autoencoder_params["label_nc"]
821+
input_seg = torch.randn(input_shape_seg).to(device)
822+
prediction, _ = inferer.sample(
823+
autoencoder_model=stage_1,
824+
diffusion_model=stage_2,
825+
input_noise=noise,
826+
save_intermediates=True,
827+
seg=input_seg,
828+
)
829+
else:
830+
prediction = inferer.sample(
831+
autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
832+
)
833+
self.assertEqual(prediction.shape, input_shape)
834+
775835
@skipUnless(has_einops, "Requires einops")
776836
def test_incompatible_spade_setup(self):
777837
stage_1 = SPADEAutoencoderKL(

0 commit comments

Comments
 (0)