diff --git a/docs/source-fabric/fundamentals/launch.rst b/docs/source-fabric/fundamentals/launch.rst index 0f25ebcda6021..d30a75621300c 100644 --- a/docs/source-fabric/fundamentals/launch.rst +++ b/docs/source-fabric/fundamentals/launch.rst @@ -67,7 +67,7 @@ An alternative way to launch your Python script in multiple processes is to use .. code-block:: bash - lightning run model path/to/your/script.py + fabric run model path/to/your/script.py This is essentially the same as running ``python path/to/your/script.py``, but it also lets you configure the following settings externally without changing your code: @@ -80,9 +80,9 @@ This is essentially the same as running ``python path/to/your/script.py``, but i .. code-block:: bash - lightning run model --help + fabric run model --help - Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]... + Usage: fabric run model [OPTIONS] SCRIPT [SCRIPT_ARGS]... Run a Lightning Fabric script. @@ -128,7 +128,7 @@ Here is how you run DDP with 8 GPUs and `torch.bfloat16 `_ w .. code-block:: bash - lightning run model ./path/to/train.py \ + fabric run model ./path/to/train.py \ --strategy=deepspeed_stage_3 \ --devices=8 \ --accelerator=cuda \ @@ -148,7 +148,7 @@ Or `DeepSpeed Zero3 `_ w .. code-block:: bash - lightning run model ./path/to/train.py \ + fabric run model ./path/to/train.py \ --devices=auto \ --accelerator=auto \ --precision=16 diff --git a/src/lightning/__setup__.py b/src/lightning/__setup__.py index b7e57309eb674..b9b19143bbe13 100644 --- a/src/lightning/__setup__.py +++ b/src/lightning/__setup__.py @@ -114,6 +114,8 @@ def _setup_args() -> Dict[str, Any]: "python_requires": ">=3.8", # todo: take the lowes based on all packages "entry_points": { "console_scripts": [ + "fabric = lightning.fabric.cli:_main", + "lightning = lightning.fabric.cli:_legacy_main", "lightning_app = lightning:_cli_entry_point", ], }, diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 9e32953100bfa..ab6f0f4971f27 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Rename `lightning run model` to `fabric run model` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442)) - diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 0ff777232a7fb..2805675cecf70 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -14,6 +14,8 @@ import logging import os import re +import subprocess +import sys from argparse import Namespace from typing import Any, List, Optional @@ -29,6 +31,7 @@ _log = logging.getLogger(__name__) _CLICK_AVAILABLE = RequirementCache("click") +_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") _SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") @@ -44,7 +47,32 @@ def _get_supported_strategies() -> List[str]: if _CLICK_AVAILABLE: import click - @click.command( + def _legacy_main() -> None: + """Legacy CLI handler for fabric. + + Raises deprecation warning and runs through fabric cli if necessary, else runs the entrypoint directly + + """ + print("`lightning run model` is deprecated and will be removed in future versions." + " Please call `fabric run model` instead.") + args = sys.argv[1:] + if args and args[0] == "run" and args[1] == "model": + _main() + return + + if _LIGHTNING_SDK_AVAILABLE: + subprocess.run([sys.executable, "-m", "lightning_sdk.cli.entrypoint"] + args) + return + + @click.group() + def _main() -> None: + pass + + @_main.group() + def run() -> None: + pass + + @run.command( "model", context_settings={ "ignore_unknown_options": True, diff --git a/src/lightning_fabric/__setup__.py b/src/lightning_fabric/__setup__.py index 869ffad571f7e..8fe0bc0937ef5 100644 --- a/src/lightning_fabric/__setup__.py +++ b/src/lightning_fabric/__setup__.py @@ -78,6 +78,11 @@ def _setup_args() -> Dict[str, Any]: "install_requires": assistant.load_requirements( _PATH_REQUIREMENTS, unfreeze="none" if _FREEZE_REQUIREMENTS else "all" ), + "entry_points": { + "console_scripts": [ + "fabric = lightning_fabric.cli:_main", + ], + }, "extras_require": _prepare_extras(), "project_urls": { "Bug Tracker": "https://github.com/Lightning-AI/lightning/issues", diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index c8e9a6bbedd57..596318b4b619f 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -21,7 +21,6 @@ import pytest from lightning.fabric.cli import _get_supported_strategies, _run_model -from lightning_utilities.core.imports import ModuleAvailableCache from tests_fabric.helpers.runif import RunIf @@ -176,13 +175,20 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch, ) +def test_cli_through_fabric_entry_point(): + result = subprocess.run("fabric run model --help", capture_output=True, text=True, shell=True) + + message = "Usage: fabric run model [OPTIONS] SCRIPT [SCRIPT_ARGS]" + assert message in result.stdout or message in result.stderr + @pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package") def test_cli_through_lightning_entry_point(): result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True) - if not ModuleAvailableCache("lightning.app"): - message = "The `lightning` command requires additional dependencies" - assert message in result.stdout or message in result.stderr - assert result.returncode != 0 - else: - message = "Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]" - assert message in result.stdout or message in result.stderr + + deprecation_message = ( + "`lightning run model` is deprecated and will be removed in future versions. " + "Please call `fabric run model` instead" + ) + message = "Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]" + assert deprecation_message in result.stdout + assert message in result.stdout or message in result.stderr