From 31ae21ff806c3d1fc19a48ce41178c82d2f69368 Mon Sep 17 00:00:00 2001 From: kq-chen Date: Sat, 2 Mar 2024 03:02:33 +0800 Subject: [PATCH] Fix for Incorrect ex_iterable used with multi num_worker (#6582) Corrects an issue where `self._ex_iterable` was erroneously used instead of `ex_iterable`, when both Distributed Data Parallel (DDP) and multi num_worker are used concurrently. This improper usage led to the generation of incorrect `shards_indices`, subsequently causing issues with the control flow responsible for worker creation. The fix ensures the appropriate iterable is used, thus providing a more accurate determination of whether a new worker should be instantiated or not. --- src/datasets/iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index f508a0b5271..31329de9c31 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1275,7 +1275,7 @@ def _iter_pytorch(self): ) # split workload _log_prefix = f"node#{self._distributed.rank} " if self._distributed else "" - shards_indices = self._ex_iterable.split_shard_indices_by_worker(worker_info.id, worker_info.num_workers) + shards_indices = ex_iterable.split_shard_indices_by_worker(worker_info.id, worker_info.num_workers) if shards_indices: logger.debug( f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.n_shards} shards."