Skip to content

Commit

Permalink
Adapt tests to reuse Orchestrator (#567)
Browse files Browse the repository at this point in the history
Tests which needed to launch an Orchestrator were spinning up and
shutting down their own instances. This led to a number of cases where a
single test failing would cascade into failures of other tests.
Additionally, this also meant that a significant amount of time in the
tests was spent waiting for Orchestrators to launch.

This PR adds a session-scoped fixture that returns an Orchestrator. Most
tests which use an Orchestrator have been updated to use this fixture;
the remaining for various reasons still need to spin up their own (for
example the multiple database tests need to have a named Orchestrator).

[ committed by @ashao ]
[ reviewed by @ankona @AlyssaCote ]

Co-authored-by: Matt Drozt <drozt@hpe.com>
  • Loading branch information
ashao and MattToast authored May 14, 2024
1 parent 8606e8e commit 781d4b6
Show file tree
Hide file tree
Showing 33 changed files with 618 additions and 522 deletions.
302 changes: 178 additions & 124 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,23 @@
from __future__ import annotations

import asyncio
from collections import defaultdict
from dataclasses import dataclass
import json
import os
import pathlib
import shutil
import subprocess
import signal
import socket
import sys
import tempfile
import time
import typing as t
import uuid
import warnings
from subprocess import run
import time

import psutil
import pytest
Expand All @@ -53,7 +57,7 @@
from smartsim._core.utils.telemetry.telemetry import JobEntity
from smartsim.database import Orchestrator
from smartsim.entity import Model
from smartsim.error import SSConfigError
from smartsim.error import SSConfigError, SSInternalError
from smartsim.log import get_logger
from smartsim.settings import (
AprunSettings,
Expand All @@ -78,7 +82,7 @@
test_num_gpus = CONFIG.test_num_gpus
test_nic = CONFIG.test_interface
test_alloc_specs_path = os.getenv("SMARTSIM_TEST_ALLOC_SPEC_SHEET_PATH", None)
test_port = CONFIG.test_port
test_ports = CONFIG.test_ports
test_account = CONFIG.test_account or ""
test_batch_resources: t.Dict[t.Any, t.Any] = CONFIG.test_batch_resources
test_output_dirs = 0
Expand All @@ -89,7 +93,6 @@
test_hostlist = None
has_aprun = shutil.which("aprun") is not None


def get_account() -> str:
return test_account

Expand All @@ -109,9 +112,7 @@ def print_test_configuration() -> None:
print("TEST_ALLOC_SPEC_SHEET_PATH:", test_alloc_specs_path)
print("TEST_DIR:", test_output_root)
print("Test output will be located in TEST_DIR if there is a failure")
print(
"TEST_PORTS:", ", ".join(str(port) for port in range(test_port, test_port + 3))
)
print("TEST_PORTS:", ", ".join(str(port) for port in test_ports))
if test_batch_resources:
print("TEST_BATCH_RESOURCES: ")
print(json.dumps(test_batch_resources, indent=2))
Expand Down Expand Up @@ -297,7 +298,23 @@ def _reset():
)


@pytest.fixture
def _find_free_port(ports: t.Collection[int]) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
for port in ports:
try:
sock.bind(("127.0.0.1", port))
except socket.error:
continue
else:
_, port_ = sock.getsockname()
return int(port_)
raise SSInternalError(
"Could not find a free port out of a options: "
f"{', '.join(str(port) for port in sorted(ports))}"
)


@pytest.fixture(scope="session")
def wlmutils() -> t.Type[WLMUtils]:
return WLMUtils

Expand All @@ -314,7 +331,9 @@ def get_test_launcher() -> str:

@staticmethod
def get_test_port() -> int:
return test_port
# TODO: Ideally this should find a free port on the correct host(s),
# but this is good enough for now
return _find_free_port(test_ports)

@staticmethod
def get_test_account() -> str:
Expand Down Expand Up @@ -420,61 +439,6 @@ def get_run_settings(

return RunSettings(exe, args)

@staticmethod
def get_orchestrator(nodes: int = 1, batch: bool = False) -> Orchestrator:
if test_launcher == "pbs":
if not shutil.which("aprun"):
hostlist = get_hostlist()
else:
hostlist = None
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
interface=test_nic,
launcher=test_launcher,
hosts=hostlist,
)
if test_launcher == "pals":
hostlist = get_hostlist()
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
interface=test_nic,
launcher=test_launcher,
hosts=hostlist,
)
if test_launcher == "slurm":
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
interface=test_nic,
launcher=test_launcher,
)
if test_launcher == "dragon":
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
interface=test_nic,
launcher=test_launcher,
)
if test_launcher == "lsf":
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
cpus_per_shard=4,
gpus_per_shard=2 if test_device == "GPU" else 0,
project=get_account(),
interface=test_nic,
launcher=test_launcher,
)

