From 005355bd6fdc3a45f4d54f8d8dfd035b7968ce64 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Wed, 21 Oct 2020 17:16:06 +0200 Subject: [PATCH] Added eps in the __repr__ of FrozenBN (#2852) * feat: Updated FrozenBN eps to align with BatchNorm * feat: Added eps to __repr__ of FrozenBN * test: Updated unittest of __repr__ for FrozenBN * test: Updated unittest for eps value in BN and FrozenBN * fix: Revert FrozenBN eps value * test: Revert test on eps alignment between FrozenBN and BN --- test/test_ops.py | 5 +++-- torchvision/ops/misc.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a6f161051fc..7c13de4dedc 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -607,10 +607,11 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_): class FrozenBNTester(unittest.TestCase): def test_frozenbatchnorm2d_repr(self): num_features = 32 - t = ops.misc.FrozenBatchNorm2d(num_features) + eps = 1e-5 + t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps) # Check integrity of object __repr__ attribute - expected_string = f"FrozenBatchNorm2d({num_features})" + expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})" self.assertEqual(t.__repr__(), expected_string) def test_frozenbatchnorm2d_eps(self): diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 17e69c506d8..3b52c0d8c4d 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -96,4 +96,4 @@ def forward(self, x: Tensor) -> Tensor: return x * scale + bias def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.weight.shape[0]})" + return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"