diff --git a/pyiqa/archs/dists_arch.py b/pyiqa/archs/dists_arch.py index af16759..52e27c2 100644 --- a/pyiqa/archs/dists_arch.py +++ b/pyiqa/archs/dists_arch.py @@ -143,6 +143,6 @@ def forward(self, x, y): S2 = (2 * xy_cov + c2) / (x_var + y_var + c2) dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True) - score = 1 - (dist1 + dist2).squeeze() + score = 1 - (dist1 + dist2) - return score + return score.squeeze(-1).squeeze(-1) diff --git a/pyiqa/archs/lpips_arch.py b/pyiqa/archs/lpips_arch.py index 83b0884..03b2302 100644 --- a/pyiqa/archs/lpips_arch.py +++ b/pyiqa/archs/lpips_arch.py @@ -173,7 +173,7 @@ def forward(self, in1, in0, retPerLayer=False, normalize=True): if (retPerLayer): return (val, res) else: - return val.squeeze() + return val.squeeze(-1).squeeze(-1) class ScalingLayer(nn.Module): diff --git a/pyiqa/archs/pieapp_arch.py b/pyiqa/archs/pieapp_arch.py index 5164d70..9936a15 100644 --- a/pyiqa/archs/pieapp_arch.py +++ b/pyiqa/archs/pieapp_arch.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from pyiqa.utils.registry import ARCH_REGISTRY -from pyiqa.archs.arch_util import load_pretrained_network +from pyiqa.archs.arch_util import load_pretrained_network, random_crop from .func_util import extract_2d_patches default_model_urls = { @@ -115,12 +115,16 @@ def forward(self, dist, ref): assert dist.shape == ref.shape, f'Input and reference images should have the same shape, but got {dist.shape}' f' and {ref.shape}' - if self.pretrained: - dist = self.preprocess(dist) - ref = self.preprocess(ref) + dist = self.preprocess(dist) + ref = self.preprocess(ref) + if not self.training: image_A_patches = extract_2d_patches(dist, self.patch_size, self.stride, padding='none') image_ref_patches = extract_2d_patches(ref, self.patch_size, self.stride, padding='none') + else: + image_A_patches, image_ref_patches = dist, ref + image_A_patches = image_A_patches.unsqueeze(1) + image_ref_patches = image_ref_patches.unsqueeze(1) bsz, num_patches, c, psz, psz = image_A_patches.shape image_A_patches = image_A_patches.reshape(bsz * num_patches, c, psz, psz) @@ -138,4 +142,4 @@ def forward(self, dist, ref): per_patch_weight = per_patch_weight.view((-1, num_patches)) score = (per_patch_weight * per_patch_score).sum(dim=-1) / per_patch_weight.sum(dim=-1) - return score.squeeze() + return score.reshape(bsz, 1)