Skip to content

Commit

Permalink
Merge pull request #1094 from rjmello/compute-client-task-methods
Browse files Browse the repository at this point in the history
Add ComputeClient methods to submit & get tasks
  • Loading branch information
kurtmckee authored Nov 8, 2024
2 parents f55e41c + 11edfb8 commit 5270f0e
Show file tree
Hide file tree
Showing 15 changed files with 350 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Added
~~~~~

- Added the ``ComputeClientV3.submit()``, ``ComputeClientV2.submit()``,
``ComputeClientV2.get_task()``, ``ComputeClientV2.get_task_batch()``,
and ``ComputeClientV2.get_task_group()`` methods. (:pr:`NUMBER`)
22 changes: 22 additions & 0 deletions src/globus_sdk/_testing/data/compute/_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
import uuid

ENDPOINT_ID = str(uuid.uuid1())

FUNCTION_ID = str(uuid.uuid1())
FUNCTION_NAME = "howdy_world"
FUNCTION_CODE = "410\n10\n04\n:gASVQAAAAAAAAACMC2hvd2R5X3dvc ..."

TASK_GROUP_ID = str(uuid.uuid1())
TASK_ID = str(uuid.uuid1())
TASK_ID_2 = str(uuid.uuid1())
TASK_DOC = {
"task_id": TASK_ID,
"status": "success",
"result": "10000",
"completion_t": "1677183605.212898",
"details": {
"os": "Linux-5.19.0-1025-aws-x86_64-with-glibc2.35",
"python_version": "3.10.4",
"dill_version": "0.3.5.1",
"globus_compute_sdk_version": "2.3.2",
"task_transitions": {
"execution-start": 1692742841.843334,
"execution-end": 1692742846.123456,
},
},
}
13 changes: 13 additions & 0 deletions src/globus_sdk/_testing/data/compute/v2/get_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from globus_sdk._testing.models import RegisteredResponse, ResponseSet

from .._common import TASK_DOC, TASK_ID

RESPONSES = ResponseSet(
metadata={"task_id": TASK_ID},
default=RegisteredResponse(
service="compute",
path=f"/v2/tasks/{TASK_ID}",
method="GET",
json=TASK_DOC,
),
)
22 changes: 22 additions & 0 deletions src/globus_sdk/_testing/data/compute/v2/get_task_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from responses.matchers import json_params_matcher

from globus_sdk._testing.models import RegisteredResponse, ResponseSet

from .._common import TASK_DOC, TASK_ID

TASK_BATCH_DOC = {
"response": "batch",
"results": {TASK_ID: TASK_DOC},
}

RESPONSES = ResponseSet(
metadata={"task_id": TASK_ID},
default=RegisteredResponse(
service="compute",
path="/v2/tasks/batch",
method="POST",
json=TASK_BATCH_DOC,
# Ensure task_ids is a list
match=[json_params_matcher({"task_ids": [TASK_ID]})],
),
)
26 changes: 26 additions & 0 deletions src/globus_sdk/_testing/data/compute/v2/get_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from globus_sdk._testing.models import RegisteredResponse, ResponseSet

from .._common import TASK_GROUP_ID, TASK_ID, TASK_ID_2

TASK_BATCH_DOC = {
"taskgroup_id": TASK_GROUP_ID,
"create_websockets_queue": True,
"tasks": [
{"id": TASK_ID, "created_at": "2021-05-05T15:00:00.000000"},
{"id": TASK_ID_2, "created_at": "2021-05-05T15:01:00.000000"},
],
}

RESPONSES = ResponseSet(
metadata={
"task_group_id": TASK_GROUP_ID,
"task_id": TASK_ID,
"task_id_2": TASK_ID_2,
},
default=RegisteredResponse(
service="compute",
path=f"/v2/taskgroup/{TASK_GROUP_ID}",
method="GET",
json=TASK_BATCH_DOC,
),
)
36 changes: 36 additions & 0 deletions src/globus_sdk/_testing/data/compute/v2/submit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from globus_sdk._testing.models import RegisteredResponse, ResponseSet

from .._common import TASK_GROUP_ID, TASK_ID, TASK_ID_2

SUBMIT_RESPONSE = {
"response": "success",
"task_group_id": TASK_GROUP_ID,
"results": [
{
"status": "success",
"task_uuid": TASK_ID,
"http_status_code": 200,
"reason": None,
},
{
"status": "success",
"task_uuid": TASK_ID_2,
"http_status_code": 200,
"reason": None,
},
],
}

