From 1d6fa5b27f57af860236479cac6252e1f0936f2b Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 15 May 2023 16:26:35 +0100 Subject: [PATCH 01/16] Use a new low-memory approach for tf dataset index shuffling --- src/datasets/utils/tf_utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 6cbbbaaf8b5..9d5d47c6ec2 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -195,10 +195,22 @@ def fetch_function(indices): ) return {key: output[i] for i, key in enumerate(columns_to_np_types.keys())} - tf_dataset = tf.data.Dataset.from_tensor_slices(np.arange(len(dataset), dtype=np.int64)) + tf_dataset = tf.data.Dataset.range(len(dataset)) + tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) if shuffle: - tf_dataset = tf_dataset.shuffle(len(dataset)) + base_seed = tf.fill((3,), fill_value=-1, dtype=tf.int64) + + def scan_random_indices(state, indices): + if tf.reduce_all(state == -1): + # This generates a new random seed once per epoch only, + # to ensure that we iterate over each sample exactly once per epoch + state = tf.random.uniform(shape=(3,), maxval=2**62, dtype=tf.int64) + state = tf.ensure_shape(state, (3,)) + shuffled_indices = tf.random_index_shuffle(index=indices, seed=state, max_index=len(dataset) - 1) + return state, shuffled_indices + + tf_dataset = tf_dataset.scan(base_seed, scan_random_indices) if batch_size is not None: tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) From b05b748a71008ccfede813287371d8ef0b57b972 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 15 May 2023 16:33:13 +0100 Subject: [PATCH 02/16] correct fill kwarg --- src/datasets/utils/tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 9d5d47c6ec2..1cfeddce9db 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -199,7 +199,7 @@ def fetch_function(indices): tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) if shuffle: - base_seed = tf.fill((3,), fill_value=-1, dtype=tf.int64) + base_seed = tf.fill((3,), value=-1, dtype=tf.int64) def scan_random_indices(state, indices): if tf.reduce_all(state == -1): From 7f936cbe997263ca3746fd5100e1bce0685caffb Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 15 May 2023 16:33:55 +0100 Subject: [PATCH 03/16] ...and cast the inputs too --- src/datasets/utils/tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 1cfeddce9db..962b45935d4 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -199,7 +199,7 @@ def fetch_function(indices): tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) if shuffle: - base_seed = tf.fill((3,), value=-1, dtype=tf.int64) + base_seed = tf.fill((3,), value=tf.cast(-1, dtype=tf.int64)) def scan_random_indices(state, indices): if tf.reduce_all(state == -1): From 7bd73125650934b8d385704ba84b812751b1499f Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 15 May 2023 17:01:51 +0100 Subject: [PATCH 04/16] Add warnings for older TF --- src/datasets/utils/tf_utils.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 962b45935d4..8dec0958ea6 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -15,6 +15,7 @@ """TF-specific utils import.""" import os +import warnings from functools import partial from math import ceil from uuid import uuid4 @@ -173,6 +174,19 @@ def dataset_to_tf( else: raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") + if hasattr(tf, "random_index_shuffle"): + random_index_shuffle = tf.random_index_shuffle + elif hasattr(tf.random.experimental, "index_shuffle"): + random_index_shuffle = tf.random.experimental.index_shuffle + else: + if len(dataset) > 10_000_000: + warnings.warn( + "to_tf_dataset() can be memory-inefficient on versions of TensorFlow older than 2.9. " + "If you are iterating over a dataset with a very large number of samples, consider " + "upgrading to TF >= 2.9." + ) + random_index_shuffle = None + getter_fn = partial( np_get_batch, dataset=dataset, @@ -196,9 +210,9 @@ def fetch_function(indices): return {key: output[i] for i, key in enumerate(columns_to_np_types.keys())} tf_dataset = tf.data.Dataset.range(len(dataset)) - tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) - if shuffle: + if shuffle and random_index_shuffle is not None: + tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) base_seed = tf.fill((3,), value=tf.cast(-1, dtype=tf.int64)) def scan_random_indices(state, indices): @@ -211,6 +225,11 @@ def scan_random_indices(state, indices): return state, shuffled_indices tf_dataset = tf_dataset.scan(base_seed, scan_random_indices) + elif shuffle: + tf_dataset = tf_dataset.shuffle(len(tf_dataset)) + tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) + else: + tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) if batch_size is not None: tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) From 18d92aaa39650bfe7cd9cb64eb933a650451c0c2 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 15 May 2023 17:05:56 +0100 Subject: [PATCH 05/16] Fix to use the imported random_index_shuffle --- src/datasets/utils/tf_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 8dec0958ea6..797d5114709 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -220,8 +220,7 @@ def scan_random_indices(state, indices): # This generates a new random seed once per epoch only, # to ensure that we iterate over each sample exactly once per epoch state = tf.random.uniform(shape=(3,), maxval=2**62, dtype=tf.int64) - state = tf.ensure_shape(state, (3,)) - shuffled_indices = tf.random_index_shuffle(index=indices, seed=state, max_index=len(dataset) - 1) + shuffled_indices = random_index_shuffle(index=indices, seed=state, max_index=len(dataset) - 1) return state, shuffled_indices tf_dataset = tf_dataset.scan(base_seed, scan_random_indices) From 82534e383534f2ad9cbc715236d3ec5059d9ee6f Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 16 May 2023 15:46:22 +0100 Subject: [PATCH 06/16] Switch to_tf_dataset entirely over to the NumPy multiprocessing approach --- src/datasets/arrow_dataset.py | 48 +++++--------- src/datasets/utils/tf_utils.py | 117 --------------------------------- 2 files changed, 15 insertions(+), 150 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 073e47aefca..ed2198dfccb 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -115,7 +115,7 @@ from .utils.metadata import DatasetMetadata from .utils.py_utils import Literal, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values from .utils.stratify import stratified_shuffle_split_generate_indices -from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf +from .utils.tf_utils import minimal_tf_collate_fn, multiprocess_dataset_to_tf from .utils.typing import ListLike, PathLike @@ -457,38 +457,20 @@ def to_tf_dataset( if col not in output_signature: raise ValueError(f"Label column {col} not found in dataset!") - if num_workers == 0: - tf_dataset = dataset_to_tf( - dataset=dataset, - cols_to_retain=cols_to_retain, - collate_fn=collate_fn, - collate_fn_args=collate_fn_args, - columns_to_np_types=columns_to_np_types, - output_signature=output_signature, - shuffle=shuffle, - batch_size=batch_size, - drop_remainder=drop_remainder, - ) - elif num_workers > 0: - if batch_size is None: - raise NotImplementedError( - "`batch_size` must be specified when using multiple workers, as unbatched multiprocessing " - "is not supported yet. Please provide a `batch_size` if `num_workers` is greater than 0." - ) - tf_dataset = multiprocess_dataset_to_tf( - dataset=dataset, - cols_to_retain=cols_to_retain, - collate_fn=collate_fn, - collate_fn_args=collate_fn_args, - columns_to_np_types=columns_to_np_types, - output_signature=output_signature, - shuffle=shuffle, - batch_size=batch_size, - drop_remainder=drop_remainder, - num_workers=num_workers, - ) - else: - raise ValueError("num_workers must be >= 0") + if num_workers <= 0: + num_workers = 1 + tf_dataset = multiprocess_dataset_to_tf( + dataset=dataset, + cols_to_retain=cols_to_retain, + collate_fn=collate_fn, + collate_fn_args=collate_fn_args, + columns_to_np_types=columns_to_np_types, + output_signature=output_signature, + shuffle=shuffle, + batch_size=batch_size, + drop_remainder=drop_remainder, + num_workers=num_workers, + ) def split_features_and_labels(input_batch): # TODO(Matt, QL): deprecate returning the dict content when there's only one key diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 797d5114709..e1bc9b9848b 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -15,8 +15,6 @@ """TF-specific utils import.""" import os -import warnings -from functools import partial from math import ceil from uuid import uuid4 @@ -133,121 +131,6 @@ def np_get_batch( return out_batch -def dataset_to_tf( - dataset, - cols_to_retain, - collate_fn, - collate_fn_args, - columns_to_np_types, - output_signature, - shuffle, - batch_size, - drop_remainder, -): - """Create a tf.data.Dataset from the underlying Dataset. This is a single-process method - the multiprocess - equivalent is multiprocess_dataset_to_tf. - - Args: - dataset (`Dataset`): Dataset to wrap with tf.data.Dataset. - cols_to_retain (`List[str]`): Dataset column(s) to load in the - tf.data.Dataset. It is acceptable to include column names that are created by the `collate_fn` and - that do not exist in the original dataset. - collate_fn(`Callable`): A function or callable object (such as a `DataCollator`) that will collate - lists of samples into a batch. - collate_fn_args (`Dict`): A `dict` of keyword arguments to be passed to the - `collate_fn`. Can be empty. - columns_to_np_types (`Dict[str, np.dtype]`): A `dict` mapping column names to numpy dtypes. - output_signature (`Dict[str, tf.TensorSpec]`): A `dict` mapping column names to - `tf.TensorSpec` objects. - shuffle(`bool`): Shuffle the dataset order when loading. Recommended True for training, False for - validation/evaluation. - batch_size (`int`, default `None`): Size of batches to load from the dataset. Defaults to `None`, which implies that - the dataset won't be batched, but the returned dataset can be batched later with `tf_dataset.batch(batch_size)`. - drop_remainder(`bool`, default `None`): Drop the last incomplete batch when loading. If not provided, - defaults to the same setting as shuffle. - - Returns: - `tf.data.Dataset` - """ - if config.TF_AVAILABLE: - import tensorflow as tf - else: - raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") - - if hasattr(tf, "random_index_shuffle"): - random_index_shuffle = tf.random_index_shuffle - elif hasattr(tf.random.experimental, "index_shuffle"): - random_index_shuffle = tf.random.experimental.index_shuffle - else: - if len(dataset) > 10_000_000: - warnings.warn( - "to_tf_dataset() can be memory-inefficient on versions of TensorFlow older than 2.9. " - "If you are iterating over a dataset with a very large number of samples, consider " - "upgrading to TF >= 2.9." - ) - random_index_shuffle = None - - getter_fn = partial( - np_get_batch, - dataset=dataset, - cols_to_retain=cols_to_retain, - collate_fn=collate_fn, - collate_fn_args=collate_fn_args, - columns_to_np_types=columns_to_np_types, - return_dict=False, - ) - - # This works because dictionaries always output in the same order - tout = [tf.dtypes.as_dtype(dtype) for dtype in columns_to_np_types.values()] - - @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)]) - def fetch_function(indices): - output = tf.py_function( - getter_fn, - inp=[indices], - Tout=tout, - ) - return {key: output[i] for i, key in enumerate(columns_to_np_types.keys())} - - tf_dataset = tf.data.Dataset.range(len(dataset)) - - if shuffle and random_index_shuffle is not None: - tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) - base_seed = tf.fill((3,), value=tf.cast(-1, dtype=tf.int64)) - - def scan_random_indices(state, indices): - if tf.reduce_all(state == -1): - # This generates a new random seed once per epoch only, - # to ensure that we iterate over each sample exactly once per epoch - state = tf.random.uniform(shape=(3,), maxval=2**62, dtype=tf.int64) - shuffled_indices = random_index_shuffle(index=indices, seed=state, max_index=len(dataset) - 1) - return state, shuffled_indices - - tf_dataset = tf_dataset.scan(base_seed, scan_random_indices) - elif shuffle: - tf_dataset = tf_dataset.shuffle(len(tf_dataset)) - tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) - else: - tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) - - if batch_size is not None: - tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) - - tf_dataset = tf_dataset.map(fetch_function) - - if batch_size is not None: - - def ensure_shapes(input_dict): - return {key: tf.ensure_shape(val, output_signature[key].shape) for key, val in input_dict.items()} - - else: - # Ensure shape but remove batch dimension of output_signature[key].shape - def ensure_shapes(input_dict): - return {key: tf.ensure_shape(val, output_signature[key].shape[1:]) for key, val in input_dict.items()} - - return tf_dataset.map(ensure_shapes) - - class SharedMemoryContext: # This is a context manager for creating shared memory that ensures cleanup happens even if a process is interrupted # The process that creates shared memory is always the one responsible for unlinking it in the end From 3011d6240cf8826e152e966dc7e9d7a98d5750e6 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 23 May 2023 16:37:34 +0100 Subject: [PATCH 07/16] Revert "Switch to_tf_dataset entirely over to the NumPy multiprocessing approach" This reverts commit 95c177e02ca20bf7bb3ed8f185d2d6f05a5e5f30. --- src/datasets/arrow_dataset.py | 43 +++++++++----- src/datasets/utils/tf_utils.py | 104 +++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 15 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index ed2198dfccb..80b073e0f8c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -115,7 +115,7 @@ from .utils.metadata import DatasetMetadata from .utils.py_utils import Literal, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values from .utils.stratify import stratified_shuffle_split_generate_indices -from .utils.tf_utils import minimal_tf_collate_fn, multiprocess_dataset_to_tf +from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf from .utils.typing import ListLike, PathLike @@ -457,20 +457,33 @@ def to_tf_dataset( if col not in output_signature: raise ValueError(f"Label column {col} not found in dataset!") - if num_workers <= 0: - num_workers = 1 - tf_dataset = multiprocess_dataset_to_tf( - dataset=dataset, - cols_to_retain=cols_to_retain, - collate_fn=collate_fn, - collate_fn_args=collate_fn_args, - columns_to_np_types=columns_to_np_types, - output_signature=output_signature, - shuffle=shuffle, - batch_size=batch_size, - drop_remainder=drop_remainder, - num_workers=num_workers, - ) + if num_workers == 0: + tf_dataset = dataset_to_tf( + dataset=dataset, + cols_to_retain=cols_to_retain, + collate_fn=collate_fn, + collate_fn_args=collate_fn_args, + columns_to_np_types=columns_to_np_types, + output_signature=output_signature, + shuffle=shuffle, + batch_size=batch_size, + drop_remainder=drop_remainder, + ) + elif num_workers > 0: + tf_dataset = multiprocess_dataset_to_tf( + dataset=dataset, + cols_to_retain=cols_to_retain, + collate_fn=collate_fn, + collate_fn_args=collate_fn_args, + columns_to_np_types=columns_to_np_types, + output_signature=output_signature, + shuffle=shuffle, + batch_size=batch_size, + drop_remainder=drop_remainder, + num_workers=num_workers, + ) + else: + raise ValueError("num_workers must be >= 0") def split_features_and_labels(input_batch): # TODO(Matt, QL): deprecate returning the dict content when there's only one key diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index e1bc9b9848b..a7167e4aa83 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -15,6 +15,8 @@ """TF-specific utils import.""" import os +import warnings +from functools import partial from math import ceil from uuid import uuid4 @@ -131,6 +133,108 @@ def np_get_batch( return out_batch +def dataset_to_tf( + dataset, + cols_to_retain, + collate_fn, + collate_fn_args, + columns_to_np_types, + output_signature, + shuffle, + batch_size, + drop_remainder, +): + """Create a tf.data.Dataset from the underlying Dataset. This is a single-process method - the multiprocess + equivalent is multiprocess_dataset_to_tf. + + Args: + dataset (`Dataset`): Dataset to wrap with tf.data.Dataset. + cols_to_retain (`List[str]`): Dataset column(s) to load in the + tf.data.Dataset. It is acceptable to include column names that are created by the `collate_fn` and + that do not exist in the original dataset. + collate_fn(`Callable`): A function or callable object (such as a `DataCollator`) that will collate + lists of samples into a batch. + collate_fn_args (`Dict`): A `dict` of keyword arguments to be passed to the + `collate_fn`. Can be empty. + columns_to_np_types (`Dict[str, np.dtype]`): A `dict` mapping column names to numpy dtypes. + output_signature (`Dict[str, tf.TensorSpec]`): A `dict` mapping column names to + `tf.TensorSpec` objects. + shuffle(`bool`): Shuffle the dataset order when loading. Recommended True for training, False for + validation/evaluation. + batch_size (`int`): Size of batches to load from the dataset. + drop_remainder(`bool`, default `None`): Drop the last incomplete batch when loading. If not provided, + defaults to the same setting as shuffle. + + Returns: + `tf.data.Dataset` + """ + if config.TF_AVAILABLE: + import tensorflow as tf + else: + raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") + + if hasattr(tf, "random_index_shuffle"): + random_index_shuffle = tf.random_index_shuffle + elif hasattr(tf.random.experimental, "index_shuffle"): + random_index_shuffle = tf.random.experimental.index_shuffle + else: + if len(dataset) > 10_000_000: + warnings.warn( + "to_tf_dataset() can be memory-inefficient on versions of TensorFlow older than 2.9. " + "If you are iterating over a dataset with a very large number of samples, consider " + "upgrading to TF >= 2.9." + ) + random_index_shuffle = None + + getter_fn = partial( + np_get_batch, + dataset=dataset, + cols_to_retain=cols_to_retain, + collate_fn=collate_fn, + collate_fn_args=collate_fn_args, + columns_to_np_types=columns_to_np_types, + return_dict=False, # TF expects numpy_function to return a list and will not accept a dict + ) + + @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)]) + def fetch_function(indices): + output = tf.numpy_function( + getter_fn, + inp=[indices], + # This works because dictionaries always output in the same order + Tout=[tf.dtypes.as_dtype(dtype) for dtype in columns_to_np_types.values()], + ) + return {key: output[i] for i, key in enumerate(columns_to_np_types.keys())} + + tf_dataset = tf.data.Dataset.range(len(dataset)) + + if shuffle and random_index_shuffle is not None: + tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) + base_seed = tf.fill((3,), value=tf.cast(-1, dtype=tf.int64)) + + def scan_random_indices(state, indices): + if tf.reduce_all(state == -1): + # This generates a new random seed once per epoch only, + # to ensure that we iterate over each sample exactly once per epoch + state = tf.random.uniform(shape=(3,), maxval=2**62, dtype=tf.int64) + shuffled_indices = random_index_shuffle(index=indices, seed=state, max_index=len(dataset) - 1) + return state, shuffled_indices + + tf_dataset = tf_dataset.scan(base_seed, scan_random_indices) + elif shuffle: + tf_dataset = tf_dataset.shuffle(len(tf_dataset)) + tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) + else: + tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) + + tf_dataset = tf_dataset.map(fetch_function) + + def ensure_shapes(input_dict): + return {key: tf.ensure_shape(val, output_signature[key].shape) for key, val in input_dict.items()} + + return tf_dataset.map(ensure_shapes) + + class SharedMemoryContext: # This is a context manager for creating shared memory that ensures cleanup happens even if a process is interrupted # The process that creates shared memory is always the one responsible for unlinking it in the end From 3c54400d29a2f66a465b7190ca9c90115d928d4f Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 23 May 2023 17:00:52 +0100 Subject: [PATCH 08/16] Add explanatory comment --- src/datasets/utils/tf_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index a7167e4aa83..27fbd39ea37 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -173,6 +173,10 @@ def dataset_to_tf( else: raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") + # Matt: There are a few version dependencies here - ideally we'd move everything to the multiprocessing path for + # simplicity, but that depends on SharedMemory that only arrived in Py3.8. We also have a reasonably efficient + # solution without Python multiprocessing, but it depends on TF >= 2.9. If we're on an older version of TF, + # we fall back to the slowest path. Hopefully when our minimum versions move up a bit more we can clean this all up. if hasattr(tf, "random_index_shuffle"): random_index_shuffle = tf.random_index_shuffle elif hasattr(tf.random.experimental, "index_shuffle"): From 81761dbfa738354a9c50309313dfe90bea26d872 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 24 May 2023 16:42:20 +0100 Subject: [PATCH 09/16] TF 2.13 has a specific optimization for dataset.shuffle(dataset.cardinality()), so use that instead of dataset.shuffle(len(dataset)) --- src/datasets/utils/tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 27fbd39ea37..de3d1f83dd7 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -226,7 +226,7 @@ def scan_random_indices(state, indices): tf_dataset = tf_dataset.scan(base_seed, scan_random_indices) elif shuffle: - tf_dataset = tf_dataset.shuffle(len(tf_dataset)) + tf_dataset = tf_dataset.shuffle(tf_dataset.cardinality()) tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) else: tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) From 8907bdb23f78545303eb3bb0561e33ec6787f96c Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 7 Jun 2023 15:58:21 +0100 Subject: [PATCH 10/16] Fix a couple of rebase errors --- src/datasets/arrow_dataset.py | 5 +++++ src/datasets/utils/tf_utils.py | 10 ++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 80b073e0f8c..073e47aefca 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -470,6 +470,11 @@ def to_tf_dataset( drop_remainder=drop_remainder, ) elif num_workers > 0: + if batch_size is None: + raise NotImplementedError( + "`batch_size` must be specified when using multiple workers, as unbatched multiprocessing " + "is not supported yet. Please provide a `batch_size` if `num_workers` is greater than 0." + ) tf_dataset = multiprocess_dataset_to_tf( dataset=dataset, cols_to_retain=cols_to_retain, diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index de3d1f83dd7..2fbff587d31 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -197,16 +197,18 @@ def dataset_to_tf( collate_fn=collate_fn, collate_fn_args=collate_fn_args, columns_to_np_types=columns_to_np_types, - return_dict=False, # TF expects numpy_function to return a list and will not accept a dict + return_dict=False, ) + # This works because dictionaries always output in the same order + tout = [tf.dtypes.as_dtype(dtype) for dtype in columns_to_np_types.values()] + @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)]) def fetch_function(indices): - output = tf.numpy_function( + output = tf.py_function( getter_fn, inp=[indices], - # This works because dictionaries always output in the same order - Tout=[tf.dtypes.as_dtype(dtype) for dtype in columns_to_np_types.values()], + Tout=tout, ) return {key: output[i] for i, key in enumerate(columns_to_np_types.keys())} From f39ba76af62c8037de3f464e87cbb095f8729062 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 7 Jun 2023 16:08:26 +0100 Subject: [PATCH 11/16] More merging with the changes in main --- src/datasets/utils/tf_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 2fbff587d31..cdc5e0a5a92 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -215,7 +215,6 @@ def fetch_function(indices): tf_dataset = tf.data.Dataset.range(len(dataset)) if shuffle and random_index_shuffle is not None: - tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) base_seed = tf.fill((3,), value=tf.cast(-1, dtype=tf.int64)) def scan_random_indices(state, indices): @@ -229,14 +228,21 @@ def scan_random_indices(state, indices): tf_dataset = tf_dataset.scan(base_seed, scan_random_indices) elif shuffle: tf_dataset = tf_dataset.shuffle(tf_dataset.cardinality()) - tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) - else: + + if batch_size is not None: tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) tf_dataset = tf_dataset.map(fetch_function) - def ensure_shapes(input_dict): - return {key: tf.ensure_shape(val, output_signature[key].shape) for key, val in input_dict.items()} + if batch_size is not None: + + def ensure_shapes(input_dict): + return {key: tf.ensure_shape(val, output_signature[key].shape) for key, val in input_dict.items()} + + else: + # Ensure shape but remove batch dimension of output_signature[key].shape + def ensure_shapes(input_dict): + return {key: tf.ensure_shape(val, output_signature[key].shape[1:]) for key, val in input_dict.items()} return tf_dataset.map(ensure_shapes) From 323747a5ff7d9b204ea3c4989d658af7102f7bbd Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 7 Jun 2023 16:11:14 +0100 Subject: [PATCH 12/16] Fix some indents --- src/datasets/utils/tf_utils.py | 40 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index cdc5e0a5a92..3cd2ecdc7be 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -147,26 +147,26 @@ def dataset_to_tf( """Create a tf.data.Dataset from the underlying Dataset. This is a single-process method - the multiprocess equivalent is multiprocess_dataset_to_tf. - Args: - dataset (`Dataset`): Dataset to wrap with tf.data.Dataset. - cols_to_retain (`List[str]`): Dataset column(s) to load in the - tf.data.Dataset. It is acceptable to include column names that are created by the `collate_fn` and - that do not exist in the original dataset. - collate_fn(`Callable`): A function or callable object (such as a `DataCollator`) that will collate - lists of samples into a batch. - collate_fn_args (`Dict`): A `dict` of keyword arguments to be passed to the - `collate_fn`. Can be empty. - columns_to_np_types (`Dict[str, np.dtype]`): A `dict` mapping column names to numpy dtypes. - output_signature (`Dict[str, tf.TensorSpec]`): A `dict` mapping column names to - `tf.TensorSpec` objects. - shuffle(`bool`): Shuffle the dataset order when loading. Recommended True for training, False for - validation/evaluation. - batch_size (`int`): Size of batches to load from the dataset. - drop_remainder(`bool`, default `None`): Drop the last incomplete batch when loading. If not provided, - defaults to the same setting as shuffle. - - Returns: - `tf.data.Dataset` + Args: + dataset (`Dataset`): Dataset to wrap with tf.data.Dataset. + cols_to_retain (`List[str]`): Dataset column(s) to load in the + tf.data.Dataset. It is acceptable to include column names that are created by the `collate_fn` and + that do not exist in the original dataset. + collate_fn(`Callable`): A function or callable object (such as a `DataCollator`) that will collate + lists of samples into a batch. + collate_fn_args (`Dict`): A `dict` of keyword arguments to be passed to the + `collate_fn`. Can be empty. + columns_to_np_types (`Dict[str, np.dtype]`): A `dict` mapping column names to numpy dtypes. + output_signature (`Dict[str, tf.TensorSpec]`): A `dict` mapping column names to + `tf.TensorSpec` objects. + shuffle(`bool`): Shuffle the dataset order when loading. Recommended True for training, False for + validation/evaluation. + batch_size (`int`): Size of batches to load from the dataset. + drop_remainder(`bool`, default `None`): Drop the last incomplete batch when loading. If not provided, + defaults to the same setting as shuffle. + + Returns: + `tf.data.Dataset` """ if config.TF_AVAILABLE: import tensorflow as tf From e8f051a41454f8625091338e6b53119a5eb9b2a0 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 7 Jun 2023 16:11:53 +0100 Subject: [PATCH 13/16] Fix docstring merge --- src/datasets/utils/tf_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 3cd2ecdc7be..f2d8c6da430 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -161,7 +161,8 @@ def dataset_to_tf( `tf.TensorSpec` objects. shuffle(`bool`): Shuffle the dataset order when loading. Recommended True for training, False for validation/evaluation. - batch_size (`int`): Size of batches to load from the dataset. + batch_size (`int`, default `None`): Size of batches to load from the dataset. Defaults to `None`, which implies that + the dataset won't be batched, but the returned dataset can be batched later with `tf_dataset.batch(batch_size)`. drop_remainder(`bool`, default `None`): Drop the last incomplete batch when loading. If not provided, defaults to the same setting as shuffle. From 5dfcd876c25cc0ffbd6b5b518b017419390a8ada Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 7 Jun 2023 16:13:45 +0100 Subject: [PATCH 14/16] Add clearer TODO --- src/datasets/utils/tf_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index f2d8c6da430..4e496c197c7 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -174,10 +174,8 @@ def dataset_to_tf( else: raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") - # Matt: There are a few version dependencies here - ideally we'd move everything to the multiprocessing path for - # simplicity, but that depends on SharedMemory that only arrived in Py3.8. We also have a reasonably efficient - # solution without Python multiprocessing, but it depends on TF >= 2.9. If we're on an older version of TF, - # we fall back to the slowest path. Hopefully when our minimum versions move up a bit more we can clean this all up. + # TODO Matt: When our minimum Python version is 3.8 or higher, we can delete all of this and move everything + # to the NumPy multiprocessing path. if hasattr(tf, "random_index_shuffle"): random_index_shuffle = tf.random_index_shuffle elif hasattr(tf.random.experimental, "index_shuffle"): From b4cc3ee6d8945052283076854eb77575d52b7432 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 7 Jun 2023 16:18:39 +0100 Subject: [PATCH 15/16] Rename indices -> index to be clearer what the function does now --- src/datasets/utils/tf_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 4e496c197c7..b69f5c85b2c 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -216,15 +216,15 @@ def fetch_function(indices): if shuffle and random_index_shuffle is not None: base_seed = tf.fill((3,), value=tf.cast(-1, dtype=tf.int64)) - def scan_random_indices(state, indices): + def scan_random_index(state, index): if tf.reduce_all(state == -1): # This generates a new random seed once per epoch only, # to ensure that we iterate over each sample exactly once per epoch state = tf.random.uniform(shape=(3,), maxval=2**62, dtype=tf.int64) - shuffled_indices = random_index_shuffle(index=indices, seed=state, max_index=len(dataset) - 1) - return state, shuffled_indices + shuffled_index = random_index_shuffle(index=index, seed=state, max_index=len(dataset) - 1) + return state, shuffled_index - tf_dataset = tf_dataset.scan(base_seed, scan_random_indices) + tf_dataset = tf_dataset.scan(base_seed, scan_random_index) elif shuffle: tf_dataset = tf_dataset.shuffle(tf_dataset.cardinality()) From c14806a42a20f44a60f3663642bae1de199ab1ec Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 7 Jun 2023 16:54:08 +0100 Subject: [PATCH 16/16] Expand test to make sure shuffling is working correctly --- tests/test_arrow_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index ac28c47e323..3846d43498d 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2755,6 +2755,8 @@ def test_tf_index_reshuffling(self, in_memory): second_indices.append(batch["col_1"]) second_indices = np.concatenate([arr.numpy() for arr in second_indices]) self.assertFalse(np.array_equal(indices, second_indices)) + self.assertEqual(len(indices), len(np.unique(indices))) + self.assertEqual(len(second_indices), len(np.unique(second_indices))) tf_dataset = dset.to_tf_dataset(batch_size=1, shuffle=False, num_workers=num_workers) for i, batch in enumerate(tf_dataset):