From 3465abb70e3cda27aaf4394ead9bc5d5ccd437fe Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Mon, 2 Mar 2020 11:15:20 +0300 Subject: [PATCH] Fix remainder logic for subset splitting --- datumaro/datumaro/plugins/transforms.py | 11 +++++++---- datumaro/tests/test_project.py | 7 +------ datumaro/tests/test_transforms.py | 16 ++++++++++------ 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/datumaro/datumaro/plugins/transforms.py b/datumaro/datumaro/plugins/transforms.py index f373932841c..47cbfcf3209 100644 --- a/datumaro/datumaro/plugins/transforms.py +++ b/datumaro/datumaro/plugins/transforms.py @@ -325,8 +325,12 @@ def build_cmdline_parser(cls, **kwargs): def __init__(self, extractor, splits, seed=None): super().__init__(extractor) - total_ratio = sum((s[1] for s in splits), 0) - if not total_ratio == 1: + assert 0 < len(splits), "Expected at least one split" + assert all(0.0 <= r and r <= 1.0 for _, r in splits), \ + "Ratios are expected to be in the range [0; 1], but got %s" % splits + + total_ratio = sum(s[1] for s in splits) + if not abs(total_ratio - 1.0) <= 1e-7: raise Exception( "Sum of ratios is expected to be 1, got %s, which is %s" % (splits, total_ratio)) @@ -336,7 +340,6 @@ def __init__(self, extractor, splits, seed=None): random.seed(seed) random.shuffle(indices) - parts = [] s = 0 for subset, ratio in splits: @@ -350,7 +353,7 @@ def _find_split(self, index): for boundary, subset in self._parts: if index < boundary: return subset - return subset + return subset # all the possible remainder goes to the last split def __iter__(self): for i, item in enumerate(self._extractor): diff --git a/datumaro/tests/test_project.py b/datumaro/tests/test_project.py index 229c9f0f628..75baf716e80 100644 --- a/datumaro/tests/test_project.py +++ b/datumaro/tests/test_project.py @@ -534,15 +534,10 @@ def __iter__(self): class DatasetItemTest(TestCase): def test_ctor_requires_id(self): - has_error = False - try: + with self.assertRaises(Exception): # pylint: disable=no-value-for-parameter DatasetItem() # pylint: enable=no-value-for-parameter - except AssertionError: - has_error = True - - self.assertTrue(has_error) @staticmethod def test_ctors_with_image(): diff --git a/datumaro/tests/test_transforms.py b/datumaro/tests/test_transforms.py index 19f9bea2fa7..b90581e8f97 100644 --- a/datumaro/tests/test_transforms.py +++ b/datumaro/tests/test_transforms.py @@ -342,18 +342,22 @@ def __iter__(self): self.assertEqual(4, len(actual.get_subset('train'))) self.assertEqual(3, len(actual.get_subset('test'))) - def test_random_split_gives_error_on_non1_ratios(self): + def test_random_split_gives_error_on_wrong_ratios(self): class SrcExtractor(Extractor): def __iter__(self): return iter([DatasetItem(id=1)]) - has_error = False - try: + with self.assertRaises(Exception): transforms.RandomSplit(SrcExtractor(), splits=[ ('train', 0.5), ('test', 0.7), ]) - except Exception: - has_error = True - self.assertTrue(has_error) \ No newline at end of file + with self.assertRaises(Exception): + transforms.RandomSplit(SrcExtractor(), splits=[]) + + with self.assertRaises(Exception): + transforms.RandomSplit(SrcExtractor(), splits=[ + ('train', -0.5), + ('test', 1.5), + ])