diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index 5560d114d9369..82007290b0f94 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -20,7 +20,7 @@ from unittest.mock import Mock import pytest -from lightning.fabric.cli import _get_supported_strategies, _run +from lightning.fabric.cli import _get_supported_strategies, _run, _consolidate from tests_fabric.helpers.runif import RunIf @@ -33,7 +33,7 @@ def fake_script(tmp_path): @mock.patch.dict(os.environ, os.environ.copy(), clear=True) -def test_cli_env_vars_defaults(monkeypatch, fake_script): +def test_run_env_vars_defaults(monkeypatch, fake_script): monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock()) with pytest.raises(SystemExit) as e: _run.main([fake_script]) @@ -49,7 +49,7 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script): @pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2) -def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script): +def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script): monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock()) with pytest.raises(SystemExit) as e: _run.main([fake_script, "--accelerator", accelerator]) @@ -60,7 +60,7 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script): @pytest.mark.parametrize("strategy", _get_supported_strategies()) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2) -def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script): +def test_run_env_vars_strategy(_, strategy, monkeypatch, fake_script): monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock()) with pytest.raises(SystemExit) as e: _run.main([fake_script, "--strategy", strategy]) @@ -68,7 +68,7 @@ def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script): assert os.environ["LT_STRATEGY"] == strategy -def test_cli_get_supported_strategies(): +def test_run_get_supported_strategies(): """Test to ensure that when new strategies get added, we must consider updating the list of supported ones in the CLI.""" assert len(_get_supported_strategies()) == 7 @@ -76,7 +76,7 @@ def test_cli_get_supported_strategies(): @pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp_fork", "ddp_notebook", "deepspeed_stage_3_offload"]) -def test_cli_env_vars_unsupported_strategy(strategy, fake_script): +def test_run_env_vars_unsupported_strategy(strategy, fake_script): ioerr = StringIO() with pytest.raises(SystemExit) as e, contextlib.redirect_stderr(ioerr): _run.main([fake_script, "--strategy", strategy]) @@ -87,7 +87,7 @@ def test_cli_env_vars_unsupported_strategy(strategy, fake_script): @pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2) -def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script): +def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script): monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock()) with pytest.raises(SystemExit) as e: _run.main([fake_script, "--accelerator", "cuda", "--devices", devices]) @@ -98,7 +98,7 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script): @RunIf(mps=True) @pytest.mark.parametrize("accelerator", ["mps", "gpu"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) -def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script): +def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script): monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock()) with pytest.raises(SystemExit) as e: _run.main([fake_script, "--accelerator", accelerator]) @@ -108,7 +108,7 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script): @pytest.mark.parametrize("num_nodes", ["1", "2", "3"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) -def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script): +def test_run_env_vars_num_nodes(num_nodes, monkeypatch, fake_script): monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock()) with pytest.raises(SystemExit) as e: _run.main([fake_script, "--num-nodes", num_nodes]) @@ -118,7 +118,7 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script): @pytest.mark.parametrize("precision", ["64-true", "64", "32-true", "32", "16-mixed", "bf16-mixed"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) -def test_cli_env_vars_precision(precision, monkeypatch, fake_script): +def test_run_env_vars_precision(precision, monkeypatch, fake_script): monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock()) with pytest.raises(SystemExit) as e: _run.main([fake_script, "--precision", precision]) @@ -127,7 +127,7 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script): @mock.patch.dict(os.environ, os.environ.copy(), clear=True) -def test_cli_torchrun_defaults(monkeypatch, fake_script): +def test_run_torchrun_defaults(monkeypatch, fake_script): torchrun_mock = Mock() monkeypatch.setitem(sys.modules, "torch.distributed.run", torchrun_mock) with pytest.raises(SystemExit) as e: @@ -155,7 +155,7 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script): ) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5) -def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch, fake_script): +def test_run_torchrun_num_processes_launched(_, devices, expected, monkeypatch, fake_script): torchrun_mock = Mock() monkeypatch.setitem(sys.modules, "torch.distributed.run", torchrun_mock) with pytest.raises(SystemExit) as e: @@ -171,7 +171,7 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch, ]) -def test_cli_through_fabric_entry_point(): +def test_run_through_fabric_entry_point(): result = subprocess.run("fabric run --help", capture_output=True, text=True, shell=True) message = "Usage: fabric run [OPTIONS] SCRIPT [SCRIPT_ARGS]" @@ -179,7 +179,7 @@ def test_cli_through_fabric_entry_point(): @pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package") -def test_cli_through_lightning_entry_point(): +def test_run_through_lightning_entry_point(): result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True) deprecation_message = ( @@ -189,3 +189,22 @@ def test_cli_through_lightning_entry_point(): message = "Usage: lightning run [OPTIONS] SCRIPT [SCRIPT_ARGS]" assert deprecation_message in result.stdout assert message in result.stdout or message in result.stderr + + +@mock.patch("lightning.fabric.cli._process_cli_args") +@mock.patch("lightning.fabric.cli._load_distributed_checkpoint") +@mock.patch("lightning.fabric.cli.torch.save") +def test_consolidate(save_mock, _, __, tmp_path): + ioerr = StringIO() + with pytest.raises(SystemExit) as e, contextlib.redirect_stderr(ioerr): + _consolidate.main(["not exist"]) + assert e.value.code == 2 + assert f"Path 'not exist' does not exist" in ioerr.getvalue() + + checkpoint_folder = tmp_path / "checkpoint" + checkpoint_folder.mkdir() + ioerr = StringIO() + with pytest.raises(SystemExit) as e, contextlib.redirect_stderr(ioerr): + _consolidate.main([str(checkpoint_folder)]) + assert e.value.code == 0 + save_mock.assert_called_once()