Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt tests to reuse Orchestrator #567

Merged
merged 30 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3ff2918
Replace some creation of databases in tests with a fixture
ashao May 13, 2024
47b264f
Catch a bug where Orchestrator.is_active fails instead of returning F…
ashao Apr 23, 2024
fb03466
Static analysis fixes
MattToast Apr 25, 2024
5c2e4d2
Extract common code path to ctx manager
MattToast Apr 26, 2024
cff2e59
Use fixture in onnx tests
MattToast Apr 26, 2024
d4ac659
Remove duplicate code
MattToast Apr 26, 2024
e858364
fix test orch tests
MattToast Apr 26, 2024
7283a44
Revert changes to tests that require explicit DB ID, make test impl m…
MattToast Apr 29, 2024
c830bd7
Update test to new fixture names
MattToast Apr 29, 2024
2061f94
Reserve last 4 hosts for session scoped DB fixtures
MattToast Apr 29, 2024
7699fbd
Session long fixtures select ports in reverse order to avoid conflicts
MattToast Apr 30, 2024
46de22b
Lint
MattToast Apr 30, 2024
089a72e
Remove Orc `is_active` check in controller
MattToast Apr 30, 2024
9ac5456
Bump number of required nodes
MattToast Apr 30, 2024
d4c69fe
Revert unnecessary change to mini experiment test
MattToast Apr 30, 2024
c5b0445
Changelog
ashao May 11, 2024
dda5382
Misc QoL changes
ashao May 11, 2024
80a447b
Revert out of scope changes
MattToast May 1, 2024
dea85d7
Unify min number of required ports
MattToast May 1, 2024
e38cdcc
Ensure exp dir exists
MattToast May 1, 2024
ecba413
Reviewer feedback, more generic type, typo
ashao May 11, 2024
62c6008
Only find free port if needed
MattToast May 3, 2024
5c9e393
reviewer feedback
MattToast May 6, 2024
1207772
Initial work to restart and clean dbs
ashao May 8, 2024
bb2f82d
Register new orchestrators and check if up. Restart if necessary
ashao May 11, 2024
1c2d10a
style checks
ashao May 11, 2024
a029062
Fix sorting
ashao May 11, 2024
a8a4e3a
Fix last remaining tests
ashao May 11, 2024
e460ce4
Fix wlm tests
ashao May 13, 2024
3694737
Add typehints and changes for Dragon test
ashao May 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 178 additions & 124 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,23 @@
from __future__ import annotations

import asyncio
from collections import defaultdict
from dataclasses import dataclass
import json
import os
import pathlib
import shutil
import subprocess
import signal
import socket
import sys
import tempfile
import time
import typing as t
import uuid
import warnings
from subprocess import run
import time

import psutil
import pytest
Expand All @@ -53,7 +57,7 @@
from smartsim._core.utils.telemetry.telemetry import JobEntity
from smartsim.database import Orchestrator
from smartsim.entity import Model
from smartsim.error import SSConfigError
from smartsim.error import SSConfigError, SSInternalError
from smartsim.log import get_logger
from smartsim.settings import (
AprunSettings,
Expand All @@ -78,7 +82,7 @@
test_num_gpus = CONFIG.test_num_gpus
test_nic = CONFIG.test_interface
test_alloc_specs_path = os.getenv("SMARTSIM_TEST_ALLOC_SPEC_SHEET_PATH", None)
test_port = CONFIG.test_port
test_ports = CONFIG.test_ports
test_account = CONFIG.test_account or ""
test_batch_resources: t.Dict[t.Any, t.Any] = CONFIG.test_batch_resources
test_output_dirs = 0
Expand All @@ -89,7 +93,6 @@
test_hostlist = None
has_aprun = shutil.which("aprun") is not None


def get_account() -> str:
return test_account

Expand All @@ -109,9 +112,7 @@ def print_test_configuration() -> None:
print("TEST_ALLOC_SPEC_SHEET_PATH:", test_alloc_specs_path)
print("TEST_DIR:", test_output_root)
print("Test output will be located in TEST_DIR if there is a failure")
print(
"TEST_PORTS:", ", ".join(str(port) for port in range(test_port, test_port + 3))
)
print("TEST_PORTS:", ", ".join(str(port) for port in test_ports))
if test_batch_resources:
print("TEST_BATCH_RESOURCES: ")
print(json.dumps(test_batch_resources, indent=2))
Expand Down Expand Up @@ -297,7 +298,23 @@ def _reset():
)