return Orchestrator(port=test_port, interface="lo")

@staticmethod
def choose_host(rs: RunSettings) -> t.Optional[str]:
if isinstance(rs, (MpirunSettings, MpiexecSettings)):
Expand All @@ -485,65 +449,6 @@ def choose_host(rs: RunSettings) -> t.Optional[str]:
return None


@pytest.fixture
def local_db(
request: t.Any, wlmutils: t.Type[WLMUtils], test_dir: str
) -> t.Generator[Orchestrator, None, None]:
"""Yield fixture for startup and teardown of an local orchestrator"""

exp_name = request.function.__name__
exp = Experiment(exp_name, launcher="local", exp_path=test_dir)
db = Orchestrator(port=wlmutils.get_test_port(), interface="lo")
db.set_path(test_dir)
exp.start(db)

yield db
# pass or fail, the teardown code below is ran after the
# completion of a test case that uses this fixture
exp.stop(db)


@pytest.fixture
def db(
request: t.Any, wlmutils: t.Type[WLMUtils], test_dir: str
) -> t.Generator[Orchestrator, None, None]:
"""Yield fixture for startup and teardown of an orchestrator"""
launcher = wlmutils.get_test_launcher()

exp_name = request.function.__name__
exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir)
db = wlmutils.get_orchestrator()
db.set_path(test_dir)
exp.start(db)

yield db
# pass or fail, the teardown code below is ran after the
# completion of a test case that uses this fixture
exp.stop(db)


@pytest.fixture
def db_cluster(
test_dir: str, wlmutils: t.Type[WLMUtils], request: t.Any
) -> t.Generator[Orchestrator, None, None]:
"""
Yield fixture for startup and teardown of a clustered orchestrator.
This should only be used in on_wlm and full_wlm tests.
"""
launcher = wlmutils.get_test_launcher()

exp_name = request.function.__name__
exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir)
db = wlmutils.get_orchestrator(nodes=3)
db.set_path(test_dir)
exp.start(db)

yield db
# pass or fail, the teardown code below is ran after the
# completion of a test case that uses this fixture
exp.stop(db)


