Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYNPY-1513] Validate input submission ID in getSubmission(...) #1135

Merged
merged 16 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions synapseclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
is_integer,
is_json,
require_param,
validate_submission_id,
)
from synapseclient.core.version_check import version_check

Expand Down Expand Up @@ -4807,12 +4808,15 @@ def _POST_paginated(self, uri: str, body, **kwargs):
if next_page_token is None:
break

def getSubmission(self, id, **kwargs):
def getSubmission(
self, id: typing.Union[str, int, collections.abc.Mapping], **kwargs
) -> Submission:
"""
Gets a [synapseclient.evaluation.Submission][] object by its id.
Gets a [synapseclient.evaluation.Submission][] object based on a given ID
or previous [synapseclient.evaluation.Submission][] object.

Arguments:
id: The id of the submission to retrieve
id: The ID of the submission to retrieve or a [synapseclient.evaluation.Submission][] object

Returns:
A [synapseclient.evaluation.Submission][] object
Expand All @@ -4823,7 +4827,7 @@ def getSubmission(self, id, **kwargs):
on the *downloadFile*, *downloadLocation*, and *ifcollision* parameters
"""

submission_id = id_of(id)
submission_id = validate_submission_id(id)
jaymedina marked this conversation as resolved.
Show resolved Hide resolved
uri = Submission.getURI(submission_id)
submission = Submission(**self.restGET(uri))

Expand Down Expand Up @@ -4852,18 +4856,20 @@ def getSubmission(self, id, **kwargs):

return submission

def getSubmissionStatus(self, submission):
def getSubmissionStatus(
self, submission: typing.Union[str, int, collections.abc.Mapping]
) -> SubmissionStatus:
"""
Downloads the status of a Submission.
Downloads the status of a Submission given its ID or previous [synapseclient.evaluation.Submission][] object.

Arguments:
submission: The submission to lookup
submission: The submission to lookup (ID or [synapseclient.evaluation.Submission][] object)

Returns:
A [synapseclient.evaluation.SubmissionStatus][] object
"""

submission_id = id_of(submission)
submission_id = validate_submission_id(submission)
uri = SubmissionStatus.getURI(submission_id)
val = self.restGET(uri)
return SubmissionStatus(**val)
Expand Down
45 changes: 45 additions & 0 deletions synapseclient/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import hashlib
import importlib
import inspect
import logging
import numbers
import os
import platform
Expand All @@ -30,6 +31,8 @@
import requests
from opentelemetry import trace

from synapseclient.core.logging_setup import DEFAULT_LOGGER_NAME

if TYPE_CHECKING:
from synapseclient.models import File, Folder, Project

Expand All @@ -47,6 +50,11 @@

SLASH_PREFIX_REGEX = re.compile(r"\/[A-Za-z]:")

# Set up logging
LOGGER_NAME = DEFAULT_LOGGER_NAME
jaymedina marked this conversation as resolved.
Show resolved Hide resolved
LOGGER = logging.getLogger(LOGGER_NAME)
logging.getLogger("py.warnings").handlers = LOGGER.handlers


