Skip to content

Commit

Permalink
Add ComputeClient methods to submit & get tasks
Browse files Browse the repository at this point in the history
- Added the `.get_task()`, `.get_task_batch()`, and `.get_task_group()`
methods to the `ComputeClientV2` class.
- Added the `.submit()` method to the `ComputeClientV3` class.
  • Loading branch information
rjmello committed Nov 7, 2024
1 parent 0b981a8 commit 1767b52
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 1 deletion.
22 changes: 22 additions & 0 deletions src/globus_sdk/_testing/data/compute/_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
import uuid

ENDPOINT_ID = str(uuid.uuid1())

FUNCTION_ID = str(uuid.uuid1())
FUNCTION_NAME = "howdy_world"
FUNCTION_CODE = "410\n10\n04\n:gASVQAAAAAAAAACMC2hvd2R5X3dvc ..."

TASK_GROUP_ID = str(uuid.uuid1())
TASK_ID = str(uuid.uuid1())
TASK_ID_2 = str(uuid.uuid1())
TASK_DOC = {
"task_id": TASK_ID,
"status": "success",
"result": "10000",
"completion_t": "1677183605.212898",
"details": {
"os": "Linux-5.19.0-1025-aws-x86_64-with-glibc2.35",
"python_version": "3.10.4",
"dill_version": "0.3.5.1",
"globus_compute_sdk_version": "2.3.2",
"task_transitions": {
"execution-start": 1692742841.843334,
"execution-end": 1692742846.123456,
},
},
}
13 changes: 13 additions & 0 deletions src/globus_sdk/_testing/data/compute/v2/get_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from globus_sdk._testing.models import RegisteredResponse, ResponseSet

from .._common import TASK_DOC, TASK_ID

RESPONSES = ResponseSet(
metadata={"task_id": TASK_ID},
default=RegisteredResponse(
service="compute",
path=f"/v2/tasks/{TASK_ID}",
method="GET",
json=TASK_DOC,
),
)
22 changes: 22 additions & 0 deletions src/globus_sdk/_testing/data/compute/v2/get_task_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from responses.matchers import json_params_matcher

from globus_sdk._testing.models import RegisteredResponse, ResponseSet

from .._common import TASK_DOC, TASK_ID

TASK_BATCH_DOC = {
"response": "batch",
"results": {TASK_ID: TASK_DOC},
}

RESPONSES = ResponseSet(
metadata={"task_id": TASK_ID},
default=RegisteredResponse(
service="compute",
path="/v2/tasks/batch",
method="POST",
json=TASK_BATCH_DOC,
# Ensure task_ids is a list
match=[json_params_matcher({"task_ids": [TASK_ID]})],
),
)
26 changes: 26 additions & 0 deletions src/globus_sdk/_testing/data/compute/v2/get_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from globus_sdk._testing.models import RegisteredResponse, ResponseSet

from .._common import TASK_GROUP_ID, TASK_ID, TASK_ID_2

TASK_BATCH_DOC = {
"taskgroup_id": TASK_GROUP_ID,
"create_websockets_queue": True,
"tasks": [
{"id": TASK_ID, "created_at": "2021-05-05T15:00:00.000000"},
{"id": TASK_ID_2, "created_at": "2021-05-05T15:01:00.000000"},
],
}

RESPONSES = ResponseSet(
metadata={
"task_group_id": TASK_GROUP_ID,
"task_id": TASK_ID,
"task_id_2": TASK_ID_2,
},
default=RegisteredResponse(
service="compute",
path=f"/v2/taskgroup/{TASK_GROUP_ID}",
method="GET",
json=TASK_BATCH_DOC,
),
)
Empty file.
31 changes: 31 additions & 0 deletions src/globus_sdk/_testing/data/compute/v3/submit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from globus_sdk._testing.models import RegisteredResponse, ResponseSet

