From 11706180e6eaca2ed7bc069d1eb5c864a898efda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 15 May 2024 08:19:05 -0400 Subject: [PATCH] fix --- src/lightning/fabric/strategies/model_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 18ca2d8785e681..6200be8092c73e 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -484,7 +484,7 @@ def _load_checkpoint( raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.") checkpoint = torch.load(path, mmap=True, map_location="cpu") - _load_raw_module_state(checkpoint.pop(module_key), module, world_size=1, strict=strict) + _load_raw_module_state(checkpoint.pop(module_key), module, world_size=self.world_size, strict=strict) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)