Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 3, 2024
1 parent 664dd48 commit 84359dd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ It is possible to convert a distributed checkpoint to a regular, single-file che

.. code-block:: bash
python -m lightning.fabric.utilities.consolidate_checkpoint path/to/my/checkpoint
fabric consolidate path/to/my/checkpoint
You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc.

Expand All @@ -202,7 +202,7 @@ You will need to do this for example if you want to load the checkpoint into a s

.. code-block:: bash
python -m lightning.fabric.utilities.consolidate_checkpoint my-checkpoint.ckpt
fabric consolidate my-checkpoint.ckpt
This saves a new file ``my-checkpoint.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch:

Expand Down
27 changes: 23 additions & 4 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import re
Expand All @@ -21,16 +22,16 @@

import torch

from fabric.utilities.consolidate_checkpoint import _process_cli_args
from fabric.utilities.load import _load_distributed_checkpoint
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import get_args

from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.distributed import _suggested_max_num_threads
from lightning.fabric.utilities.load import _load_distributed_checkpoint

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,8 +166,26 @@ def _run(**kwargs: Any) -> None:
"ignore_unknown_options": True,
},
)
def _consolidate() -> None:
"""Consolidate a distributed checkpoint into a single file."""
@click.argument(
"checkpoint_folder",
type=click.Path(exists=True),
)
@click.option(
"--output_file",
type=click.Path(exists=True),
default=None,
help=(
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
" and a '.consolidated' suffix."
),
)
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None:
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
Only supports FSDP sharded checkpoints at the moment.
"""
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
config = _process_cli_args(args)
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
torch.save(checkpoint, config.output_file)
Expand Down

0 comments on commit 84359dd

Please sign in to comment.