from .._common import ENDPOINT_ID, FUNCTION_ID, TASK_GROUP_ID, TASK_ID, TASK_ID_2

REQUEST_ID = "5158de19-10b5-4deb-9d87-a86c1dec3460"

SUBMIT_RESPONSE = {
"request_id": REQUEST_ID,
"task_group_id": TASK_GROUP_ID,
"endpoint_id": ENDPOINT_ID,
"tasks": {
FUNCTION_ID: [TASK_ID, TASK_ID_2],
},
}

RESPONSES = ResponseSet(
metadata={
"endpoint_id": ENDPOINT_ID,
"function_id": FUNCTION_ID,
"task_id": TASK_ID,
"task_id_2": TASK_ID_2,
"task_group_id": TASK_GROUP_ID,
"request_id": REQUEST_ID,
},
default=RegisteredResponse(
service="compute",
path=f"/v3/endpoints/{ENDPOINT_ID}/submit",
method="POST",
json=SUBMIT_RESPONSE,
),
)
68 changes: 67 additions & 1 deletion src/globus_sdk/services/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import typing as t

from globus_sdk import GlobusHTTPResponse, client
from globus_sdk import GlobusHTTPResponse, client, utils
from globus_sdk._types import UUIDLike
from globus_sdk.scopes import ComputeScopes, Scope

Expand Down Expand Up @@ -72,6 +72,54 @@ def delete_function(self, function_id: UUIDLike) -> GlobusHTTPResponse:
""" # noqa: E501
return self.delete(f"/v2/functions/{function_id}")

def get_task(self, task_id: UUIDLike) -> GlobusHTTPResponse:
"""Get information about a task.
:param task_id: The ID of the task.
.. tab-set::
.. tab-item:: API Info
.. extdoclink:: Get Task
:service: compute
:ref: Tasks/operation/get_task_status_and_result_v2_tasks__task_uuid__get
""" # noqa: E501
return self.get(f"/v2/tasks/{task_id}")

def get_task_batch(
self, task_ids: UUIDLike | t.Iterable[UUIDLike]
) -> GlobusHTTPResponse:
"""Get information about a batch of tasks.
:param task_ids: The IDs of the tasks.
.. tab-set::
.. tab-item:: API Info
.. extdoclink:: Get Task Batch
:service: compute
:ref: operation/get_batch_status_v2_batch_status_post
"""
task_ids = list(utils.safe_strseq_iter(task_ids))
return self.post("/v2/tasks/batch", data={"task_ids": task_ids})

def get_task_group(self, task_group_id: UUIDLike) -> GlobusHTTPResponse:
"""Get a list of task IDs associated with a task group.
:param task_group_id: The ID of the task group.
.. tab-set::
.. tab-item:: API Info
.. extdoclink:: Get Task Group Tasks
:service: compute
:ref: TaskGroup/operation/get_task_group_tasks_v2_taskgroup__task_group_uuid__get
""" # noqa: E501
return self.get(f"/v2/taskgroup/{task_group_id}")


class ComputeClientV3(client.BaseClient):
r"""
Expand All @@ -85,6 +133,24 @@ class ComputeClientV3(client.BaseClient):
scopes = ComputeScopes
default_scope_requirements = [Scope(ComputeScopes.all)]

def submit(
self, endpoint_id: UUIDLike, data: dict[str, t.Any]
) -> GlobusHTTPResponse:
"""Submit a batch of tasks to a Globus Compute endpoint.
:param endpoint_id: The ID of the Globus Compute endpoint.
:param data: The task batch document.
.. tab-set::
.. tab-item:: API Info
.. extdoclink:: Submit Batch
:service: compute
:ref: Endpoints/operation/submit_batch_v3_endpoints__endpoint_uuid__submit_post
""" # noqa: E501
return self.post(f"/v3/endpoints/{endpoint_id}/submit", data=data)


