Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
KumoLiu committed Jun 28, 2024
1 parent a292783 commit 7033892
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_pad_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def tearDown(self) -> None:

@parameterized.expand(TESTS)
def test_pad_collation(self, t_type, collate_method, transform):
if isinstance(t_type, dict):
if t_type is dict:
dataset = CacheDataset(self.dict_data, transform, progress=False)
else:
dataset = _Dataset(self.list_data, self.list_labels, transform)
Expand All @@ -104,7 +104,7 @@ def test_pad_collation(self, t_type, collate_method, transform):
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method)
# check collation in forward direction
for data in loader:
if isinstance(t_type, dict):
if t_type is dict:
shapes = []
decollated_data = decollate_batch(data)
for d in decollated_data:
Expand All @@ -113,7 +113,7 @@ def test_pad_collation(self, t_type, collate_method, transform):
self.assertTrue(len(output["image"].applied_operations), len(dataset.transform.transforms))
self.assertTrue(len(set(shapes)) > 1) # inverted shapes must be different because of random xforms

if isinstance(t_type, dict):
if t_type is dict:
batch_inverse = BatchInverseTransform(dataset.transform, loader)
for data in loader:
output = batch_inverse(data)
Expand Down

0 comments on commit 7033892

Please sign in to comment.