Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add with_rank param to Dataset.filter #6608

Merged
merged 4 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 40 additions & 19 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3532,7 +3532,8 @@ def init_buffer_and_writer():
def filter(
self,
function: Optional[Callable] = None,
with_indices=False,
with_indices: bool = False,
with_rank: bool = False,
input_columns: Optional[Union[str, List[str]]] = None,
batched: bool = False,
batch_size: Optional[int] = 1000,
Expand All @@ -3552,14 +3553,18 @@ def filter(
Args:
function (`Callable`): Callable with one of the following signatures:

- `function(example: Dict[str, Any]) -> bool` if `with_indices=False, batched=False`
- `function(example: Dict[str, Any], indices: int) -> bool` if `with_indices=True, batched=False`
- `function(example: Dict[str, List]) -> List[bool]` if `with_indices=False, batched=True`
- `function(example: Dict[str, List], indices: List[int]) -> List[bool]` if `with_indices=True, batched=True`
- `function(example: Dict[str, Any]) -> Dict[str, Any]` if `batched=False` and `with_indices=False` and `with_rank=False`
- `function(example: Dict[str, Any], *extra_args) -> Dict[str, Any]` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each)
- `function(batch: Dict[str, List]) -> Dict[str, List]` if `batched=True` and `with_indices=False` and `with_rank=False`
- `function(batch: Dict[str, List], *extra_args) -> Dict[str, List]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each)
mariosasko marked this conversation as resolved.
Show resolved Hide resolved

If no function is provided, defaults to an always `True` function: `lambda x: True`.
with_indices (`bool`, defaults to `False`):
Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`.
Provide example indices to `function`. Note that in this case the
signature of `function` should be `def function(example, idx[, rank]): ...`.
with_rank (`bool`, defaults to `False`):
Provide process rank to `function`. Note that in this case the
signature of `function` should be `def function(example[, idx], rank): ...`.
input_columns (`str` or `List[str]`, *optional*):
The columns to be passed into `function` as
positional arguments. If `None`, a `dict` mapping to all formatted columns is passed as one argument.
Expand Down Expand Up @@ -3622,9 +3627,16 @@ def filter(

indices = self.map(
function=partial(
get_indices_from_mask_function, function, batched, with_indices, input_columns, self._indices
get_indices_from_mask_function,
function,
batched,
with_indices,
with_rank,
input_columns,
self._indices,
),
with_indices=True,
with_rank=True,
features=Features({"indices": Value("uint64")}),
batched=True,
batch_size=batch_size,
Expand Down Expand Up @@ -6193,41 +6205,50 @@ def get_indices_from_mask_function(
function: Callable,
batched: bool,
with_indices: bool,
with_rank: bool,
input_columns: Optional[Union[str, List[str]]],
indices_mapping: Optional[Table] = None,
*args,
**fn_kwargs,
):
if batched:
# we extract indices from args
*inputs, indices = args
# we extract indices and rank from args
*inputs, indices, rank = args
additional_args = ()
if with_indices:
mask = function(*inputs, indices, **fn_kwargs)
else:
mask = function(*inputs, **fn_kwargs)
additional_args += (indices,)
if with_rank:
additional_args += (rank,)
mask = function(*inputs, *additional_args, **fn_kwargs)
else:
# we get batched data (to do less look-ups) but `function` only accepts one example
# therefore we need to call `function` on each example of the batch to get the mask
*inputs, indices = args
*inputs, indices, rank = args
mask = []
if input_columns is None:
# inputs only contains a batch of examples
batch: dict = inputs[0]
num_examples = len(batch[next(iter(batch.keys()))])
for i in range(num_examples):
example = {key: batch[key][i] for key in batch}
mask.append(
function(example, indices[i], **fn_kwargs) if with_indices else function(example, **fn_kwargs)
)
additional_args = ()
if with_indices:
additional_args += (indices[i],)
if with_rank:
additional_args += (rank,)
mask.append(function(example, *additional_args, **fn_kwargs))
else:
# inputs is a list of columns
columns: List[List] = inputs
num_examples = len(columns[0])
for i in range(num_examples):
input = [column[i] for column in columns]
mask.append(
function(*input, indices[i], **fn_kwargs) if with_indices else function(*input, **fn_kwargs)
)
additional_args = ()
if with_indices:
additional_args += (indices[i],)
if with_rank:
additional_args += (rank,)
mask.append(function(*input, *additional_args, **fn_kwargs))
indices_array = [i for i, to_keep in zip(indices, mask) if to_keep]
if indices_mapping is not None:
indices_array = pa.array(indices_array, type=pa.uint64())
Expand Down
26 changes: 17 additions & 9 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,9 @@ def map(

def filter(
self,
function,
with_indices=False,
function: Optional[Callable] = None,
with_indices: bool = False,
with_rank: bool = False,
input_columns: Optional[Union[str, List[str]]] = None,
batched: bool = False,
batch_size: Optional[int] = 1000,
Expand All @@ -909,14 +910,20 @@ def filter(
The transformation is applied to all the datasets of the dataset dictionary.

Args:
function (`callable`):
With one of the following signature:
- `function(example: Dict[str, Any]) -> bool` if `with_indices=False, batched=False`
- `function(example: Dict[str, Any], indices: int) -> bool` if `with_indices=True, batched=False`
- `function(example: Dict[str, List]) -> List[bool]` if `with_indices=False, batched=True`
- `function(example: Dict[str, List], indices: List[int]) -> List[bool]` if ``with_indices=True, batched=True`
function (`Callable`): Callable with one of the following signatures:

- `function(example: Dict[str, Any]) -> Dict[str, Any]` if `batched=False` and `with_indices=False` and `with_rank=False`
- `function(example: Dict[str, Any], *extra_args) -> Dict[str, Any]` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each)
- `function(batch: Dict[str, List]) -> Dict[str, List]` if `batched=True` and `with_indices=False` and `with_rank=False`
- `function(batch: Dict[str, List], *extra_args) -> Dict[str, List]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each)
mariosasko marked this conversation as resolved.
Show resolved Hide resolved

If no function is provided, defaults to an always `True` function: `lambda x: True`.
with_indices (`bool`, defaults to `False`):
Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`.
Provide example indices to `function`. Note that in this case the
signature of `function` should be `def function(example, idx[, rank]): ...`.
with_rank (`bool`, defaults to `False`):
Provide process rank to `function`. Note that in this case the
signature of `function` should be `def function(example[, idx], rank): ...`.
input_columns (`[Union[str, List[str]]]`, *optional*, defaults to `None`):
The columns to be passed into `function` as
positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument.
Expand Down Expand Up @@ -976,6 +983,7 @@ def filter(
k: dataset.filter(
function=function,
with_indices=with_indices,
with_rank=with_rank,
input_columns=input_columns,
batched=batched,
batch_size=batch_size,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def picklable_filter_function(x):
return int(x["filename"].split("_")[-1]) < 10


def picklable_filter_function_with_rank(x, r):
return r == 0


def assert_arrow_metadata_are_synced_with_dataset_features(dataset: Dataset):
assert dataset.data.schema.metadata is not None
assert b"huggingface" in dataset.data.schema.metadata
Expand Down Expand Up @@ -1756,6 +1760,18 @@ def test_filter_multiprocessing(self, in_memory):
self.assertEqual(len(dset_filter_first_ten.cache_files), 0 if in_memory else 2)
self.assertNotEqual(dset_filter_first_ten._fingerprint, fingerprint)

with tempfile.TemporaryDirectory() as tmp_dir: # with_rank
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
fingerprint = dset._fingerprint
with dset.filter(
picklable_filter_function_with_rank, num_proc=2, with_rank=True
) as dset_filter_first_rank:
self.assertEqual(len(dset_filter_first_rank), min(len(dset) // 2, len(dset)))
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
self.assertDictEqual(dset_filter_first_rank.features, Features({"filename": Value("string")}))
self.assertEqual(len(dset_filter_first_rank.cache_files), 0 if in_memory else 2)
self.assertNotEqual(dset_filter_first_rank._fingerprint, fingerprint)

def test_filter_caching(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
self._caplog.clear()
Expand Down
Loading