RESPONSES = ResponseSet(
metadata={
"task_group_id": TASK_GROUP_ID,
"task_id": TASK_ID,
"task_id_2": TASK_ID_2,
},
default=RegisteredResponse(
service="compute",
path="/v2/submit",
method="POST",
json=SUBMIT_RESPONSE,
),
)
Empty file.
31 changes: 31 additions & 0 deletions src/globus_sdk/_testing/data/compute/v3/submit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from globus_sdk._testing.models import RegisteredResponse, ResponseSet

from .._common import ENDPOINT_ID, FUNCTION_ID, TASK_GROUP_ID, TASK_ID, TASK_ID_2

REQUEST_ID = "5158de19-10b5-4deb-9d87-a86c1dec3460"

SUBMIT_RESPONSE = {
"request_id": REQUEST_ID,
"task_group_id": TASK_GROUP_ID,
"endpoint_id": ENDPOINT_ID,
"tasks": {
FUNCTION_ID: [TASK_ID, TASK_ID_2],
},
}

RESPONSES = ResponseSet(
metadata={
"endpoint_id": ENDPOINT_ID,
"function_id": FUNCTION_ID,
"task_id": TASK_ID,
"task_id_2": TASK_ID_2,
"task_group_id": TASK_GROUP_ID,
"request_id": REQUEST_ID,
},
default=RegisteredResponse(
service="compute",
path=f"/v3/endpoints/{ENDPOINT_ID}/submit",
method="POST",
json=SUBMIT_RESPONSE,
),
)
83 changes: 82 additions & 1 deletion src/globus_sdk/services/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import typing as t

from globus_sdk import GlobusHTTPResponse, client
from globus_sdk import GlobusHTTPResponse, client, utils
from globus_sdk._types import UUIDLike
from globus_sdk.scopes import ComputeScopes, Scope

Expand Down Expand Up @@ -72,6 +72,69 @@ def delete_function(self, function_id: UUIDLike) -> GlobusHTTPResponse:
""" # noqa: E501
return self.delete(f"/v2/functions/{function_id}")

def get_task(self, task_id: UUIDLike) -> GlobusHTTPResponse:
"""Get information about a task.
:param task_id: The ID of the task.
.. tab-set::
.. tab-item:: API Info
.. extdoclink:: Get Task
:service: compute
:ref: Tasks/operation/get_task_status_and_result_v2_tasks__task_uuid__get
""" # noqa: E501
return self.get(f"/v2/tasks/{task_id}")

def get_task_batch(
self, task_ids: UUIDLike | t.Iterable[UUIDLike]
) -> GlobusHTTPResponse:
"""Get information about a batch of tasks.
:param task_ids: The IDs of the tasks.
.. tab-set::
.. tab-item:: API Info
.. extdoclink:: Get Task Batch
:service: compute
:ref: Root/operation/get_batch_status_v2_batch_status_post
"""
task_ids = list(utils.safe_strseq_iter(task_ids))
return self.post("/v2/tasks/batch", data={"task_ids": task_ids})

def get_task_group(self, task_group_id: UUIDLike) -> GlobusHTTPResponse:
"""Get a list of task IDs associated with a task group.
:param task_group_id: The ID of the task group.
.. tab-set::
.. tab-item:: API Info
.. extdoclink:: Get Task Group Tasks
:service: compute
:ref: TaskGroup/operation/get_task_group_tasks_v2_taskgroup__task_group_uuid__get
""" # noqa: E501
return self.get(f"/v2/taskgroup/{task_group_id}")

def submit(self, data: dict[str, t.Any]) -> GlobusHTTPResponse:
"""Submit a batch of tasks to a Globus Compute endpoint.
:param data: The task batch document.
.. tab-set::
.. tab-item:: API Info
.. extdoclink:: Submit Batch
:service: compute
:ref: Root/operation/submit_batch_v2_submit_post
""" # noqa: E501
return self.post("/v2/submit", data=data)


class ComputeClientV3(client.BaseClient):
r"""
Expand All @@ -85,6 +148,24 @@ class ComputeClientV3(client.BaseClient):
scopes = ComputeScopes
default_scope_requirements = [Scope(ComputeScopes.all)]

def submit(
self, endpoint_id: UUIDLike, data: dict[str, t.Any]
) -> GlobusHTTPResponse:
"""Submit a batch of tasks to a Globus Compute endpoint.
:param endpoint_id: The ID of the Globus Compute endpoint.
:param data: The task batch document.
.. tab-set::
.. tab-item:: API Info
.. extdoclink:: Submit Batch
:service: compute
:ref: Endpoints/operation/submit_batch_v3_endpoints__endpoint_uuid__submit_post
""" # noqa: E501
return self.post(f"/v3/endpoints/{endpoint_id}/submit", data=data)