@pytest.fixture
def _find_free_port(ports: t.Collection[int]) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
for port in ports:
try:
sock.bind(("127.0.0.1", port))
except socket.error:
continue
else:
_, port_ = sock.getsockname()
return int(port_)
raise SSInternalError(
"Could not find a free port out of a options: "
f"{', '.join(str(port) for port in sorted(ports))}"
)


@pytest.fixture(scope="session")
def wlmutils() -> t.Type[WLMUtils]:
return WLMUtils

Expand All @@ -314,7 +331,9 @@ def get_test_launcher() -> str:

@staticmethod
def get_test_port() -> int:
return test_port
# TODO: Ideally this should find a free port on the correct host(s),
MattToast marked this conversation as resolved.
Show resolved Hide resolved
# but this is good enough for now
return _find_free_port(test_ports)

@staticmethod
def get_test_account() -> str:
Expand Down Expand Up @@ -420,61 +439,6 @@ def get_run_settings(

return RunSettings(exe, args)

@staticmethod
def get_orchestrator(nodes: int = 1, batch: bool = False) -> Orchestrator:
if test_launcher == "pbs":
if not shutil.which("aprun"):
hostlist = get_hostlist()
else:
hostlist = None
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
interface=test_nic,
launcher=test_launcher,
hosts=hostlist,
)
if test_launcher == "pals":
hostlist = get_hostlist()
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
interface=test_nic,
launcher=test_launcher,
hosts=hostlist,
)
if test_launcher == "slurm":
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
interface=test_nic,
launcher=test_launcher,
)
if test_launcher == "dragon":
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
interface=test_nic,
launcher=test_launcher,
)
if test_launcher == "lsf":
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
cpus_per_shard=4,
gpus_per_shard=2 if test_device == "GPU" else 0,
project=get_account(),
interface=test_nic,
launcher=test_launcher,
)

return Orchestrator(port=test_port, interface="lo")

@staticmethod
def choose_host(rs: RunSettings) -> t.Optional[str]:
if isinstance(rs, (MpirunSettings, MpiexecSettings)):
Expand All @@ -485,65 +449,6 @@ def choose_host(rs: RunSettings) -> t.Optional[str]:
return None


@pytest.fixture
def local_db(
request: t.Any, wlmutils: t.Type[WLMUtils], test_dir: str
) -> t.Generator[Orchestrator, None, None]:
"""Yield fixture for startup and teardown of an local orchestrator"""

exp_name = request.function.__name__
exp = Experiment(exp_name, launcher="local", exp_path=test_dir)
db = Orchestrator(port=wlmutils.get_test_port(), interface="lo")
db.set_path(test_dir)
exp.start(db)

yield db
# pass or fail, the teardown code below is ran after the
# completion of a test case that uses this fixture
exp.stop(db)
AlyssaCote marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture
def db(
request: t.Any, wlmutils: t.Type[WLMUtils], test_dir: str
) -> t.Generator[Orchestrator, None, None]:
"""Yield fixture for startup and teardown of an orchestrator"""
launcher = wlmutils.get_test_launcher()

exp_name = request.function.__name__
exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir)
db = wlmutils.get_orchestrator()
db.set_path(test_dir)
exp.start(db)

yield db
# pass or fail, the teardown code below is ran after the
# completion of a test case that uses this fixture
exp.stop(db)


@pytest.fixture
def db_cluster(
test_dir: str, wlmutils: t.Type[WLMUtils], request: t.Any
) -> t.Generator[Orchestrator, None, None]:
"""
Yield fixture for startup and teardown of a clustered orchestrator.
This should only be used in on_wlm and full_wlm tests.
"""
launcher = wlmutils.get_test_launcher()

exp_name = request.function.__name__
exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir)
db = wlmutils.get_orchestrator(nodes=3)
db.set_path(test_dir)
exp.start(db)

yield db
# pass or fail, the teardown code below is ran after the
# completion of a test case that uses this fixture
exp.stop(db)


