Skip to content

Commit

Permalink
Delete task_templates in IterableDataset when they are no longer …
Browse files Browse the repository at this point in the history
…valid (#6027)
  • Loading branch information
mariosasko authored Jul 13, 2023
1 parent f49a163 commit dbf6c10
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,11 @@ def rename_columns_fn(example):
for col, feature in original_features.items()
}
)
# check that it's still valid, especially with regard to task templates
try:
ds_iterable._info.copy()
except ValueError:
ds_iterable._info.task_templates = None
return ds_iterable

def remove_columns(self, column_names: Union[str, List[str]]) -> "IterableDataset":
Expand Down Expand Up @@ -2027,6 +2032,12 @@ def remove_columns(self, column_names: Union[str, List[str]]) -> "IterableDatase
for col, _ in original_features.items():
if col in column_names:
del ds_iterable._info.features[col]
# check that it's still valid, especially with regard to task templates
try:
ds_iterable._info.copy()
except ValueError:
ds_iterable._info.task_templates = None

return ds_iterable

def select_columns(self, column_names: Union[str, List[str]]) -> "IterableDataset":
Expand Down Expand Up @@ -2068,6 +2079,11 @@ def select_columns(self, column_names: Union[str, List[str]]) -> "IterableDatase
f"{list(self._info.features.keys())}."
)
info.features = Features({c: info.features[c] for c in column_names})
# check that it's still valid, especially with regard to task templates
try:
info.copy()
except ValueError:
info.task_templates = None

ex_iterable = SelectColumnsIterable(self._ex_iterable, column_names)
return IterableDataset(
Expand Down

0 comments on commit dbf6c10

Please sign in to comment.