diff --git a/generate_shapes_and_images.py b/generate_shapes_and_images.py index 970535b..4bcb92c 100644 --- a/generate_shapes_and_images.py +++ b/generate_shapes_and_images.py @@ -26,7 +26,8 @@ def generate(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent): g_ema.eval() - surface_g_ema.eval() + if not opt.no_surface_renderings: + surface_g_ema.eval() # set camera angles if opt.fixed_camera_angles: diff --git a/model.py b/model.py index 7cf453c..ade773a 100644 --- a/model.py +++ b/model.py @@ -683,13 +683,14 @@ def make_noise(self): def mean_latent(self, n_latent, device): latent_in = torch.randn(n_latent, self.style_dim, device=device) - renderer_latent = self.style(latent_in).mean(0, keepdim=True) + renderer_latent = self.style(latent_in) + renderer_latent_mean = renderer_latent.mean(0, keepdim=True) if self.full_pipeline: - decoder_latent = self.decoder.mean_latent(renderer_latent) + decoder_latent_mean = self.decoder.mean_latent(renderer_latent) else: - decoder_latent = None + decoder_latent_mean = None - return [renderer_latent, decoder_latent] + return [renderer_latent_mean, decoder_latent_mean] def get_latent(self, input): return self.style(input)