Skip to content

Commit

Permalink
Update src/torchmetrics/image/fid.py
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
furkan-celik and SkafteNicki authored Apr 12, 2024
1 parent 43e28ee commit 5944b4e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def __init__(
if self.normalize:
dummy_image = torch.rand(1, *input_img_size, dtype=torch.float32)
else:
dummy_image = torch.randint(0, 255, input_img_size, dtype=torch.uint8)
dummy_image = torch.randint(0, 255, (1, *input_img_size), dtype=torch.uint8)
num_features = self.inception(dummy_image).shape[-1]
else:
raise TypeError("Got unknown input to argument `feature`")
Expand Down

0 comments on commit 5944b4e

Please sign in to comment.