From dbf6c103f5844de40431478e7e4a64fbf2c2c067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Thu, 13 Jul 2023 15:57:35 +0200 Subject: [PATCH] Delete `task_templates` in `IterableDataset` when they are no longer valid (#6027) --- src/datasets/iterable_dataset.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 8a59ff95ded..fe9f54641e7 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1993,6 +1993,11 @@ def rename_columns_fn(example): for col, feature in original_features.items() } ) + # check that it's still valid, especially with regard to task templates + try: + ds_iterable._info.copy() + except ValueError: + ds_iterable._info.task_templates = None return ds_iterable def remove_columns(self, column_names: Union[str, List[str]]) -> "IterableDataset": @@ -2027,6 +2032,12 @@ def remove_columns(self, column_names: Union[str, List[str]]) -> "IterableDatase for col, _ in original_features.items(): if col in column_names: del ds_iterable._info.features[col] + # check that it's still valid, especially with regard to task templates + try: + ds_iterable._info.copy() + except ValueError: + ds_iterable._info.task_templates = None + return ds_iterable def select_columns(self, column_names: Union[str, List[str]]) -> "IterableDataset": @@ -2068,6 +2079,11 @@ def select_columns(self, column_names: Union[str, List[str]]) -> "IterableDatase f"{list(self._info.features.keys())}." ) info.features = Features({c: info.features[c] for c in column_names}) + # check that it's still valid, especially with regard to task templates + try: + info.copy() + except ValueError: + info.task_templates = None ex_iterable = SelectColumnsIterable(self._ex_iterable, column_names) return IterableDataset(