class ComputeClient(ComputeClientV2):
r"""
Expand Down
9 changes: 9 additions & 0 deletions tests/functional/services/compute/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,12 @@ class CustomComputeClientV2(globus_sdk.ComputeClientV2):
transport_class = no_retry_transport

return CustomComputeClientV2()


@pytest.fixture
def compute_client_v3(no_retry_transport):
class CustomComputeClientV3(globus_sdk.ComputeClientV3):
transport_class = no_retry_transport

return CustomComputeClientV3()

33 changes: 33 additions & 0 deletions tests/functional/services/compute/v2/test_get_task_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import uuid

import pytest

import globus_sdk
from globus_sdk._testing import load_response


@pytest.mark.parametrize(
"task_id_style", ("string", "list", "set", "uuid", "uuid_list")
)
def test_get_task_batch(
compute_client_v2: globus_sdk.ComputeClientV2, task_id_style: str
):
meta = load_response(compute_client_v2.get_task_batch).metadata

if task_id_style == "string":
task_ids = meta["task_id"]
elif task_id_style == "list":
task_ids = [meta["task_id"]]
elif task_id_style == "set":
task_ids = {meta["task_id"]}
elif task_id_style == "uuid":
task_ids = uuid.UUID(meta["task_id"])
elif task_id_style == "uuid_list":
task_ids = [uuid.UUID(meta["task_id"])]
else:
raise NotImplementedError(f"Unknown task_id_style {task_id_style}")

res = compute_client_v2.get_task_batch(task_ids=task_ids)

assert res.http_status == 200
assert meta["task_id"] in res.data["results"]
11 changes: 11 additions & 0 deletions tests/functional/services/compute/v2/test_get_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import globus_sdk
from globus_sdk._testing import load_response


def test_get_task_group(compute_client_v2: globus_sdk.ComputeClientV2):
meta = load_response(compute_client_v2.get_task_group).metadata
res = compute_client_v2.get_task_group(task_group_id=meta["task_group_id"])
assert res.http_status == 200
assert meta["task_group_id"] == res.data["taskgroup_id"]
assert meta["task_id"] == res.data["tasks"][0]["id"]
assert meta["task_id_2"] == res.data["tasks"][1]["id"]
9 changes: 9 additions & 0 deletions tests/functional/services/compute/v2/test_get_task_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import globus_sdk
from globus_sdk._testing import load_response


def test_get_task(compute_client_v2: globus_sdk.ComputeClientV2):
meta = load_response(compute_client_v2.get_task).metadata
res = compute_client_v2.get_task(task_id=meta["task_id"])
assert res.http_status == 200
assert res.data["task_id"] == meta["task_id"]
31 changes: 31 additions & 0 deletions tests/functional/services/compute/v3/test_submit_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import globus_sdk
from globus_sdk._testing import load_response


def test_submit(compute_client_v3: globus_sdk.ComputeClientV3):
meta = load_response(compute_client_v3.submit).metadata
submit_doc = {
"tasks": {
meta["function_id"]: [
"36\n00\ngASVDAAAAAAAAACMBlJvZG5leZSFlC4=\n12 ...",
"36\n00\ngASVCwAAAAAAAACMBUJvYmJ5lIWULg==\n12 ...",
],
},
"task_group_id": meta["task_group_id"],
"create_queue": True,
"user_runtime": {
"globus_compute_sdk_version": "2.29.0",
"globus_sdk_version": "3.46.0",
"python_version": "3.11.9",
},
}

res = compute_client_v3.submit(endpoint_id=meta["endpoint_id"], data=submit_doc)

assert res.http_status == 200
assert res.data["request_id"] == meta["request_id"]
assert res.data["task_group_id"] == meta["task_group_id"]
assert res.data["endpoint_id"] == meta["endpoint_id"]
assert res.data["tasks"] == {
meta["function_id"]: [meta["task_id"], meta["task_id_2"]]
}

0 comments on commit 1767b52

Please sign in to comment.