Skip to content

Commit

Permalink
make the environment check a warning
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetretto committed Oct 24, 2024
1 parent 06c90f8 commit 35f584e
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 24 deletions.
11 changes: 8 additions & 3 deletions src/jobflow_remote/cli/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def check(
bool,
typer.Option(
"--full",
"-f",
help="Perform a full check",
),
] = False,
Expand All @@ -190,6 +191,7 @@ def check(
workers_to_test = [worker]

tick = "[bold green]✓[/] "
tick_warn = "[bold yellow]✓[/] "
cross = "[bold red]x[/] "
errors = []
with loading_spinner(processing=False) as progress:
Expand All @@ -199,13 +201,16 @@ def check(
worker_to_test = project.workers[worker_name]
if worker_to_test.get_host().interactive_login:
with hide_progress(progress):
err = check_worker(worker_to_test, full_check=full)
err, worker_warn = check_worker(worker_to_test, full_check=full)
else:
err = check_worker(worker_to_test, full_check=full)
err, worker_warn = check_worker(worker_to_test, full_check=full)
header = tick
if err:
errors.append((f"Worker {worker_name}", err))
errors.append((f"Worker {worker_name} ", err))
header = cross
elif worker_warn:
errors.append((f"Worker {worker_name} warning ", worker_warn))
header = tick_warn
progress.print(Text.from_markup(header + f"Worker {worker_name}"))

if check_all or jobstore:
Expand Down
54 changes: 34 additions & 20 deletions src/jobflow_remote/config/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,33 +152,40 @@ def _check_workdir(worker: WorkerBase, host: BaseHost) -> str | None:
host.execute(f"rm {str(canary_file)!r}")


def check_worker(worker: WorkerBase, full_check: bool = False) -> str | None:
def check_worker(
worker: WorkerBase, full_check: bool = False
) -> tuple[str | None, str | None]:
"""Check that a connection to the configured worker can be made."""
host = worker.get_host()
worker_warn = None
try:
host.connect()
host_error = host.test()
if host_error:
return host_error
return host_error, None

from jobflow_remote.remote.queue import QueueManager

qm = QueueManager(scheduler_io=worker.get_scheduler_io(), host=host)
qm.get_jobs_list()

_check_workdir(worker=worker, host=host)
_check_environment(worker=worker, host=host, full_check=full_check)
# don't perform the environment check, as they will be equivalent
if worker.type != "local":
worker_warn = _check_environment(
worker=worker, host=host, full_check=full_check
)

except Exception:
exc = traceback.format_exc()
return f"Error while testing worker:\n {exc}"
return f"Error while testing worker:\n {exc}", worker_warn
finally:
try:
host.close()
except Exception:
logger.warning(f"error while closing connection to host {host}")

return None
return None, worker_warn


def _check_store(store: Store) -> str | None:
Expand Down Expand Up @@ -218,26 +225,25 @@ def _check_environment(
Parameters
----------
worker: The worker configuration.
host: A connected host.
full_check: Whether to check the entire environment and not just jobflow and jobflow-remote.
Returns
-------
str | None
A message describing the environment mismatches. None if no mismatch is found.
"""
# TODO: not sure about this test here but I based this check function on the _check_worker function
# which does the same.
try:
host_error = host.test()
if host_error:
return host_error
except Exception:
exc = traceback.format_exc()
return f"Error while testing worker:\n {exc}"

installed_packages = importlib.metadata.distributions()
local_package_versions = {
package.metadata["Name"]: package.version for package in installed_packages
}
stdout, stderr, errcode = host.execute("pip list --format=json")
cmd = "pip list --format=json"
if worker.pre_run:
cmd = "; ".join(worker.pre_run.strip().splitlines()) + "; " + cmd

stdout, stderr, errcode = host.execute(cmd)
if errcode != 0:
return f"Error while checking the compatibility of the environments: {stderr}"
host_package_versions = {
package_dict["name"]: package_dict["version"]
for package_dict in json.loads(stdout)
Expand All @@ -260,6 +266,14 @@ def _check_environment(
host_package_versions[package],
)
)
if missing or mismatch:
raise ValueError("Version mismatch")
return None
msg = None
if mismatch or missing:
msg = "Note: inconsistencies may be due to the proper python environment not being correctly loaded.\n"
if missing:
missing_str = [f"{m[0]} - {m[1]}" for m in missing]
msg += f"Missing packages: {', '.join(missing_str)}. "
if mismatch:
mismatch_str = [f"{m[0]} - {m[1]} vs {m[2]}" for m in mismatch]
msg += f"Mismatching versions: {', '.join(mismatch_str)}"

return msg
73 changes: 73 additions & 0 deletions tests/db/cli/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,79 @@ def test_check(job_controller) -> None:
run_check_cli(["project", "check"], required_out=output)


def test_check_fail(job_controller, monkeypatch, tmp_dir) -> None:
import json
import os

from maggma.stores.mongolike import MongoStore
from monty.serialization import dumpfn

from jobflow_remote import SETTINGS
from jobflow_remote.config import helper
from jobflow_remote.remote.host.remote import RemoteHost
from jobflow_remote.remote.queue import QueueManager
from jobflow_remote.testing.cli import run_check_cli

def return_none(*args, **kwargs):
return None

def exec_jobflow_version(*args, **kwargs):
d = [
{"name": "jobflow", "version": "0.1.0"},
{"name": "jobflow-remote", "version": "0.1.0"},
]
return json.dumps(d), "", 0

# change project directory and test options there
with monkeypatch.context() as m:
m.setattr(SETTINGS, "projects_folder", os.getcwd())
m.setattr(SETTINGS, "project", "testtest")
run_check_cli(["project", "list"], required_out="No project available in")

# create a project with a fake remote worker
worker_dict = {
"scheduler_type": "shell",
"work_dir": "/fake/path",
"type": "remote",
"host": "fake_host",
"timeout_execute": 1,
}
queue_dict = {"store": MongoStore("xxx", "yyy").as_dict()}
dumpfn(
{
"name": "testtest",
"workers": {"fake_remote_worker": worker_dict},
"queue": queue_dict,
},
"testtest.yaml",
)

# project check fails as it cannot connect
err_required = ["Errors:", "x Worker fake_remote_worker"]
run_check_cli(
["project", "check", "-e", "-w", "fake_remote_worker"],
required_out=err_required,
)

# mock all the functions to make the check succeed, except the mismatching jobflow versions
m.setattr(RemoteHost, "connect", return_none)
m.setattr(RemoteHost, "test", return_none)
m.setattr(RemoteHost, "write_text_file", return_none)
m.setattr(RemoteHost, "execute", exec_jobflow_version)
m.setattr(QueueManager, "get_jobs_list", return_none)
m.setattr(helper, "_check_workdir", return_none)
warn_required = [
"✓ Worker fake_remote_worker",
"Errors:",
"Mismatching versions: jobflow",
"jobflow-remote",
]
run_check_cli(
["project", "check", "-e", "-w", "fake_remote_worker"],
required_out=warn_required,
)


def test_remove(job_controller, random_project_name, monkeypatch, tmp_dir) -> None:
import os

Expand Down
5 changes: 4 additions & 1 deletion tests/integration/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def test_project_check(job_controller, capsys) -> None:
"✓ Jobstore",
"✓ Queue store",
]
run_check_cli(["project", "check", "-e"], required_out=expected)
excluded = ["Errors:"]
run_check_cli(
["project", "check", "-e"], required_out=expected, excluded_out=excluded
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 35f584e

Please sign in to comment.