Skip to content

Commit

Permalink
make JobDoc available at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetretto committed Jan 17, 2025
1 parent 8a78800 commit 89e3d7c
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 4 deletions.
23 changes: 22 additions & 1 deletion src/jobflow_remote/jobs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@

from jobflow import JobStore, initialize_logger
from jobflow.core.flow import get_flow
from monty.design_patterns import singleton
from monty.os import cd
from monty.serialization import dumpfn, loadfn
from monty.shutil import decompress_file

from jobflow_remote.jobs.batch import LocalBatchManager
from jobflow_remote.jobs.data import IN_FILENAME, OUT_FILENAME
from jobflow_remote.jobs.data import IN_FILENAME, OUT_FILENAME, JobDoc
from jobflow_remote.remote.data import get_job_path, get_store_file_paths
from jobflow_remote.utils.log import initialize_remote_run_log

Expand All @@ -29,6 +30,20 @@
logger = logging.getLogger(__name__)


@singleton
class JfrState:
"""State of the current job being executed."""

job_doc: JobDoc = None

def reset(self):
"""Reset the current state."""
self.job_doc = None


CURRENT_JOBDOC: JfrState = JfrState()


def run_remote_job(run_dir: str | Path = ".") -> None:
"""Run the job."""
initialize_remote_run_log()
Expand All @@ -42,6 +57,10 @@ def run_remote_job(run_dir: str | Path = ".") -> None:

job: Job = in_data["job"]
store = in_data["store"]
job_doc_dict = in_data.get("job_doc", None)
if job_doc_dict:
job_doc_dict["job"] = job
JfrState().job_doc = JobDoc.model_validate(job_doc_dict)

store.connect()

Expand Down Expand Up @@ -91,6 +110,8 @@ def run_remote_job(run_dir: str | Path = ".") -> None:
"end_time": datetime.datetime.utcnow(),
}
dumpfn(output, OUT_FILENAME)
finally:
JfrState().reset()


def run_batch_jobs(
Expand Down
2 changes: 1 addition & 1 deletion src/jobflow_remote/jobs/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def upload(self, lock: MongoLock) -> None:
logger.error(err_msg)
raise RemoteError(err_msg, no_retry=False)

serialized_input = get_remote_in_file(job_dict, remote_store)
serialized_input = get_remote_in_file(job_dict, remote_store, doc)

path_file = Path(remote_path, IN_FILENAME)
host.put(serialized_input, str(path_file))
Expand Down
13 changes: 11 additions & 2 deletions src/jobflow_remote/remote/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,18 @@ def get_local_data_path(
return get_job_path(job_id, index, local_base_dir)


def get_remote_in_file(job, remote_store):
def get_remote_in_file(job, remote_store, job_doc=None):
# remove the job from the job_doc, if present.
# Create the copy from scratch to avoid allocating the job multiple
# times if it is big
job_doc_copy = None
if job_doc is not None:
job_doc_copy = {k: v for k, v in job_doc.items() if k not in ("job", "_id")}
# the document is likely locked when getting here.
job_doc_copy["lock_id"] = None
job_doc_copy["lock_time"] = None
d = jsanitize(
{"job": job, "store": remote_store},
{"job": job, "store": remote_store, "job_doc": job_doc_copy},
strict=True,
allow_bson=True,
enum_values=True,
Expand Down
7 changes: 7 additions & 0 deletions src/jobflow_remote/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,10 @@ def ignore_input(a: int) -> int:
Allows to test flows with failed parents
"""
return 1


@job
def current_jobdoc():
from jobflow_remote.jobs.run import CURRENT_JOBDOC

return CURRENT_JOBDOC.job_doc
32 changes: 32 additions & 0 deletions tests/db/jobs/test_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
def test_current_jobdoc(job_controller, runner):
from jobflow_remote import submit_flow
from jobflow_remote.jobs.run import CURRENT_JOBDOC, JfrState
from jobflow_remote.testing import current_jobdoc

j = current_jobdoc()
submit_flow([j], worker="test_local_worker")
runner.run_one_job()

job_output = job_controller.jobstore.get_output(uuid=j.uuid)
job_doc = job_controller.get_job_doc(job_id=j.uuid).as_db_dict()
for k in job_doc:
# some keys do not match
if k not in (
"state",
"end_time",
"start_time",
"updated_on",
"remote",
"run_dir",
"created_on",
):
assert job_doc[k] == job_output[k]

# check that CURRENT_JOBDOC is a singleton and can be set
s = JfrState()
assert s.job_doc is None
assert CURRENT_JOBDOC.job_doc is None
s.job_doc = job_doc
assert CURRENT_JOBDOC.job_doc == job_doc
s.reset()
assert CURRENT_JOBDOC.job_doc is None

0 comments on commit 89e3d7c

Please sign in to comment.