From 89e3d7caea91faba95cc39910a8598d4e6dd6e46 Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Wed, 15 Jan 2025 17:43:57 +0100 Subject: [PATCH] make JobDoc available at runtime --- src/jobflow_remote/jobs/run.py | 23 +++++++++++++++++- src/jobflow_remote/jobs/runner.py | 2 +- src/jobflow_remote/remote/data.py | 13 +++++++++-- src/jobflow_remote/testing/__init__.py | 7 ++++++ tests/db/jobs/test_run.py | 32 ++++++++++++++++++++++++++ 5 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 tests/db/jobs/test_run.py diff --git a/src/jobflow_remote/jobs/run.py b/src/jobflow_remote/jobs/run.py index 1b8b1415..4bec5c82 100644 --- a/src/jobflow_remote/jobs/run.py +++ b/src/jobflow_remote/jobs/run.py @@ -12,12 +12,13 @@ from jobflow import JobStore, initialize_logger from jobflow.core.flow import get_flow +from monty.design_patterns import singleton from monty.os import cd from monty.serialization import dumpfn, loadfn from monty.shutil import decompress_file from jobflow_remote.jobs.batch import LocalBatchManager -from jobflow_remote.jobs.data import IN_FILENAME, OUT_FILENAME +from jobflow_remote.jobs.data import IN_FILENAME, OUT_FILENAME, JobDoc from jobflow_remote.remote.data import get_job_path, get_store_file_paths from jobflow_remote.utils.log import initialize_remote_run_log @@ -29,6 +30,20 @@ logger = logging.getLogger(__name__) +@singleton +class JfrState: + """State of the current job being executed.""" + + job_doc: JobDoc = None + + def reset(self): + """Reset the current state.""" + self.job_doc = None + + +CURRENT_JOBDOC: JfrState = JfrState() + + def run_remote_job(run_dir: str | Path = ".") -> None: """Run the job.""" initialize_remote_run_log() @@ -42,6 +57,10 @@ def run_remote_job(run_dir: str | Path = ".") -> None: job: Job = in_data["job"] store = in_data["store"] + job_doc_dict = in_data.get("job_doc", None) + if job_doc_dict: + job_doc_dict["job"] = job + JfrState().job_doc = JobDoc.model_validate(job_doc_dict) store.connect() @@ -91,6 +110,8 @@ def run_remote_job(run_dir: str | Path = ".") -> None: "end_time": datetime.datetime.utcnow(), } dumpfn(output, OUT_FILENAME) + finally: + JfrState().reset() def run_batch_jobs( diff --git a/src/jobflow_remote/jobs/runner.py b/src/jobflow_remote/jobs/runner.py index 53c3464d..c6370424 100644 --- a/src/jobflow_remote/jobs/runner.py +++ b/src/jobflow_remote/jobs/runner.py @@ -648,7 +648,7 @@ def upload(self, lock: MongoLock) -> None: logger.error(err_msg) raise RemoteError(err_msg, no_retry=False) - serialized_input = get_remote_in_file(job_dict, remote_store) + serialized_input = get_remote_in_file(job_dict, remote_store, doc) path_file = Path(remote_path, IN_FILENAME) host.put(serialized_input, str(path_file)) diff --git a/src/jobflow_remote/remote/data.py b/src/jobflow_remote/remote/data.py index 76d21d50..c0cb2375 100644 --- a/src/jobflow_remote/remote/data.py +++ b/src/jobflow_remote/remote/data.py @@ -56,9 +56,18 @@ def get_local_data_path( return get_job_path(job_id, index, local_base_dir) -def get_remote_in_file(job, remote_store): +def get_remote_in_file(job, remote_store, job_doc=None): + # remove the job from the job_doc, if present. + # Create the copy from scratch to avoid allocating the job multiple + # times if it is big + job_doc_copy = None + if job_doc is not None: + job_doc_copy = {k: v for k, v in job_doc.items() if k not in ("job", "_id")} + # the document is likely locked when getting here. + job_doc_copy["lock_id"] = None + job_doc_copy["lock_time"] = None d = jsanitize( - {"job": job, "store": remote_store}, + {"job": job, "store": remote_store, "job_doc": job_doc_copy}, strict=True, allow_bson=True, enum_values=True, diff --git a/src/jobflow_remote/testing/__init__.py b/src/jobflow_remote/testing/__init__.py index fa10eed2..a546ff57 100644 --- a/src/jobflow_remote/testing/__init__.py +++ b/src/jobflow_remote/testing/__init__.py @@ -96,3 +96,10 @@ def ignore_input(a: int) -> int: Allows to test flows with failed parents """ return 1 + + +@job +def current_jobdoc(): + from jobflow_remote.jobs.run import CURRENT_JOBDOC + + return CURRENT_JOBDOC.job_doc diff --git a/tests/db/jobs/test_run.py b/tests/db/jobs/test_run.py new file mode 100644 index 00000000..da1db14b --- /dev/null +++ b/tests/db/jobs/test_run.py @@ -0,0 +1,32 @@ +def test_current_jobdoc(job_controller, runner): + from jobflow_remote import submit_flow + from jobflow_remote.jobs.run import CURRENT_JOBDOC, JfrState + from jobflow_remote.testing import current_jobdoc + + j = current_jobdoc() + submit_flow([j], worker="test_local_worker") + runner.run_one_job() + + job_output = job_controller.jobstore.get_output(uuid=j.uuid) + job_doc = job_controller.get_job_doc(job_id=j.uuid).as_db_dict() + for k in job_doc: + # some keys do not match + if k not in ( + "state", + "end_time", + "start_time", + "updated_on", + "remote", + "run_dir", + "created_on", + ): + assert job_doc[k] == job_output[k] + + # check that CURRENT_JOBDOC is a singleton and can be set + s = JfrState() + assert s.job_doc is None + assert CURRENT_JOBDOC.job_doc is None + s.job_doc = job_doc + assert CURRENT_JOBDOC.job_doc == job_doc + s.reset() + assert CURRENT_JOBDOC.job_doc is None