Skip to content

Commit

Permalink
feat: 🧑‍💻 improve codes to allow training
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed May 29, 2024
1 parent e73c719 commit 036f2cc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
4 changes: 2 additions & 2 deletions pyiqa/archs/dists_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,6 @@ def forward(self, x, y):
S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)

score = 1 - (dist1 + dist2).squeeze()
score = 1 - (dist1 + dist2)

return score
return score.squeeze(-1).squeeze(-1)
2 changes: 1 addition & 1 deletion pyiqa/archs/lpips_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def forward(self, in1, in0, retPerLayer=False, normalize=True):
if (retPerLayer):
return (val, res)
else:
return val.squeeze()
return val.squeeze(-1).squeeze(-1)


class ScalingLayer(nn.Module):
Expand Down
14 changes: 9 additions & 5 deletions pyiqa/archs/pieapp_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn.functional as F

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network
from pyiqa.archs.arch_util import load_pretrained_network, random_crop
from .func_util import extract_2d_patches

default_model_urls = {
Expand Down Expand Up @@ -115,12 +115,16 @@ def forward(self, dist, ref):
assert dist.shape == ref.shape, f'Input and reference images should have the same shape, but got {dist.shape}'
f' and {ref.shape}'

if self.pretrained:
dist = self.preprocess(dist)
ref = self.preprocess(ref)
dist = self.preprocess(dist)
ref = self.preprocess(ref)

if not self.training:
image_A_patches = extract_2d_patches(dist, self.patch_size, self.stride, padding='none')
image_ref_patches = extract_2d_patches(ref, self.patch_size, self.stride, padding='none')
else:
image_A_patches, image_ref_patches = dist, ref
image_A_patches = image_A_patches.unsqueeze(1)
image_ref_patches = image_ref_patches.unsqueeze(1)

bsz, num_patches, c, psz, psz = image_A_patches.shape
image_A_patches = image_A_patches.reshape(bsz * num_patches, c, psz, psz)
Expand All @@ -138,4 +142,4 @@ def forward(self, dist, ref):
per_patch_weight = per_patch_weight.view((-1, num_patches))

score = (per_patch_weight * per_patch_score).sum(dim=-1) / per_patch_weight.sum(dim=-1)
return score.squeeze()
return score.reshape(bsz, 1)

0 comments on commit 036f2cc

Please sign in to comment.