def md5_for_file(
filename: str, block_size: int = 2 * MB, callback: typing.Callable = None
Expand Down Expand Up @@ -242,6 +250,43 @@ def id_of(obj: typing.Union[str, collections.abc.Mapping, numbers.Number]) -> st
raise ValueError("Invalid parameters: couldn't find id of " + str(obj))


def validate_submission_id(
submission_id: typing.Union[str, int, collections.abc.Mapping]
) -> str:
"""
Ensures that a given submission ID is either an integer or a string that
can be converted to an integer. Version notation is not supported for submission
IDs, therefore decimals are not allowed.

Arguments:
submission_id: The submission ID to validate

Returns:
The submission ID as a string

"""
if isinstance(submission_id, int):
return str(submission_id)
elif isinstance(submission_id, str) and submission_id.isdigit():
return submission_id
elif isinstance(submission_id, collections.abc.Mapping):
syn_id = _get_from_members_items_or_properties(submission_id, "id")
if syn_id is not None:
return validate_submission_id(syn_id)
else:
try:
int_submission_id = int(float(submission_id))
except ValueError:
raise ValueError(
f"Submission ID '{submission_id}' is not a valid submission ID. Please use digits only."
)
LOGGER.warning(
f"Submission ID '{submission_id}' contains decimals which are not supported. "
f"Submission ID will be converted to '{int_submission_id}'."
)
return str(int_submission_id)


def concrete_type_of(obj: collections.abc.Mapping):
"""
Return the concrete type of an object representing a Synapse entity.
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/synapseclient/core/unit_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# unit tests for utils.py

import base64
import logging
import os
import re
import tempfile
Expand Down Expand Up @@ -100,6 +101,49 @@ def __init__(self, id_attr_name: str, id: str) -> None:
assert utils.id_of(foo) == "123"


@pytest.mark.parametrize(
"input_value, expected_output, expected_warning",
[
# Test 1: Valid inputs
("123", "123", None),
(123, "123", None),
({"id": "222"}, "222", None),
# Test 2: Invalid inputs that should be corrected
(
"123.0",
"123",
"Submission ID '123.0' contains decimals which are not supported",
),
(
123.0,
"123",
"Submission ID '123.0' contains decimals which are not supported",
),
(
{"id": "999.222"},
"999",
"Submission ID '999.222' contains decimals which are not supported",
),
],
)
def test_validate_submission_id(input_value, expected_output, expected_warning, caplog):
with caplog.at_level(logging.WARNING):
assert utils.validate_submission_id(input_value) == expected_output
if expected_warning:
assert expected_warning in caplog.text
else:
assert not caplog.text


def test_validate_submission_id_letters_input() -> None:
letters_input = "syn123"
expected_error = f"Submission ID '{letters_input}' is not a valid submission ID. Please use digits only."
with pytest.raises(ValueError) as err:
utils.validate_submission_id(letters_input)

assert str(err.value) == expected_error


# TODO: Add a test for is_synapse_id_str(...)
# https://sagebionetworks.jira.com/browse/SYNPY-1425

Expand Down
152 changes: 152 additions & 0 deletions tests/unit/synapseclient/unit_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import tempfile
import typing
import urllib.request as urllib_request
import uuid
from pathlib import Path
Expand Down Expand Up @@ -59,6 +60,7 @@
)
from synapseclient.core.models.dict_object import DictObject
from synapseclient.core.upload import upload_functions
from synapseclient.evaluation import Submission, SubmissionStatus

GET_FILE_HANDLE_FOR_DOWNLOAD = (
"synapseclient.core.download.download_functions.get_file_handle_for_download_async"
Expand Down Expand Up @@ -2995,6 +2997,156 @@ def test_get_submission_with_annotations(syn: Synapse) -> None:
assert evaluation_id == response["evaluationId"]


def run_get_submission_test(
syn: Synapse,
submission_id: typing.Union[str, int],
expected_id: str,
should_warn: bool = False,
caplog=None,
) -> None:
"""
Common code for test_get_submission_valid_id and test_get_submission_invalid_id.
Generates a dummy submission dictionary for regression testing, mocks the API calls,
and validates the expected output for getSubmission. For invalid submission IDs, this
will check that a warning was logged for the user before transforming their input.

Arguments:
syn: Synapse object
submission_id: Submission ID to test
expected_id: Submission ID that should be returned
should_warn: Whether or not a warning should be logged
caplog: pytest caplog fixture

Returns:
None

"""
evaluation_id = (98765,)
submission = {
"evaluationId": evaluation_id,
"entityId": submission_id,
"versionNumber": 1,
"entityBundleJSON": json.dumps({}),
}

with patch.object(syn, "restGET") as restGET, patch.object(
syn, "_getWithEntityBundle"
) as get_entity:
restGET.return_value = submission

if should_warn:
with caplog.at_level(logging.WARNING):
syn.getSubmission(submission_id)
assert f"contains decimals which are not supported" in caplog.text
else:
syn.getSubmission(submission_id)

restGET.assert_called_once_with(f"/evaluation/submission/{expected_id}")
get_entity.assert_called_once_with(
entityBundle={},
entity=submission_id,
submission=str(expected_id),
)


@pytest.mark.parametrize(
"submission_id, expected_id",
[("123", "123"), (123, "123"), ({"id": 123}, "123"), ({"id": "123"}, "123")],
)
def test_get_submission_valid_id(syn: Synapse, submission_id, expected_id) -> None:
"""Test getSubmission with valid submission ID"""
run_get_submission_test(syn, submission_id, expected_id)


@pytest.mark.parametrize(
"submission_id, expected_id",
[
("123.0", "123"),
(123.0, "123"),
({"id": 123.0}, "123"),
({"id": "123.0"}, "123"),
],
)
def test_get_submission_invalid_id(
syn: Synapse, submission_id, expected_id, caplog
) -> None:
"""Test getSubmission with invalid submission ID"""
run_get_submission_test(
syn, submission_id, expected_id, should_warn=True, caplog=caplog
)


def test_get_submission_and_submission_status_interchangeability(
syn: Synapse, caplog
) -> None:
"""Test interchangeability of getSubmission and getSubmissionStatus."""

# Establish some dummy variables to work with
evaluation_id = 98765
submission_id = 9745366.0
expected_submission_id = "9745366"

# Establish an expected return for `getSubmissionStatus`
submission_status_return = {
"id": expected_submission_id,
"etag": "000",
"status": "RECEIVED",
}

# Establish an expected return for `getSubmission`
submission_return = {
"id": expected_submission_id,
"evaluationId": evaluation_id,
"entityId": expected_submission_id,
"versionNumber": 1,
"entityBundleJSON": json.dumps({}),
}

# Let's mock all the API calls made within these two methods
with patch.object(syn, "restGET") as restGET, patch.object(
Submission, "getURI"
) as get_submission_uri, patch.object(
SubmissionStatus, "getURI"
) as get_status_uri, patch.object(
syn, "_getWithEntityBundle"
):
get_submission_uri.return_value = (
f"/evaluation/submission/{expected_submission_id}"
)
get_status_uri.return_value = (
f"/evaluation/submission/{expected_submission_id}/status"
)

# Establish a return for all the calls to restGET we will be making in this test
restGET.side_effect = [
# Step 1 call to `getSubmission`
submission_return,
# Step 2 call to `getSubmissionStatus`
submission_status_return,
]

# Step 1: Call `getSubmission` with float ID
restGET.return_value = submission_return
submission_result = syn.getSubmission(submission_id)

# Step 2: Call `getSubmissionStatus` with the `Submission` object from above
restGET.reset_mock()
restGET.return_value = submission_status_return
submission_status_result = syn.getSubmissionStatus(submission_result)

# Validate that getSubmission and getSubmissionStatus are called with correct URIs
# in `getURI` calls
get_submission_uri.assert_called_once_with(expected_submission_id)
get_status_uri.assert_called_once_with(expected_submission_id)

# Validate final output is as expected
assert (
submission_result["id"]
== submission_status_result["id"]
== expected_submission_id
)


class TestTableSnapshot:
def test__create_table_snapshot(self, syn: Synapse) -> None:
"""Testing creating table snapshots"""
Expand Down
Loading