Skip to content

Commit

Permalink
Modify add_column() to optionally accept a FeatureType as param (#7143)
Browse files Browse the repository at this point in the history
* Modify add_column() to optionally accept a FeatureType param

* Add feature param to add_column() docstring

---------

Co-authored-by: Varad Bhatnagar <varadhbhatnagar@gmail.com>
  • Loading branch information
varadhbhatnagar and Varad Bhatnagar authored Sep 16, 2024
1 parent e4bba5e commit 43b1fe1
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5613,7 +5613,9 @@ def push_to_hub(

@transmit_format
@fingerprint_transform(inplace=False)
def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: str):
def add_column(
self, name: str, column: Union[list, np.array], new_fingerprint: str, feature: Optional[FeatureType] = None
):
"""Add column to Dataset.
<Added version="1.7"/>
Expand All @@ -5623,6 +5625,8 @@ def add_column(self, name: str, column: Union[list, np.array], new_fingerprint:
Column name.
column (`list` or `np.array`):
Column data to be added.
feature (`FeatureType` or `None`, defaults to `None`):
Column datatype.
Returns:
[`Dataset`]
Expand All @@ -5640,7 +5644,13 @@ def add_column(self, name: str, column: Union[list, np.array], new_fingerprint:
})
```
"""
column_table = InMemoryTable.from_pydict({name: column})

if feature:
pyarrow_schema = Features({name: feature}).arrow_schema
else:
pyarrow_schema = None

column_table = InMemoryTable.from_pydict({name: column}, schema=pyarrow_schema)
_check_column_names(self._data.column_names + column_table.column_names)
dataset = self.flatten_indices() if self._indices is not None else self
# Concatenate tables horizontally
Expand Down

0 comments on commit 43b1fe1

Please sign in to comment.