Skip to content

Commit

Permalink
Rename Lightning Fabric CLI (#19442)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 12, 2024
1 parent 47c8f4c commit 2ed7282
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 16 deletions.
12 changes: 6 additions & 6 deletions docs/source-fabric/fundamentals/launch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ An alternative way to launch your Python script in multiple processes is to use

.. code-block:: bash
lightning run model path/to/your/script.py
fabric run model path/to/your/script.py
This is essentially the same as running ``python path/to/your/script.py``, but it also lets you configure the following settings externally without changing your code:

Expand All @@ -80,9 +80,9 @@ This is essentially the same as running ``python path/to/your/script.py``, but i

.. code-block:: bash
lightning run model --help
fabric run model --help
Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]...
Usage: fabric run model [OPTIONS] SCRIPT [SCRIPT_ARGS]...
Run a Lightning Fabric script.
Expand Down Expand Up @@ -128,7 +128,7 @@ Here is how you run DDP with 8 GPUs and `torch.bfloat16 <https://pytorch.org/doc

.. code-block:: bash
lightning run model ./path/to/train.py \
fabric run model ./path/to/train.py \
--strategy=ddp \
--devices=8 \
--accelerator=cuda \
Expand All @@ -138,7 +138,7 @@ Or `DeepSpeed Zero3 <https://www.deepspeed.ai/2021/03/07/zero3-offload.html>`_ w

.. code-block:: bash
lightning run model ./path/to/train.py \
fabric run model ./path/to/train.py \
--strategy=deepspeed_stage_3 \
--devices=8 \
--accelerator=cuda \
Expand All @@ -148,7 +148,7 @@ Or `DeepSpeed Zero3 <https://www.deepspeed.ai/2021/03/07/zero3-offload.html>`_ w

.. code-block:: bash
lightning run model ./path/to/train.py \
fabric run model ./path/to/train.py \
--devices=auto \
--accelerator=auto \
--precision=16
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/__setup__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def _setup_args() -> Dict[str, Any]:
"python_requires": ">=3.8", # todo: take the lowes based on all packages
"entry_points": {
"console_scripts": [
"fabric = lightning.fabric.cli:_main",
"lightning = lightning.fabric.cli:_legacy_main",
"lightning_app = lightning:_cli_entry_point",
],
},
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Rename `lightning run model` to `fabric run model` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442))

-

Expand Down
30 changes: 29 additions & 1 deletion src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import logging
import os
import re
import subprocess
import sys
from argparse import Namespace
from typing import Any, List, Optional

Expand All @@ -29,6 +31,7 @@
_log = logging.getLogger(__name__)

_CLICK_AVAILABLE = RequirementCache("click")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")

Expand All @@ -44,7 +47,32 @@ def _get_supported_strategies() -> List[str]:
if _CLICK_AVAILABLE:
import click

@click.command(
def _legacy_main() -> None:
"""Legacy CLI handler for fabric.
Raises deprecation warning and runs through fabric cli if necessary, else runs the entrypoint directly
"""
print("`lightning run model` is deprecated and will be removed in future versions."
" Please call `fabric run model` instead.")
args = sys.argv[1:]
if args and args[0] == "run" and args[1] == "model":
_main()
return

if _LIGHTNING_SDK_AVAILABLE:
subprocess.run([sys.executable, "-m", "lightning_sdk.cli.entrypoint"] + args)
return

@click.group()
def _main() -> None:
pass

@_main.group()
def run() -> None:
pass

@run.command(
"model",
context_settings={
"ignore_unknown_options": True,
Expand Down
5 changes: 5 additions & 0 deletions src/lightning_fabric/__setup__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def _setup_args() -> Dict[str, Any]:
"install_requires": assistant.load_requirements(
_PATH_REQUIREMENTS, unfreeze="none" if _FREEZE_REQUIREMENTS else "all"
),
"entry_points": {
"console_scripts": [
"fabric = lightning_fabric.cli:_main",
],
},
"extras_require": _prepare_extras(),
"project_urls": {
"Bug Tracker": "https://github.com/Lightning-AI/lightning/issues",
Expand Down
22 changes: 14 additions & 8 deletions tests/tests_fabric/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import pytest
from lightning.fabric.cli import _get_supported_strategies, _run_model
from lightning_utilities.core.imports import ModuleAvailableCache

from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -176,13 +175,20 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
)


def test_cli_through_fabric_entry_point():
result = subprocess.run("fabric run model --help", capture_output=True, text=True, shell=True)

message = "Usage: fabric run model [OPTIONS] SCRIPT [SCRIPT_ARGS]"
assert message in result.stdout or message in result.stderr

@pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package")
def test_cli_through_lightning_entry_point():
result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True)
if not ModuleAvailableCache("lightning.app"):
message = "The `lightning` command requires additional dependencies"
assert message in result.stdout or message in result.stderr
assert result.returncode != 0
else:
message = "Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]"
assert message in result.stdout or message in result.stderr

deprecation_message = (
"`lightning run model` is deprecated and will be removed in future versions. "
"Please call `fabric run model` instead"
)
message = "Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]"
assert deprecation_message in result.stdout
assert message in result.stdout or message in result.stderr

0 comments on commit 2ed7282

Please sign in to comment.