@pytest.fixture(scope="function", autouse=True)
def environment_cleanup(monkeypatch: pytest.MonkeyPatch) -> None:
for key in os.environ.keys():
Expand Down Expand Up @@ -750,7 +655,7 @@ def setup_test_colo(
db_args: t.Dict[str, t.Any],
colo_settings: t.Optional[RunSettings] = None,
colo_model_name: str = "colocated_model",
port: int = test_port,
port: t.Optional[int] = None,
on_wlm: bool = False,
) -> Model:
"""Setup database needed for the colo pinning tests"""
Expand All @@ -766,10 +671,11 @@ def setup_test_colo(
if on_wlm:
colo_settings.set_tasks(1)
colo_settings.set_nodes(1)

colo_model = exp.create_model(colo_model_name, colo_settings)

if db_type in ["tcp", "deprecated"]:
db_args["port"] = port
db_args["port"] = port if port is not None else _find_free_port(test_ports)
db_args["ifname"] = "lo"
if db_type == "uds" and colo_model_name is not None:
tmp_dir = tempfile.gettempdir()
Expand Down Expand Up @@ -968,3 +874,151 @@ def num_calls(self) -> int:
@property
def details(self) -> t.List[t.Tuple[t.Tuple[t.Any, ...], t.Dict[str, t.Any]]]:
return self._details

## Reuse database across tests

database_registry: t.DefaultDict[str, t.Optional[Orchestrator]] = defaultdict(lambda: None)

@pytest.fixture(scope="function")
def local_experiment(test_dir: str) -> smartsim.Experiment:
"""Create a default experiment that uses the requested launcher"""
name = pathlib.Path(test_dir).stem
return smartsim.Experiment(name, exp_path=test_dir, launcher="local")

@pytest.fixture(scope="function")
def wlm_experiment(test_dir: str, wlmutils: WLMUtils) -> smartsim.Experiment:
"""Create a default experiment that uses the requested launcher"""
name = pathlib.Path(test_dir).stem
return smartsim.Experiment(
name,
exp_path=test_dir,
launcher=wlmutils.get_test_launcher()
)

def _cleanup_db(name: str) -> None:
global database_registry
db = database_registry[name]
if db and db.is_active():
exp = Experiment("cleanup")
try:
db = exp.reconnect_orchestrator(db.checkpoint_file)
exp.stop(db)
except:
pass

@dataclass
class DBConfiguration:
name: str
launcher: str
num_nodes: int
interface: t.Union[str,t.List[str]]
hostlist: t.Optional[t.List[str]]
port: int

@dataclass
class PrepareDatabaseOutput:
orchestrator: t.Optional[Orchestrator] # The actual orchestrator object
new_db: bool # True if a new database was created when calling prepare_db

# Reuse databases
@pytest.fixture(scope="session")
def local_db() -> t.Generator[DBConfiguration, None, None]:
name = "local_db_fixture"
config = DBConfiguration(
name,
"local",
1,
"lo",
None,
_find_free_port(tuple(reversed(test_ports))),
)
yield config
_cleanup_db(name)

@pytest.fixture(scope="session")
def single_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]:
hostlist = wlmutils.get_test_hostlist()
hostlist = hostlist[-1:] if hostlist is not None else None
name = "single_db_fixture"
config = DBConfiguration(
name,
wlmutils.get_test_launcher(),
1,
wlmutils.get_test_interface(),
hostlist,
_find_free_port(tuple(reversed(test_ports)))
)
yield config
_cleanup_db(name)


@pytest.fixture(scope="session")
def clustered_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]:
hostlist = wlmutils.get_test_hostlist()
hostlist = hostlist[-4:-1] if hostlist is not None else None
name = "clustered_db_fixture"
config = DBConfiguration(
name,
wlmutils.get_test_launcher(),
3,
wlmutils.get_test_interface(),
hostlist,
_find_free_port(tuple(reversed(test_ports))),
)
yield config
_cleanup_db(name)


@pytest.fixture
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this fixture need a scope?

def register_new_db() -> t.Callable[[DBConfiguration], Orchestrator]:
def _register_new_db(
config: DBConfiguration
) -> Orchestrator:
exp_path = pathlib.Path(test_output_root, config.name)
exp_path.mkdir(exist_ok=True)
exp = Experiment(
config.name,
exp_path=str(exp_path),
launcher=config.launcher,
)
orc = exp.create_database(
port=config.port,
batch=False,
interface=config.interface,
hosts=config.hostlist,
db_nodes=config.num_nodes
)
exp.generate(orc, overwrite=True)
exp.start(orc)
global database_registry
database_registry[config.name] = orc
return orc
return _register_new_db


@pytest.fixture(scope="function")
def prepare_db(
register_new_db: t.Callable[
[DBConfiguration],
Orchestrator
]
) -> t.Callable[
[DBConfiguration],
PrepareDatabaseOutput
]:
def _prepare_db(db_config: DBConfiguration) -> PrepareDatabaseOutput:
global database_registry
db = database_registry[db_config.name]

new_db = False
db_up = False

if db:
db_up = db.is_active()

if not db_up or db is None:
db = register_new_db(db_config)
new_db = True

return PrepareDatabaseOutput(db, new_db)
return _prepare_db
Loading
Loading