diff --git a/mmedit/evaluation/metrics/ppl.py b/mmedit/evaluation/metrics/ppl.py index ba6500de3a..7e22ef1ec3 100644 --- a/mmedit/evaluation/metrics/ppl.py +++ b/mmedit/evaluation/metrics/ppl.py @@ -240,6 +240,7 @@ def __next__(self): if self.idx >= len(self.batch_sizes): raise StopIteration batch = self.batch_sizes[self.idx] + injected_noise = self.generator.make_injected_noise() inputs = torch.randn([batch * 2, self.latent_dim], device=self.device) if self.sampling == 'full': @@ -270,6 +271,7 @@ def __next__(self): inputs=dict( noise=latent_e, sample_kwargs=dict( + injected_noise=injected_noise, input_is_latent=(self.space == 'W')))) ppl_sampler = PPLSampler( diff --git a/mmedit/utils/io_utils.py b/mmedit/utils/io_utils.py index de965c1908..bd6128f707 100644 --- a/mmedit/utils/io_utils.py +++ b/mmedit/utils/io_utils.py @@ -3,7 +3,7 @@ import os import click -import mmcv +import mmengine import requests import torch.distributed as dist from mmengine.dist import get_dist_info @@ -67,7 +67,7 @@ def download_from_url(url, if rank == 0: # mkdir _dir = os.path.dirname(dest_path) - mmcv.mkdir_or_exist(_dir) + mmengine.mkdir_or_exist(_dir) if hash_prefix is not None: sha256 = hashlib.sha256()