Skip to content

Commit

Permalink
flake8 warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
juampatronics committed Mar 25, 2024
1 parent 1b97e69 commit f2fe14e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
4 changes: 1 addition & 3 deletions monai/transforms/regularization/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ class MixUpd(MapTransform):
for consistency, i.e. images and labels must be applied the same augmenation.
"""

def __init__(
self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
) -> None:
def __init__(self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False) -> None:
super().__init__(keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_mixup(self):
mixup = MixUp(6, 1.0)
output = mixup(sample)
self.assertEqual(output.shape, sample.shape)
self.assertTrue(any([not torch.allclose(sample, mixup(sample)) for _ in range(10)]))
self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10)))

with self.assertRaises(ValueError):
MixUp(6, -0.5)
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_cutmix(self):
cutmix = CutMix(6, 1.0)
output = cutmix(sample)
self.assertEqual(output.shape, sample.shape)
self.assertTrue(any([not torch.allclose(sample, cutmix(sample)) for _ in range(10)]))
self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10)))

def test_cutmixd(self):
for dims in [2, 3]:
Expand All @@ -83,7 +83,7 @@ def test_cutout(self):
cutout = CutOut(6, 1.0)
output = cutout(sample)
self.assertEqual(output.shape, sample.shape)
self.assertTrue(any([not torch.allclose(sample, cutout(sample)) for _ in range(10)]))
self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10)))


if __name__ == "__main__":
Expand Down

0 comments on commit f2fe14e

Please sign in to comment.