@pytest.fixture(scope="function", autouse=True)
def environment_cleanup(monkeypatch: pytest.MonkeyPatch) -> None:
for key in os.environ.keys():
Expand Down Expand Up @@ -750,7 +655,7 @@ def setup_test_colo(
db_args: t.Dict[str, t.Any],
colo_settings: t.Optional[RunSettings] = None,
colo_model_name: str = "colocated_model",
port: int = test_port,
port: t.Optional[int] = None,
on_wlm: bool = False,
) -> Model:
"""Setup database needed for the colo pinning tests"""
Expand All @@ -766,10 +671,11 @@ def setup_test_colo(
if on_wlm:
colo_settings.set_tasks(1)
colo_settings.set_nodes(1)

colo_model = exp.create_model(colo_model_name, colo_settings)

if db_type in ["tcp", "deprecated"]:
db_args["port"] = port
db_args["port"] = port if port is not None else _find_free_port(test_ports)
db_args["ifname"] = "lo"
if db_type == "uds" and colo_model_name is not None:
tmp_dir = tempfile.gettempdir()
Expand Down Expand Up @@ -968,3 +874,151 @@ def num_calls(self) -> int:
@property
def details(self) -> t.List[t.Tuple[t.Tuple[t.Any, ...], t.Dict[str, t.Any]]]:
return self._details

## Reuse database across tests

database_registry: t.DefaultDict[str, t.Optional[Orchestrator]] = defaultdict(lambda: None)

@pytest.fixture(scope="function")
def local_experiment(test_dir: str) -> smartsim.Experiment:
"""Create a default experiment that uses the requested launcher"""
name = pathlib.Path(test_dir).stem
return smartsim.Experiment(name, exp_path=test_dir, launcher="local")

@pytest.fixture(scope="function")
def wlm_experiment(test_dir: str, wlmutils: WLMUtils) -> smartsim.Experiment:
"""Create a default experiment that uses the requested launcher"""
name = pathlib.Path(test_dir).stem
return smartsim.Experiment(
name,
exp_path=test_dir,
launcher=wlmutils.get_test_launcher()
)

def _cleanup_db(name: str) -> None:
global database_registry
db = database_registry[name]
if db and db.is_active():
exp = Experiment("cleanup")
try:
db = exp.reconnect_orchestrator(db.checkpoint_file)
exp.stop(db)
except:
pass

@dataclass
class DBConfiguration:
name: str
launcher: str
num_nodes: int
interface: t.Union[str,t.List[str]]
hostlist: t.Optional[t.List[str]]
port: int

@dataclass
class PrepareDatabaseOutput:
orchestrator: t.Optional[Orchestrator] # The actual orchestrator object
new_db: bool # True if a new database was created when calling prepare_db

# Reuse databases
@pytest.fixture(scope="session")
def local_db() -> t.Generator[DBConfiguration, None, None]:
name = "local_db_fixture"
config = DBConfiguration(
name,
"local",
1,
"lo",
None,
_find_free_port(tuple(reversed(test_ports))),
)
yield config
_cleanup_db(name)

@pytest.fixture(scope="session")
def single_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]:
hostlist = wlmutils.get_test_hostlist()
hostlist = hostlist[-1:] if hostlist is not None else None
name = "single_db_fixture"
config = DBConfiguration(
name,
wlmutils.get_test_launcher(),
1,
wlmutils.get_test_interface(),
hostlist,
_find_free_port(tuple(reversed(test_ports)))
)
yield config
_cleanup_db(name)


@pytest.fixture(scope="session")
def clustered_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]:
hostlist = wlmutils.get_test_hostlist()
hostlist = hostlist[-4:-1] if hostlist is not None else None
name = "clustered_db_fixture"
config = DBConfiguration(
name,
wlmutils.get_test_launcher(),
3,
wlmutils.get_test_interface(),
hostlist,
_find_free_port(tuple(reversed(test_ports))),
)
yield config
_cleanup_db(name)


@pytest.fixture
def register_new_db() -> t.Callable[[DBConfiguration], Orchestrator]:
def _register_new_db(
config: DBConfiguration
) -> Orchestrator:
exp_path = pathlib.Path(test_output_root, config.name)
exp_path.mkdir(exist_ok=True)
exp = Experiment(
config.name,
exp_path=str(exp_path),
launcher=config.launcher,
)
orc = exp.create_database(
port=config.port,
batch=False,
interface=config.interface,
hosts=config.hostlist,
db_nodes=config.num_nodes
)
exp.generate(orc, overwrite=True)
exp.start(orc)
global database_registry
database_registry[config.name] = orc
return orc
return _register_new_db


@pytest.fixture(scope="function")
def prepare_db(
register_new_db: t.Callable[
[DBConfiguration],
Orchestrator
]
) -> t.Callable[
[DBConfiguration],
PrepareDatabaseOutput
]:
def _prepare_db(db_config: DBConfiguration) -> PrepareDatabaseOutput:
global database_registry
db = database_registry[db_config.name]

new_db = False
db_up = False

if db:
db_up = db.is_active()

if not db_up or db is None:
db = register_new_db(db_config)
new_db = True

return PrepareDatabaseOutput(db, new_db)
return _prepare_db
Loading

0 comments on commit 781d4b6

Please sign in to comment.