diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index e7ac1a34665..a0d885b4354 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -5613,7 +5613,9 @@ def push_to_hub( @transmit_format @fingerprint_transform(inplace=False) - def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: str): + def add_column( + self, name: str, column: Union[list, np.array], new_fingerprint: str, feature: Optional[FeatureType] = None + ): """Add column to Dataset. @@ -5623,6 +5625,8 @@ def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: Column name. column (`list` or `np.array`): Column data to be added. + feature (`FeatureType` or `None`, defaults to `None`): + Column datatype. Returns: [`Dataset`] @@ -5640,7 +5644,13 @@ def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: }) ``` """ - column_table = InMemoryTable.from_pydict({name: column}) + + if feature: + pyarrow_schema = Features({name: feature}).arrow_schema + else: + pyarrow_schema = None + + column_table = InMemoryTable.from_pydict({name: column}, schema=pyarrow_schema) _check_column_names(self._data.column_names + column_table.column_names) dataset = self.flatten_indices() if self._indices is not None else self # Concatenate tables horizontally