Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jan 31, 2023
1 parent 9068ff0 commit f977a2b
Show file tree
Hide file tree
Showing 17 changed files with 101 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.multiprocessing as mp
from typing_extensions import Literal

from lightning_fabric.strategies.launchers.base import _Launcher
from lightning_fabric.strategies.launchers.launcher import _Launcher
from lightning_fabric.utilities.apply_func import move_data_to_device
from lightning_fabric.utilities.imports import _IS_INTERACTIVE, _TORCH_GREATER_EQUAL_1_11
from lightning_fabric.utilities.seed import _collect_rng_states, _set_rng_states
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lightning_utilities.core.imports import RequirementCache

from lightning_fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_fabric.strategies.launchers.base import _Launcher
from lightning_fabric.strategies.launchers.launcher import _Launcher

_HYDRA_AVAILABLE = RequirementCache("hydra-core")

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_fabric/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.multiprocessing import get_context

from lightning_fabric.accelerators.tpu import _XLA_AVAILABLE
from lightning_fabric.strategies.launchers.base import _Launcher
from lightning_fabric.strategies.launchers.launcher import _Launcher
from lightning_fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot
from lightning_fabric.utilities.apply_func import move_data_to_device

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning_fabric.plugins.io.torch_io import TorchCheckpointIO
from lightning_fabric.plugins.precision import Precision
from lightning_fabric.strategies.launchers.base import _Launcher
from lightning_fabric.strategies.launchers.launcher import _Launcher
from lightning_fabric.utilities.apply_func import move_data_to_device
from lightning_fabric.utilities.types import _PATH, _Stateful, Optimizable, ReduceOp

Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a `Trainer.received_sigterm` property to check whether a SIGTERM signal was received ([#16501](https://github.com/Lightning-AI/lightning/pull/16501))

- Added support for cascading a SIGTERM signal to launched subprocesses after the launching process (rank 0) receives it ([#16525](https://github.com/Lightning-AI/lightning/pull/16525))
- Added support for cascading a SIGTERM signal to launched processes after the launching process (rank 0) receives it ([#16525](https://github.com/Lightning-AI/lightning/pull/16525))

- Added a `kill` method to launchers to kill all launched processes ([#16525](https://github.com/Lightning-AI/lightning/pull/16525))

### Changed

Expand Down
22 changes: 0 additions & 22 deletions src/pytorch_lightning/strategies/launchers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +0,0 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.launchers.xla import _XLALauncher

__all__ = [
"_MultiProcessingLauncher",
"_SubprocessScriptLauncher",
"_XLALauncher",
]
23 changes: 23 additions & 0 deletions src/pytorch_lightning/strategies/launchers/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod

from lightning_fabric.strategies.launchers.launcher import _Launcher as _FabricLauncher
from pytorch_lightning.trainer.connectors.signal_connector import _SIGNUM


class _Launcher(_FabricLauncher, ABC):
@abstractmethod
def kill(self, signum: _SIGNUM) -> None:
"""Kill existing alive processes."""
22 changes: 19 additions & 3 deletions src/pytorch_lightning/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import tempfile
from collections import UserList
from dataclasses import dataclass
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, NamedTuple, Optional
from typing import Any, Callable, Dict, List, NamedTuple, Optional

import numpy as np
import torch
Expand All @@ -27,15 +28,18 @@
from typing_extensions import Literal

import pytorch_lightning as pl
from lightning_fabric.strategies.launchers.base import _Launcher
from lightning_fabric.strategies.launchers.launcher import _Launcher
from lightning_fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from lightning_fabric.utilities.seed import _collect_rng_states, _set_rng_states
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.trainer.connectors.signal_connector import _SIGNUM
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
from pytorch_lightning.utilities.rank_zero import rank_zero_debug

log = logging.getLogger(__name__)


class _MultiProcessingLauncher(_Launcher):
r"""Launches processes that run a given function in parallel, and joins them all at the end.
Expand Down Expand Up @@ -70,6 +74,7 @@ def __init__(
f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
f" {', '.join(mp.get_all_start_methods())}"
)
self.procs: List[mp.Process] = []

@property
def is_interactive_compatible(self) -> bool:
Expand Down Expand Up @@ -110,12 +115,14 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
else:
process_args = [trainer, function, args, kwargs, return_queue]

mp.start_processes(
process_context = mp.start_processes(
self._wrapping_function,
args=process_args,
nprocs=self._strategy.num_processes,
start_method=self._start_method,
)
self.procs = process_context.processes

worker_output = return_queue.get()
if trainer is None:
return worker_output
Expand Down Expand Up @@ -226,6 +233,15 @@ def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
callback_metrics: dict = queue.get()
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))

def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
if proc.is_alive():
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
try:
os.kill(proc.pid, signum)
except ProcessLookupError:
log.info("process {proc.id} already exited")


class _FakeQueue(UserList):
"""Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list."""
Expand Down
11 changes: 10 additions & 1 deletion src/pytorch_lightning/strategies/launchers/subprocess_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import subprocess
from typing import Any, Callable, List, Optional
Expand All @@ -19,9 +20,11 @@

import pytorch_lightning as pl
from lightning_fabric.plugins import ClusterEnvironment
from lightning_fabric.strategies.launchers.base import _Launcher
from lightning_fabric.strategies.launchers.launcher import _Launcher
from lightning_fabric.strategies.launchers.subprocess_script import _basic_subprocess_cmd, _hydra_subprocess_cmd
from pytorch_lightning.trainer.connectors.signal_connector import _SIGNUM

log = logging.getLogger(__name__)
_HYDRA_AVAILABLE = RequirementCache("hydra-core")


Expand Down Expand Up @@ -88,6 +91,12 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
self._call_children_scripts()
return function(*args, **kwargs)

def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
# this skips subprocesses already terminated
proc.send_signal(signum)

def _call_children_scripts(self) -> None:
# bookkeeping of spawned processes
self._check_can_spawn_children()
Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_lightning/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,14 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
return_queue = context.SimpleQueue()
import torch_xla.distributed.xla_multiprocessing as xmp

xmp.spawn(
process_context = xmp.spawn(
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
nprocs=self._strategy.num_processes,
start_method=self._start_method,
)
self.procs = process_context.processes

worker_output = return_queue.get()
if trainer is None:
return worker_output
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import pytorch_lightning as pl
from lightning_fabric.plugins import CheckpointIO
from lightning_fabric.strategies.launchers.base import _Launcher
from lightning_fabric.strategies.launchers.launcher import _Launcher
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.distributed import ReduceOp
from lightning_fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device
Expand Down
13 changes: 4 additions & 9 deletions src/pytorch_lightning/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pytorch_lightning as pl
from lightning_fabric.plugins.environments import SLURMEnvironment
from lightning_fabric.utilities.imports import _IS_WINDOWS
from pytorch_lightning.strategies.launchers import _SubprocessScriptLauncher
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from pytorch_lightning.utilities.rank_zero import rank_zero_info

Expand Down Expand Up @@ -107,17 +106,13 @@ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
log.warning("requeue failed...")

def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None:
local_rank = self.trainer.local_rank
log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", local_rank))
log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank))
# subprocesses killing the parent process is not supported, only the parent (rank 0) does it
if not self.received_sigterm and local_rank == 0:
if not self.received_sigterm:
# send the same signal to the subprocesses
launcher = self.trainer.strategy.launcher
if isinstance(launcher, _SubprocessScriptLauncher):
for proc in launcher.procs:
if proc.poll() is None: # process hasn't terminated
log.debug(f"pid {os.getpid()} killing {proc.pid} with {signum}")
os.kill(proc.pid, signum)
if launcher is not None:
launcher.kill(signum)
self.received_sigterm = True

def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/utilities/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_validate_client_command():
with pytest.raises(Exception, match="annotate your method"):
_validate_client_command(ClientCommand(run_failure_1))

with pytest.raises(Exception, match="lightning_app/utilities/commands/base.py"):
with pytest.raises(Exception, match="lightning_app/utilities/commands/launcher.py"):
_validate_client_command(ClientCommand(run_failure_2))


Expand Down
14 changes: 13 additions & 1 deletion tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from multiprocessing import Process
from unittest import mock
from unittest.mock import ANY, Mock
from unittest.mock import ANY, call, Mock, patch

import pytest
import torch
Expand Down Expand Up @@ -163,3 +164,14 @@ def test_non_strict_loading(tmpdir):
# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
model.load_state_dict.assert_called_once_with(ANY, strict=False)


def test_kill():
launcher = _MultiProcessingLauncher(Mock())
proc0 = Mock(autospec=Process)
proc1 = Mock(autospec=Process)
launcher.procs = [proc0, proc1]

with patch("os.kill") as kill_patch:
launcher.kill(15)
assert kill_patch.mock_calls == [call(proc0.pid, 15), call(proc1.pid, 15)]
14 changes: 14 additions & 0 deletions tests/tests_pytorch/strategies/launchers/test_subprocess_script.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import subprocess
import sys
from unittest.mock import Mock

import pytest
from lightning_utilities.core.imports import RequirementCache

from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from tests_pytorch.helpers.runif import RunIf

_HYDRA_WITH_RERUN = RequirementCache("hydra-core>=1.2")
Expand Down Expand Up @@ -59,3 +62,14 @@ def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
if subdir is not None:
cmd += [f"hydra.output_subdir={subdir}"]
run_process(cmd)


def test_kill():
launcher = _SubprocessScriptLauncher(Mock(), 1, 1)
proc0 = Mock(autospec=subprocess.Popen)
proc1 = Mock(autospec=subprocess.Popen)
launcher.procs = [proc0, proc1]

launcher.kill(15)
proc0.send_signal.assert_called_once_with(15)
proc1.send_signal.assert_called_once_with(15)
22 changes: 6 additions & 16 deletions tests/tests_pytorch/trainer/connectors/test_signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import concurrent.futures
import os
import signal
import subprocess
from unittest import mock
from unittest.mock import Mock

Expand All @@ -24,7 +23,6 @@
from lightning_fabric.utilities.imports import _IS_WINDOWS
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.strategies.launchers import _SubprocessScriptLauncher
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.utilities.exceptions import SIGTERMException
from tests_pytorch.helpers.runif import RunIf
Expand Down Expand Up @@ -163,22 +161,14 @@ def test_has_already_handler(handler, expected_return):

def test_sigterm_notifier_fn():
trainer = Mock()
trainer.local_rank = 0
launcher = _SubprocessScriptLauncher(Mock(), 1, 1)
proc0 = Mock(autospec=subprocess.Popen)
proc1 = Mock(autospec=subprocess.Popen)
proc0.pid = 123
proc1.pid = 312
proc0.poll.return_value = None
launcher.procs = [proc0, proc1]
launcher = Mock()
trainer.strategy.launcher = launcher
connector = SignalConnector(trainer)

assert not connector.received_sigterm
with mock.patch("os.kill") as kill_mock:
connector._sigterm_notifier_fn(signal.SIGTERM, Mock())
kill_mock.assert_called_once_with(123, 15)
connector._sigterm_notifier_fn(signal.SIGTERM, Mock())
launcher.kill.assert_called_once_with(15)
assert connector.received_sigterm
with mock.patch("os.kill") as kill_mock:
connector._sigterm_notifier_fn(signal.SIGTERM, Mock())
kill_mock.assert_not_called()
launcher.reset_mock()
connector._sigterm_notifier_fn(signal.SIGTERM, Mock())
launcher.kill.assert_not_called()

0 comments on commit f977a2b

Please sign in to comment.