Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a new low-memory approach for tf dataset index shuffling #5863

Merged
merged 16 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/datasets/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""TF-specific utils import."""

import os
import warnings
from functools import partial
from math import ceil
from uuid import uuid4
Expand Down Expand Up @@ -173,6 +174,21 @@ def dataset_to_tf(
else:
raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.")

# TODO Matt: When our minimum Python version is 3.8 or higher, we can delete all of this and move everything
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Matt, is datasets going to drop Python 3.7 support due to its upcoming EOL? Because it will happen by the end of the month in case we want to wait and set the minimum version to 3.8, even though I assume some users may still be using 3.7?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will probably depend on what transformers does

# to the NumPy multiprocessing path.
if hasattr(tf, "random_index_shuffle"):
random_index_shuffle = tf.random_index_shuffle
elif hasattr(tf.random.experimental, "index_shuffle"):
random_index_shuffle = tf.random.experimental.index_shuffle
else:
if len(dataset) > 10_000_000:
warnings.warn(
"to_tf_dataset() can be memory-inefficient on versions of TensorFlow older than 2.9. "
"If you are iterating over a dataset with a very large number of samples, consider "
"upgrading to TF >= 2.9."
)
random_index_shuffle = None

getter_fn = partial(
np_get_batch,
dataset=dataset,
Expand All @@ -195,10 +211,22 @@ def fetch_function(indices):
)
return {key: output[i] for i, key in enumerate(columns_to_np_types.keys())}

tf_dataset = tf.data.Dataset.from_tensor_slices(np.arange(len(dataset), dtype=np.int64))
tf_dataset = tf.data.Dataset.range(len(dataset))

if shuffle and random_index_shuffle is not None:
base_seed = tf.fill((3,), value=tf.cast(-1, dtype=tf.int64))

def scan_random_index(state, index):
if tf.reduce_all(state == -1):
# This generates a new random seed once per epoch only,
# to ensure that we iterate over each sample exactly once per epoch
state = tf.random.uniform(shape=(3,), maxval=2**62, dtype=tf.int64)
shuffled_index = random_index_shuffle(index=index, seed=state, max_index=len(dataset) - 1)
return state, shuffled_index

if shuffle:
tf_dataset = tf_dataset.shuffle(len(dataset))
tf_dataset = tf_dataset.scan(base_seed, scan_random_index)
elif shuffle:
tf_dataset = tf_dataset.shuffle(tf_dataset.cardinality())

if batch_size is not None:
tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2755,6 +2755,8 @@ def test_tf_index_reshuffling(self, in_memory):
second_indices.append(batch["col_1"])
second_indices = np.concatenate([arr.numpy() for arr in second_indices])
self.assertFalse(np.array_equal(indices, second_indices))
self.assertEqual(len(indices), len(np.unique(indices)))
self.assertEqual(len(second_indices), len(np.unique(second_indices)))

tf_dataset = dset.to_tf_dataset(batch_size=1, shuffle=False, num_workers=num_workers)
for i, batch in enumerate(tf_dataset):
Expand Down