diff --git a/libs/libcommon/src/libcommon/queue.py b/libs/libcommon/src/libcommon/queue.py index 10c6abb3c6..884e9474df 100644 --- a/libs/libcommon/src/libcommon/queue.py +++ b/libs/libcommon/src/libcommon/queue.py @@ -37,7 +37,7 @@ def __get__(self, instance: object, cls: Type[U]) -> QuerySet[U]: # END monkey patching ### hack ### -class Status(enum.Enum): +class Status(str, enum.Enum): WAITING = "waiting" STARTED = "started" SUCCESS = "success" @@ -46,7 +46,7 @@ class Status(enum.Enum): SKIPPED = "skipped" -class Priority(enum.Enum): +class Priority(str, enum.Enum): NORMAL = "normal" LOW = "low" @@ -64,6 +64,7 @@ class JobDict(TypedDict): created_at: datetime started_at: Optional[datetime] finished_at: Optional[datetime] + last_heartbeat: Optional[datetime] class JobInfo(TypedDict): @@ -121,6 +122,7 @@ class Job(Document): created_at (`datetime`): The creation date of the job. started_at (`datetime`, optional): When the job has started. finished_at (`datetime`, optional): When the job has finished. + last_heartbeat (`datetime`, optional): Last time the running job got a heartbeat from the worker. """ meta = { @@ -148,6 +150,7 @@ class Job(Document): created_at = DateTimeField(required=True) started_at = DateTimeField() finished_at = DateTimeField() + last_heartbeat = DateTimeField() def to_dict(self) -> JobDict: return { @@ -163,6 +166,7 @@ def to_dict(self) -> JobDict: "created_at": self.created_at, "started_at": self.started_at, "finished_at": self.finished_at, + "last_heartbeat": self.last_heartbeat, } objects = QuerySetManager["Job"]() diff --git a/services/worker/dev.Dockerfile b/services/worker/dev.Dockerfile index 4fd090fd93..2468dd10b6 100644 --- a/services/worker/dev.Dockerfile +++ b/services/worker/dev.Dockerfile @@ -30,7 +30,9 @@ COPY services/worker/poetry.lock ./services/worker/poetry.lock COPY services/worker/pyproject.toml ./services/worker/pyproject.toml COPY libs/libcommon ./libs/libcommon WORKDIR /src/services/worker/ -RUN poetry install --no-cache +RUN --mount=type=cache,target=/home/.cache/pypoetry/cache \ + --mount=type=cache,target=/home/.cache/pypoetry/artifacts \ + poetry install --no-root # FOR LOCAL DEVELOPMENT ENVIRONMENT # No need to copy the source code since we map a volume in docker-compose-base.yaml @@ -38,4 +40,4 @@ RUN poetry install --no-cache # Removed: RUN poetry install --no-cache # However we need to install the package when the container starts # Added: poetry install -ENTRYPOINT ["/bin/sh", "-c" , "poetry install && poetry run python src/worker/main.py"] +ENTRYPOINT ["/bin/sh", "-c" , "poetry install --only-root && poetry run python src/worker/main.py"] diff --git a/services/worker/poetry.lock b/services/worker/poetry.lock index 58eb2dcf88..72d57d9f4a 100644 --- a/services/worker/poetry.lock +++ b/services/worker/poetry.lock @@ -2212,6 +2212,24 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mirakuru" +version = "2.4.2" +description = "Process executor (not only) for tests." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mirakuru-2.4.2-py3-none-any.whl", hash = "sha256:fdb67d141cc9f7abd485a515d618daf3272c3e6ff48380749997ff8e8c5f2cb2"}, + {file = "mirakuru-2.4.2.tar.gz", hash = "sha256:ec84d4d81b4bca96cb0e598c6b3d198a92f036a0c1223c881482c02a98508226"}, +] + +[package.dependencies] +psutil = {version = ">=4.0.0", markers = "sys_platform != \"cygwin\""} + +[package.extras] +tests = ["pytest", "pytest-cov", "python-daemon"] + [[package]] name = "mongo-types" version = "0.15.1" @@ -2857,6 +2875,13 @@ files = [ {file = "Pillow-9.4.0-1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:b8c2f6eb0df979ee99433d8b3f6d193d9590f735cf12274c108bd954e30ca858"}, {file = "Pillow-9.4.0-1-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:b70756ec9417c34e097f987b4d8c510975216ad26ba6e57ccb53bc758f490dab"}, {file = "Pillow-9.4.0-1-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:43521ce2c4b865d385e78579a082b6ad1166ebed2b1a2293c3be1d68dd7ca3b9"}, + {file = "Pillow-9.4.0-2-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:9d9a62576b68cd90f7075876f4e8444487db5eeea0e4df3ba298ee38a8d067b0"}, + {file = "Pillow-9.4.0-2-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:87708d78a14d56a990fbf4f9cb350b7d89ee8988705e58e39bdf4d82c149210f"}, + {file = "Pillow-9.4.0-2-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:8a2b5874d17e72dfb80d917213abd55d7e1ed2479f38f001f264f7ce7bae757c"}, + {file = "Pillow-9.4.0-2-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:83125753a60cfc8c412de5896d10a0a405e0bd88d0470ad82e0869ddf0cb3848"}, + {file = "Pillow-9.4.0-2-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:9e5f94742033898bfe84c93c831a6f552bb629448d4072dd312306bab3bd96f1"}, + {file = "Pillow-9.4.0-2-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:013016af6b3a12a2f40b704677f8b51f72cb007dac785a9933d5c86a72a7fe33"}, + {file = "Pillow-9.4.0-2-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:99d92d148dd03fd19d16175b6d355cc1b01faf80dae93c6c3eb4163709edc0a9"}, {file = "Pillow-9.4.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:2968c58feca624bb6c8502f9564dd187d0e1389964898f5e9e1fbc8533169157"}, {file = "Pillow-9.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c5c1362c14aee73f50143d74389b2c158707b4abce2cb055b7ad37ce60738d47"}, {file = "Pillow-9.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd752c5ff1b4a870b7661234694f24b1d2b9076b8bf337321a814c612665f343"}, @@ -4669,11 +4694,14 @@ files = [ {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"}, {file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"}, {file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"}, + {file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"}, + {file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"}, {file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"}, + {file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"}, {file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"}, @@ -4682,6 +4710,7 @@ files = [ {file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"}, {file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"}, {file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"}, + {file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"}, @@ -4748,6 +4777,7 @@ opt-einsum = ["opt-einsum (>=3.3)"] [package.source] type = "url" url = "https://download.pytorch.org/whl/cpu/torch-1.13.1%2Bcpu-cp39-cp39-linux_x86_64.whl" + [[package]] name = "torchaudio" version = "0.13.1+cpu" @@ -4765,6 +4795,7 @@ torch = "1.13.1" [package.source] type = "url" url = "https://download.pytorch.org/whl/cpu/torchaudio-0.13.1%2Bcpu-cp39-cp39-linux_x86_64.whl" + [[package]] name = "tqdm" version = "4.64.1" @@ -5459,4 +5490,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "3.9.15" -content-hash = "4b1fde55056862f335ccae9046704cced7bf1d9599e54a00e5b35c5dde787528" +content-hash = "c20f6820064117295d9db223142b9ccdafd534b12c3304e3fa51ccf3dbcd622f" diff --git a/services/worker/pyproject.toml b/services/worker/pyproject.toml index b885443c51..7679130223 100644 --- a/services/worker/pyproject.toml +++ b/services/worker/pyproject.toml @@ -40,6 +40,7 @@ transformers = "^4.26.1" trec-car-tools = { path = "vendors/trec-car-tools/python3" } typer = "^0.4.2" wget = "^3.2" +mirakuru = "^2.4.2" [tool.poetry.group.dev.dependencies] bandit = "^1.7.4" diff --git a/services/worker/src/worker/config.py b/services/worker/src/worker/config.py index b0ae7cda73..b7ae5cc141 100644 --- a/services/worker/src/worker/config.py +++ b/services/worker/src/worker/config.py @@ -19,6 +19,7 @@ WORKER_MAX_LOAD_PCT = 70 WORKER_MAX_MEMORY_PCT = 80 WORKER_SLEEP_SECONDS = 15 +WORKER_HEARTBEAT_TIME_INTERVAL_SECONDS = 60 def get_empty_str_list() -> List[str]: @@ -34,6 +35,8 @@ class WorkerConfig: only_job_types: list[str] = field(default_factory=get_empty_str_list) sleep_seconds: int = WORKER_SLEEP_SECONDS storage_paths: List[str] = field(default_factory=get_empty_str_list) + state_path: Optional[str] = None + heartbeat_time_interval_seconds: int = WORKER_HEARTBEAT_TIME_INTERVAL_SECONDS @classmethod def from_env(cls) -> "WorkerConfig": @@ -47,6 +50,10 @@ def from_env(cls) -> "WorkerConfig": sleep_seconds=env.int(name="SLEEP_SECONDS", default=WORKER_SLEEP_SECONDS), only_job_types=env.list(name="ONLY_JOB_TYPES", default=get_empty_str_list()), storage_paths=env.list(name="STORAGE_PATHS", default=get_empty_str_list()), + state_path=env.str(name="STATE_PATH", default=None), + heartbeat_time_interval_seconds=env.int( + name="HEARTBEAT_TIME_INTERVAL_SECONDS", default=WORKER_HEARTBEAT_TIME_INTERVAL_SECONDS + ), ) diff --git a/services/worker/src/worker/loop.py b/services/worker/src/worker/loop.py index e4cffbf644..24b013f188 100644 --- a/services/worker/src/worker/loop.py +++ b/services/worker/src/worker/loop.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright 2022 The HuggingFace Authors. +import json import logging import random import time from dataclasses import dataclass, field +from typing import Optional, TypedDict -from libcommon.queue import EmptyQueueError, Queue +from filelock import FileLock +from libcommon.queue import EmptyQueueError, JobInfo, Queue from psutil import cpu_count, disk_usage, getloadavg, swap_memory, virtual_memory from worker.config import WorkerConfig @@ -17,6 +20,10 @@ class UnknownJobTypeError(Exception): pass +class WorkerState(TypedDict): + current_job_info: Optional[JobInfo] + + @dataclass class Loop: """ @@ -96,7 +103,7 @@ def sleep(self) -> None: time.sleep(duration) def run(self) -> None: - logging.info("Worker started") + logging.info("Worker loop started") try: while True: if self.has_resources() and self.process_next_job(): @@ -119,13 +126,23 @@ def process_next_job(self) -> bool: f" ${', '.join(self.worker_config.only_job_types)}). The queue should not have provided this" " job. It is in an inconsistent state. Please report this issue to the datasets team." ) + self.set_worker_state(current_job_info=job_info) logging.debug(f"job assigned: {job_info}") except EmptyQueueError: + self.set_worker_state(current_job_info=None) logging.debug("no job in the queue") return False job_runner = self.job_runner_factory.create_job_runner(job_info) finished_status = job_runner.run() self.queue.finish_job(job_id=job_runner.job_id, finished_status=finished_status) + self.set_worker_state(current_job_info=None) logging.debug(f"job finished with {finished_status.value}: {job_runner}") return True + + def set_worker_state(self, current_job_info: Optional[JobInfo]) -> None: + worker_state: WorkerState = {"current_job_info": current_job_info} + if self.worker_config.state_path: + with FileLock(self.worker_config.state_path + ".lock"): + with open(self.worker_config.state_path, "w") as worker_state_f: + json.dump(worker_state, worker_state_f) diff --git a/services/worker/src/worker/main.py b/services/worker/src/worker/main.py index a7c1816d73..439cefbe65 100644 --- a/services/worker/src/worker/main.py +++ b/services/worker/src/worker/main.py @@ -1,53 +1,90 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright 2022 The HuggingFace Authors. +import json +import logging +import os +import sys +import tempfile +import time +from typing import Optional +from filelock import FileLock from libcommon.log import init_logging -from libcommon.processing_graph import ProcessingGraph -from libcommon.resources import CacheMongoResource, QueueMongoResource -from libcommon.storage import init_assets_dir +from libcommon.queue import Job, Status, get_datetime +from libcommon.resources import QueueMongoResource +from mirakuru import OutputExecutor +from worker import start_worker_loop from worker.config import AppConfig -from worker.job_runner_factory import JobRunnerFactory -from worker.loop import Loop -from worker.resources import LibrariesResource +from worker.loop import WorkerState + +WORKER_STATE_FILE_NAME = "worker_state.json" +START_WORKER_LOOP_PATH = start_worker_loop.__file__ + + +class WorkerExecutor: + def __init__(self, app_config: AppConfig) -> None: + self.app_config = app_config + + def _create_worker_loop_executor(self) -> OutputExecutor: + banner = self.app_config.worker.state_path + if not banner: + raise ValueError("Failed to create the executor because WORKER_STATE_PATH is missing.") + start_worker_loop_command = [ + sys.executable, + START_WORKER_LOOP_PATH, + "--print-worker-state-path", + ] + return OutputExecutor(start_worker_loop_command, banner, timeout=10) + + def start(self) -> None: + worker_loop_executor = self._create_worker_loop_executor() + worker_loop_executor.start() # blocking until the banner is printed + logging.info("Starting heartbeat.") + while worker_loop_executor.running(): + self.heartbeat() + time.sleep(self.app_config.worker.heartbeat_time_interval_seconds) + worker_loop_executor.stop() + + def get_state(self) -> WorkerState: + worker_state_path = self.app_config.worker.state_path + if not worker_state_path: + raise ValueError("Failed to get worker state because WORKER_STATE_PATH is missing.") + if os.path.exists(worker_state_path): + with FileLock(worker_state_path + ".lock"): + try: + with open(worker_state_path, "r") as worker_state_f: + worker_state = json.load(worker_state_f) + return WorkerState(current_job_info=worker_state.get("current_job_info")) + except json.JSONDecodeError: + return WorkerState(current_job_info=None) + else: + return WorkerState(current_job_info=None) + + def get_current_job(self) -> Optional[Job]: + worker_state = self.get_state() + if worker_state["current_job_info"]: + job = Job.objects.with_id(worker_state["current_job_info"]["job_id"]) # type: ignore + if job and isinstance(job, Job) and job.status == Status.STARTED: + return job + return None + + def heartbeat(self) -> None: + current_job = self.get_current_job() + if current_job: + current_job.update(last_heartbeat=get_datetime()) + if __name__ == "__main__": - app_config = AppConfig.from_env() - - init_logging(log_level=app_config.common.log_level) - # ^ set first to have logs as soon as possible - assets_directory = init_assets_dir(directory=app_config.assets.storage_directory) - - processing_graph = ProcessingGraph(app_config.processing_graph.specification) - - with ( - LibrariesResource( - hf_endpoint=app_config.common.hf_endpoint, - init_hf_datasets_cache=app_config.datasets_based.hf_datasets_cache, - numba_path=app_config.numba.path, - ) as libraries_resource, - CacheMongoResource( - database=app_config.cache.mongo_database, host=app_config.cache.mongo_url - ) as cache_resource, - QueueMongoResource( + with tempfile.TemporaryDirectory() as tmp_dir: + if "WORKER_STATE_PATH" not in os.environ: + os.environ["WORKER_STATE_PATH"] = os.path.join(tmp_dir, WORKER_STATE_FILE_NAME) + + app_config = AppConfig.from_env() + init_logging(log_level=app_config.common.log_level) + + with QueueMongoResource( database=app_config.queue.mongo_database, host=app_config.queue.mongo_url - ) as queue_resource, - ): - if not cache_resource.is_available(): - raise RuntimeError("The connection to the cache database could not be established. Exiting.") - if not queue_resource.is_available(): - raise RuntimeError("The connection to the queue database could not be established. Exiting.") - - job_runner_factory = JobRunnerFactory( - app_config=app_config, - processing_graph=processing_graph, - hf_datasets_cache=libraries_resource.hf_datasets_cache, - assets_directory=assets_directory, - ) - loop = Loop( - library_cache_paths=libraries_resource.storage_paths, - job_runner_factory=job_runner_factory, - max_jobs_per_namespace=app_config.queue.max_jobs_per_namespace, - worker_config=app_config.worker, - ) - loop.run() + ) as queue_resource: + if not queue_resource.is_available(): + raise RuntimeError("The connection to the queue database could not be established. Exiting.") + worker_executor = WorkerExecutor(app_config) + worker_executor.start() diff --git a/services/worker/src/worker/start_worker_loop.py b/services/worker/src/worker/start_worker_loop.py new file mode 100644 index 0000000000..311971ee7d --- /dev/null +++ b/services/worker/src/worker/start_worker_loop.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2022 The HuggingFace Authors. + +import sys + +from libcommon.log import init_logging +from libcommon.processing_graph import ProcessingGraph +from libcommon.resources import CacheMongoResource, QueueMongoResource +from libcommon.storage import init_assets_dir + +from worker.config import AppConfig +from worker.job_runner_factory import JobRunnerFactory +from worker.loop import Loop +from worker.resources import LibrariesResource + +if __name__ == "__main__": + app_config = AppConfig.from_env() + if "--print-worker-state-path" in sys.argv: + print(app_config.worker.state_path, flush=True) + + init_logging(log_level=app_config.common.log_level) + # ^ set first to have logs as soon as possible + assets_directory = init_assets_dir(directory=app_config.assets.storage_directory) + + processing_graph = ProcessingGraph(app_config.processing_graph.specification) + + with ( + LibrariesResource( + hf_endpoint=app_config.common.hf_endpoint, + init_hf_datasets_cache=app_config.datasets_based.hf_datasets_cache, + numba_path=app_config.numba.path, + ) as libraries_resource, + CacheMongoResource( + database=app_config.cache.mongo_database, host=app_config.cache.mongo_url + ) as cache_resource, + QueueMongoResource( + database=app_config.queue.mongo_database, host=app_config.queue.mongo_url + ) as queue_resource, + ): + if not cache_resource.is_available(): + raise RuntimeError("The connection to the cache database could not be established. Exiting.") + if not queue_resource.is_available(): + raise RuntimeError("The connection to the queue database could not be established. Exiting.") + + job_runner_factory = JobRunnerFactory( + app_config=app_config, + processing_graph=processing_graph, + hf_datasets_cache=libraries_resource.hf_datasets_cache, + assets_directory=assets_directory, + ) + loop = Loop( + library_cache_paths=libraries_resource.storage_paths, + job_runner_factory=job_runner_factory, + max_jobs_per_namespace=app_config.queue.max_jobs_per_namespace, + worker_config=app_config.worker, + ) + loop.run() diff --git a/services/worker/tests/conftest.py b/services/worker/tests/conftest.py index 279ae7d6f4..eed12a2723 100644 --- a/services/worker/tests/conftest.py +++ b/services/worker/tests/conftest.py @@ -12,6 +12,7 @@ from pytest import MonkeyPatch, fixture from worker.config import AppConfig, FirstRowsConfig +from worker.main import WORKER_STATE_FILE_NAME from worker.resources import LibrariesResource from .constants import CI_APP_TOKEN, CI_HUB_ENDPOINT, CI_URL_TEMPLATE, CI_USER_TOKEN @@ -27,6 +28,11 @@ def modules_cache_directory(tmp_path: Path) -> Path: return tmp_path / "modules" +@fixture +def worker_state_path(tmp_path: Path) -> Path: + return tmp_path / WORKER_STATE_FILE_NAME + + # see https://github.com/pytest-dev/pytest/issues/363#issuecomment-406536200 @fixture(scope="session", autouse=True) def monkeypatch_session() -> Iterator[MonkeyPatch]: @@ -41,7 +47,9 @@ def monkeypatch_session() -> Iterator[MonkeyPatch]: # see https://github.com/pytest-dev/pytest/issues/363#issuecomment-406536200 @fixture -def set_env_vars(datasets_cache_directory: Path, modules_cache_directory: Path) -> Iterator[MonkeyPatch]: +def set_env_vars( + datasets_cache_directory: Path, modules_cache_directory: Path, worker_state_path: Path +) -> Iterator[MonkeyPatch]: mp = MonkeyPatch() mp.setenv("CACHE_MONGO_DATABASE", "datasets_server_cache_test") mp.setenv("QUEUE_MONGO_DATABASE", "datasets_server_queue_test") @@ -55,6 +63,8 @@ def set_env_vars(datasets_cache_directory: Path, modules_cache_directory: Path) mp.setenv("DATASETS_BASED_HF_DATASETS_CACHE", str(datasets_cache_directory)) mp.setenv("HF_MODULES_CACHE", str(modules_cache_directory)) mp.setenv("WORKER_CONTENT_MAX_BYTES", "10_000_000") + mp.setenv("WORKER_STATE_PATH", str(worker_state_path)) + mp.setenv("WORKER_HEARTBEAT_TIME_INTERVAL_SECONDS", "1") yield mp mp.undo() diff --git a/services/worker/tests/test_executor.py b/services/worker/tests/test_executor.py new file mode 100644 index 0000000000..46f9ad5d28 --- /dev/null +++ b/services/worker/tests/test_executor.py @@ -0,0 +1,180 @@ +import json +import os +import sys +import time +from datetime import timedelta +from pathlib import Path +from typing import Iterator +from unittest.mock import patch + +import pytest +import pytz +from filelock import FileLock +from libcommon.queue import Job, JobInfo, Priority, Status, get_datetime +from libcommon.resources import QueueMongoResource +from mirakuru import ProcessExitedWithError, TimeoutExpired +from pytest import fixture + +from worker.config import AppConfig +from worker.loop import WorkerState +from worker.main import WorkerExecutor + + +def get_job_info() -> JobInfo: + return JobInfo( + job_id="a" * 24, + type="bar", + dataset="user/my_dataset", + config="default", + split="train", + force=False, + priority=Priority.LOW, + ) + + +def write_worker_state(worker_state: WorkerState, worker_state_path: str) -> None: + with FileLock(worker_state_path + ".lock"): + with open(worker_state_path, "w") as worker_state_f: + json.dump(worker_state, worker_state_f) + + +def start_worker_loop() -> None: + app_config = AppConfig.from_env() + if not app_config.worker.state_path: + raise ValueError("Failed to get worker state because WORKER_STATE_PATH is missing.") + if "--print-worker-state-path" in sys.argv: + print(app_config.worker.state_path, flush=True) + current_job_info = get_job_info() + worker_state = WorkerState(current_job_info=current_job_info) + write_worker_state(worker_state, app_config.worker.state_path) + + +def start_worker_loop_that_crashes() -> None: + app_config = AppConfig.from_env() + if not app_config.worker.state_path: + raise ValueError("Failed to get worker state because WORKER_STATE_PATH is missing.") + if "--print-worker-state-path" in sys.argv: + print(app_config.worker.state_path, flush=True) + raise RuntimeError("Tried to run a bad worker loop") + + +def start_worker_loop_that_times_out() -> None: + time.sleep(20) + + +@fixture +def set_worker_state(worker_state_path: Path) -> Iterator[WorkerState]: + job_info = get_job_info() + worker_state = WorkerState(current_job_info=job_info) + write_worker_state(worker_state, str(worker_state_path)) + yield worker_state + os.remove(worker_state_path) + + +@fixture +def set_started_job_in_queue(queue_mongo_resource: QueueMongoResource) -> Iterator[Job]: + if not queue_mongo_resource.is_available(): + raise RuntimeError("Mongo resource is not available") + job_info = get_job_info() + if Job.objects.with_id(job_info["job_id"]): # type: ignore + Job.objects.with_id(job_info["job_id"]).delete() # type: ignore + job = Job( + pk=job_info["job_id"], + type=job_info["type"], + dataset=job_info["dataset"], + config=job_info["config"], + split=job_info["split"], + unicity_id="unicity_id", + namespace="user", + priority=job_info["priority"], + status=Status.STARTED, + created_at=get_datetime(), + ) + job.save() + yield job + job.delete() + + +def test_executor_get_state(app_config: AppConfig, set_worker_state: WorkerState) -> None: + executor = WorkerExecutor(app_config) + assert executor.get_state() == set_worker_state + + +def test_executor_get_empty_state(app_config: AppConfig) -> None: + executor = WorkerExecutor(app_config) + assert executor.get_state() == WorkerState(current_job_info=None) + + +def test_executor_get_current_job( + app_config: AppConfig, set_started_job_in_queue: Job, set_worker_state: WorkerState +) -> None: + executor = WorkerExecutor(app_config) + assert executor.get_current_job() == set_started_job_in_queue + + +def test_executor_get_nonexisting_current_job(app_config: AppConfig) -> None: + executor = WorkerExecutor(app_config) + assert executor.get_current_job() is None + + +def test_executor_heartbeat( + app_config: AppConfig, + set_started_job_in_queue: Job, + set_worker_state: WorkerState, + queue_mongo_resource: QueueMongoResource, +) -> None: + if not queue_mongo_resource.is_available(): + raise RuntimeError("Mongo resource is not available") + executor = WorkerExecutor(app_config) + current_job = executor.get_current_job() + assert current_job is not None + assert current_job.last_heartbeat is None + executor.heartbeat() + current_job = executor.get_current_job() + assert current_job is not None + assert current_job.last_heartbeat is not None + last_heartbeat_datetime = pytz.UTC.localize(current_job.last_heartbeat) + assert last_heartbeat_datetime >= get_datetime() - timedelta(seconds=1) + + +def test_executor_start( + app_config: AppConfig, queue_mongo_resource: QueueMongoResource, set_started_job_in_queue: Job +) -> None: + if not queue_mongo_resource.is_available(): + raise RuntimeError("Mongo resource is not available") + executor = WorkerExecutor(app_config) + with patch.object(executor, "heartbeat", wraps=executor.heartbeat) as heartbeat_mock: + with patch("worker.main.START_WORKER_LOOP_PATH", __file__): + executor.start() + current_job = executor.get_current_job() + assert current_job is not None + assert str(current_job.pk) == get_job_info()["job_id"] + assert heartbeat_mock.call_count > 0 + + +@pytest.mark.parametrize( + "bad_worker_loop_type", ["start_worker_loop_that_crashes", "start_worker_loop_that_times_out"] +) +def test_executor_raises_on_bad_worker( + app_config: AppConfig, queue_mongo_resource: QueueMongoResource, tmp_path: Path, bad_worker_loop_type: str +) -> None: + if not queue_mongo_resource.is_available(): + raise RuntimeError("Mongo resource is not available") + bad_start_worker_loop_path = tmp_path / "bad_start_worker_loop.py" + with bad_start_worker_loop_path.open("w") as bad_start_worker_loop_f: + bad_start_worker_loop_f.write("raise RuntimeError('Tried to start a bad worker loop.')") + executor = WorkerExecutor(app_config) + with patch.dict(os.environ, {"WORKER_LOOP_TYPE": bad_worker_loop_type}): + with patch("worker.main.START_WORKER_LOOP_PATH", __file__): + with pytest.raises((ProcessExitedWithError, TimeoutExpired)): + executor.start() + + +if __name__ == "__main__": + worker_loop_type = os.environ.get("WORKER_LOOP_TYPE", "start_worker_loop") + if worker_loop_type == "start_worker_loop_that_crashes": + start_worker_loop_that_crashes() + elif worker_loop_type == "start_worker_loop_that_times_out": + start_worker_loop_that_times_out() + else: + start_worker_loop()