diff --git a/src/lightning/data/utilities/shuffle.py b/src/lightning/data/utilities/shuffle.py index 8be6563096ed6..4dc34a7dc2302 100644 --- a/src/lightning/data/utilities/shuffle.py +++ b/src/lightning/data/utilities/shuffle.py @@ -12,8 +12,9 @@ def _intra_node_chunk_shuffle( current_epoch: int, ) -> List[int]: chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)] + process_per_node = distributed_env.world_size // distributed_env.num_nodes for rank, chunks_per_rank in enumerate(chunks_per_ranks): - chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend( + chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // process_per_node].extend( chunks_per_rank ) diff --git a/tests/tests_data/streaming/test_shuffle.py b/tests/tests_data/utilities/test_shuffle.py similarity index 91% rename from tests/tests_data/streaming/test_shuffle.py rename to tests/tests_data/utilities/test_shuffle.py index cb451ce73a1ec..9951a21c603ab 100644 --- a/tests/tests_data/streaming/test_shuffle.py +++ b/tests/tests_data/utilities/test_shuffle.py @@ -12,6 +12,11 @@ def test_intra_node_chunk_shuffle(): assert shuffled_indexes == [3, 2, 1, 0, 7, 6, 5, 4] + chunks_per_ranks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] + shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(8, 7, 2), chunks_per_ranks, 42, 2) + assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4, 13, 10, 8, 15, 14, 9, 11, 12] + + def test_associate_chunks_and_internals_to_ranks(): indexes = [0, 1, 2, 3, 4, 5, 6, 7] chunk_intervals = [[0, 50], [0, 50], [0, 50], [0, 50], [0, 50], [0, 50], [0, 50], [0, 50]]