From a90fec9fc5f89809727260f84f37e42d46751605 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Fri, 9 Feb 2024 23:21:28 +0800 Subject: [PATCH] bugfix: correct node rank (#19437) (cherry picked from commit 7b867c7d91ad8f154e9c5b4656271a6329a8b059) --- src/lightning/data/utilities/shuffle.py | 3 ++- tests/tests_data/{streaming => utilities}/test_shuffle.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) rename tests/tests_data/{streaming => utilities}/test_shuffle.py (91%) diff --git a/src/lightning/data/utilities/shuffle.py b/src/lightning/data/utilities/shuffle.py index 8be6563096ed6..4dc34a7dc2302 100644 --- a/src/lightning/data/utilities/shuffle.py +++ b/src/lightning/data/utilities/shuffle.py @@ -12,8 +12,9 @@ def _intra_node_chunk_shuffle( current_epoch: int, ) -> List[int]: chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)] + process_per_node = distributed_env.world_size // distributed_env.num_nodes for rank, chunks_per_rank in enumerate(chunks_per_ranks): - chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend( + chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // process_per_node].extend( chunks_per_rank ) diff --git a/tests/tests_data/streaming/test_shuffle.py b/tests/tests_data/utilities/test_shuffle.py similarity index 91% rename from tests/tests_data/streaming/test_shuffle.py rename to tests/tests_data/utilities/test_shuffle.py index cb451ce73a1ec..9951a21c603ab 100644 --- a/tests/tests_data/streaming/test_shuffle.py +++ b/tests/tests_data/utilities/test_shuffle.py @@ -12,6 +12,11 @@ def test_intra_node_chunk_shuffle(): assert shuffled_indexes == [3, 2, 1, 0, 7, 6, 5, 4] + chunks_per_ranks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] + shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(8, 7, 2), chunks_per_ranks, 42, 2) + assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4, 13, 10, 8, 15, 14, 9, 11, 12] + + def test_associate_chunks_and_internals_to_ranks(): indexes = [0, 1, 2, 3, 4, 5, 6, 7] chunk_intervals = [[0, 50], [0, 50], [0, 50], [0, 50], [0, 50], [0, 50], [0, 50], [0, 50]]