class ComputeClient(ComputeClientV2):
r"""
Expand Down
8 changes: 8 additions & 0 deletions tests/functional/services/compute/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@ class CustomComputeClientV2(globus_sdk.ComputeClientV2):
transport_class = no_retry_transport

return CustomComputeClientV2()


@pytest.fixture
def compute_client_v3(no_retry_transport):
class CustomComputeClientV3(globus_sdk.ComputeClientV3):
transport_class = no_retry_transport

return CustomComputeClientV3()
28 changes: 28 additions & 0 deletions tests/functional/services/compute/v2/test_get_task_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import typing as t
import uuid

import pytest

import globus_sdk
from globus_sdk._testing import load_response


@pytest.mark.parametrize(
"transform",
(
pytest.param(lambda x: x, id="string"),
pytest.param(lambda x: [x], id="list"),
pytest.param(lambda x: {x}, id="set"),
pytest.param(lambda x: uuid.UUID(x), id="uuid"),
pytest.param(lambda x: [uuid.UUID(x)], id="uuid_list"),
),
)
def test_get_task_batch(
compute_client_v2: globus_sdk.ComputeClientV2, transform: t.Callable
):
meta = load_response(compute_client_v2.get_task_batch).metadata
task_ids = transform(meta["task_id"])
res = compute_client_v2.get_task_batch(task_ids=task_ids)

assert res.http_status == 200
assert meta["task_id"] in res.data["results"]
11 changes: 11 additions & 0 deletions tests/functional/services/compute/v2/test_get_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import globus_sdk
from globus_sdk._testing import load_response


def test_get_task_group(compute_client_v2: globus_sdk.ComputeClientV2):
meta = load_response(compute_client_v2.get_task_group).metadata
res = compute_client_v2.get_task_group(task_group_id=meta["task_group_id"])
assert res.http_status == 200
assert meta["task_group_id"] == res.data["taskgroup_id"]
assert meta["task_id"] == res.data["tasks"][0]["id"]
assert meta["task_id_2"] == res.data["tasks"][1]["id"]
9 changes: 9 additions & 0 deletions tests/functional/services/compute/v2/test_get_task_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import globus_sdk
from globus_sdk._testing import load_response


def test_get_task(compute_client_v2: globus_sdk.ComputeClientV2):
meta = load_response(compute_client_v2.get_task).metadata
res = compute_client_v2.get_task(task_id=meta["task_id"])
assert res.http_status == 200
assert res.data["task_id"] == meta["task_id"]
25 changes: 25 additions & 0 deletions tests/functional/services/compute/v2/test_submit_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import uuid

import globus_sdk
from globus_sdk._testing import load_response


def test_submit(compute_client_v2: globus_sdk.ComputeClientV2):
meta = load_response(compute_client_v2.submit).metadata
ep_id, func_id = uuid.uuid1(), uuid.uuid1()
submit_doc = {
"task_group_id": meta["task_group_id"],
"create_websocket_queue": False,
"tasks": [
[func_id, ep_id, "36\n00\ngASVDAAAAAAAAACMBlJvZG5leZSFlC4=\n12 ..."],
[func_id, ep_id, "36\n00\ngASVCwAAAAAAAACMBUJvYmJ5lIWULg==\n12 ..."],
],
}

res = compute_client_v2.submit(data=submit_doc)

assert res.http_status == 200
assert res.data["task_group_id"] == meta["task_group_id"]
results = res.data["results"]
assert results[0]["task_uuid"] == meta["task_id"]
assert results[1]["task_uuid"] == meta["task_id_2"]
31 changes: 31 additions & 0 deletions tests/functional/services/compute/v3/test_submit_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import globus_sdk
from globus_sdk._testing import load_response


def test_submit(compute_client_v3: globus_sdk.ComputeClientV3):
meta = load_response(compute_client_v3.submit).metadata
submit_doc = {
"tasks": {
meta["function_id"]: [
"36\n00\ngASVDAAAAAAAAACMBlJvZG5leZSFlC4=\n12 ...",
"36\n00\ngASVCwAAAAAAAACMBUJvYmJ5lIWULg==\n12 ...",
],
},
"task_group_id": meta["task_group_id"],
"create_queue": True,
"user_runtime": {
"globus_compute_sdk_version": "2.29.0",
"globus_sdk_version": "3.46.0",
"python_version": "3.11.9",
},
}

res = compute_client_v3.submit(endpoint_id=meta["endpoint_id"], data=submit_doc)

assert res.http_status == 200
assert res.data["request_id"] == meta["request_id"]
assert res.data["task_group_id"] == meta["task_group_id"]
assert res.data["endpoint_id"] == meta["endpoint_id"]
assert res.data["tasks"] == {
meta["function_id"]: [meta["task_id"], meta["task_id_2"]]
}

0 comments on commit 5270f0e

Please sign in to comment.