From 3d0e852f650c0f347a31e8fc8af7cfd3015900ba Mon Sep 17 00:00:00 2001 From: Jos van der Velde Date: Fri, 1 Dec 2023 11:50:46 +0100 Subject: [PATCH 1/3] Validation of platform_resource_identifiers, primarily used to make sure that the HuggingFace datasets have a correct platform_resource_identifier --- .../huggingface_dataset_connector.py | 3 +- src/database/model/concept/concept.py | 20 ++- src/database/validators/__init__.py | 0 .../validators/huggingface_validators.py | 36 +++++ src/database/validators/openml_validators.py | 11 ++ .../test_huggingface_dataset_connector.py | 2 +- .../resource_routers/test_router_dataset.py | 34 +++++ src/tests/uploader/huggingface/__init__.py | 0 .../huggingface/test_dataset_uploader.py | 131 ++++++++++++++---- src/tests/validators/__init__.py | 0 .../validators/test_huggingface_validators.py | 43 ++++++ src/uploader/hugging_face_uploader.py | 65 +++++++-- 12 files changed, 298 insertions(+), 47 deletions(-) create mode 100644 src/database/validators/__init__.py create mode 100644 src/database/validators/huggingface_validators.py create mode 100644 src/database/validators/openml_validators.py create mode 100644 src/tests/uploader/huggingface/__init__.py create mode 100644 src/tests/validators/__init__.py create mode 100644 src/tests/validators/test_huggingface_validators.py diff --git a/src/connectors/huggingface/huggingface_dataset_connector.py b/src/connectors/huggingface/huggingface_dataset_connector.py index 5071fa45..fc3cb28f 100644 --- a/src/connectors/huggingface/huggingface_dataset_connector.py +++ b/src/connectors/huggingface/huggingface_dataset_connector.py @@ -1,5 +1,6 @@ import logging import typing + import bibtexparser import requests from huggingface_hub import list_datasets @@ -154,8 +155,6 @@ def _parse_citations(self, dataset, pydantic_class_publication) -> list: ] return [ pydantic_class_publication( - platform=self.platform_name, - platform_resource_identifier=citation["ID"], name=citation["title"], same_as=citation["link"] if "link" in citation else None, type=citation["ENTRYTYPE"], diff --git a/src/database/model/concept/concept.py b/src/database/model/concept/concept.py index d31f2529..40912e62 100644 --- a/src/database/model/concept/concept.py +++ b/src/database/model/concept/concept.py @@ -3,6 +3,7 @@ import os from typing import Optional, Tuple +from pydantic import validator from sqlalchemy import CheckConstraint, Index from sqlalchemy.orm import declared_attr from sqlalchemy.sql.functions import coalesce @@ -13,6 +14,7 @@ from database.model.platform.platform_names import PlatformName from database.model.relationships import OneToOne from database.model.serializers import CastDeserializer +from database.validators import huggingface_validators, openml_validators IS_SQLITE = os.getenv("DB") == "SQLite" CONSTRAINT_LOWERCASE = f"{'platform' if IS_SQLITE else 'BINARY(platform)'} = LOWER(platform)" @@ -32,11 +34,27 @@ class AIoDConceptBase(SQLModel): platform_resource_identifier: str | None = Field( max_length=NORMAL, description="A unique identifier issued by the external platform that's specified in " - "'platform'. Leave empty if this item is not part of an external platform.", + "'platform'. Leave empty if this item is not part of an external platform. For example, " + "for HuggingFace, this should be the /, and for Openml, the " + "OpenML identifier.", default=None, schema_extra={"example": "1"}, ) + @validator("platform_resource_identifier") + def platform_resource_identifier_valid(cls, platform_resource_identifier: str, values) -> str: + if platform := values.get("platform", None): + match platform: + case PlatformName.huggingface: + huggingface_validators.throw_error_on_invalid_identifier( + platform_resource_identifier + ) + case PlatformName.openml: + openml_validators.throw_error_on_invalid_identifier( + platform_resource_identifier + ) + return platform_resource_identifier + class AIoDConcept(AIoDConceptBase): identifier: int = Field(default=None, primary_key=True) diff --git a/src/database/validators/__init__.py b/src/database/validators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/database/validators/huggingface_validators.py b/src/database/validators/huggingface_validators.py new file mode 100644 index 00000000..33203cfa --- /dev/null +++ b/src/database/validators/huggingface_validators.py @@ -0,0 +1,36 @@ +import re + +REPO_ID_ILLEGAL_CHARACTERS = re.compile(r"[^0-9a-zA-Z-_./]+") +MSG_PREFIX = "The platform_resource_identifier for HuggingFace should be a valid repo_id. " + + +def throw_error_on_invalid_identifier(platform_resource_identifier: str): + """ + Throw a ValueError on an invalid repository identifier. + + Valid repo_ids: + Between 1 and 96 characters. + Either “repo_name” or “namespace/repo_name” + [a-zA-Z0-9] or ”-”, ”_”, ”.” + ”—” and ”..” are forbidden + + Refer to: + https://huggingface.co/docs/huggingface_hub/package_reference/utilities#huggingface_hub.utils.validate_repo_id + """ + repo_id = platform_resource_identifier + if REPO_ID_ILLEGAL_CHARACTERS.search(repo_id): + msg = "A repo_id should only contain [a-zA-Z0-9] or ”-”, ”_”, ”.”" + raise ValueError(MSG_PREFIX + msg) + if not (1 < len(repo_id) < 96): + msg = "A repo_id should be between 1 and 96 characters." + raise ValueError(MSG_PREFIX + msg) + if repo_id.count("/") > 1: + msg = ( + "For new repositories, there should be a single forward slash in the repo_id (" + "namespace/repo_name). Legacy repositories are without a namespace. This repo_id has " + "too many forward slashes." + ) + raise ValueError(MSG_PREFIX + msg) + if ".." in repo_id: + msg = "A repo_id may not contain multiple consecutive dots." + raise ValueError(MSG_PREFIX + msg) diff --git a/src/database/validators/openml_validators.py b/src/database/validators/openml_validators.py new file mode 100644 index 00000000..39b962d1 --- /dev/null +++ b/src/database/validators/openml_validators.py @@ -0,0 +1,11 @@ +MSG = "An OpenML platform_resource_identifier should be a positive integer." + + +def throw_error_on_invalid_identifier(platform_resource_identifier: str): + """Throw a ValueError on an invalid repository identifier.""" + try: + openml_identifier = int(platform_resource_identifier) + except ValueError: + raise ValueError(MSG) + if openml_identifier < 0: + raise ValueError(MSG) diff --git a/src/tests/connectors/huggingface/test_huggingface_dataset_connector.py b/src/tests/connectors/huggingface/test_huggingface_dataset_connector.py index 613194e8..7911d59d 100644 --- a/src/tests/connectors/huggingface/test_huggingface_dataset_connector.py +++ b/src/tests/connectors/huggingface/test_huggingface_dataset_connector.py @@ -124,7 +124,7 @@ def test_incorrect_citation(): citation.name == "ArCOV-19: The First Arabic COVID-19 Twitter Dataset with Propagation " "Networks" ) - assert citation.platform_resource_identifier == "haouari2020arcov19" + assert citation.platform_resource_identifier is None assert citation.type == "article" assert ( citation.description.plain == "By Fatima Haouari and Maram Hasanain and Reem Suwaileh " diff --git a/src/tests/routers/resource_routers/test_router_dataset.py b/src/tests/routers/resource_routers/test_router_dataset.py index 4b21b712..00cd9f35 100644 --- a/src/tests/routers/resource_routers/test_router_dataset.py +++ b/src/tests/routers/resource_routers/test_router_dataset.py @@ -1,6 +1,7 @@ import copy from unittest.mock import Mock +from starlette import status from starlette.testclient import TestClient from authentication import keycloak_openid @@ -50,3 +51,36 @@ def test_happy_path( "geo": {"latitude": 37.42242, "longitude": -122.08585, "elevation_millimeters": 2000}, } # TODO: test delete + + +def test_post_invalid_huggingface_identifier( + client: TestClient, + mocked_privileged_token: Mock, +): + keycloak_openid.userinfo = mocked_privileged_token + + body = {"name": "name", "platform": "huggingface", "platform_resource_identifier": "a"} + + response = client.post("/datasets/v1", json=body, headers={"Authorization": "Fake token"}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json() + assert ( + response.json()["detail"][0]["msg"] + == "The platform_resource_identifier for HuggingFace should be a valid repo_id. A repo_id " + "should be between 1 and 96 characters." + ) + + +def test_post_invalid_openml_identifier( + client: TestClient, + mocked_privileged_token: Mock, +): + keycloak_openid.userinfo = mocked_privileged_token + + body = {"name": "name", "platform": "openml", "platform_resource_identifier": "a"} + + response = client.post("/datasets/v1", json=body, headers={"Authorization": "Fake token"}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json() + assert ( + response.json()["detail"][0]["msg"] + == "An OpenML platform_resource_identifier should be a positive integer." + ) diff --git a/src/tests/uploader/huggingface/__init__.py b/src/tests/uploader/huggingface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tests/uploader/huggingface/test_dataset_uploader.py b/src/tests/uploader/huggingface/test_dataset_uploader.py index 8b8e09e7..b9a75f6c 100644 --- a/src/tests/uploader/huggingface/test_dataset_uploader.py +++ b/src/tests/uploader/huggingface/test_dataset_uploader.py @@ -1,29 +1,31 @@ +import copy from unittest.mock import Mock import huggingface_hub +import pytest import responses +from starlette import status from starlette.testclient import TestClient from authentication import keycloak_openid -from database.model.ai_asset.ai_asset_table import AIAssetTable from database.model.dataset.dataset import Dataset from database.session import DbSession from tests.testutils.paths import path_test_resources +from uploader.hugging_face_uploader import _throw_error_on_invalid_repo_id def test_happy_path_new_repository( client: TestClient, mocked_privileged_token: Mock, dataset: Dataset ): + dataset = copy.deepcopy(dataset) + dataset.platform = "huggingface" + dataset.platform_resource_identifier = "Fake-username/test" + keycloak_openid.userinfo = mocked_privileged_token with DbSession() as session: session.add(dataset) session.commit() - data = { - "token": "huggingface_token", - "username": "username", - } - with open(path_test_resources() / "uploaders" / "huggingface" / "example.csv", "rb") as f: files = {"file": f.read()} @@ -37,39 +39,27 @@ def test_happy_path_new_repository( huggingface_hub.upload_file = Mock(return_value=None) response = client.post( "/upload/datasets/1/huggingface", - data=data, params={"username": "Fake-username", "token": "Fake-token"}, headers={"Authorization": "Fake token"}, files=files, ) + assert response.status_code == 200, response.json() id_response = response.json() assert id_response == 1 -def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock): +def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock, dataset: Dataset): keycloak_openid.userinfo = mocked_privileged_token - dataset_id = 1 + + dataset = copy.deepcopy(dataset) + dataset.platform = "huggingface" + dataset.platform_resource_identifier = "Fake-username/test" + with DbSession() as session: - session.add_all( - [ - AIAssetTable(type="dataset"), - Dataset( - identifier=dataset_id, - name="Parent", - platform="example", - platform_resource_identifier="1", - same_as="", - ), - ] - ) + session.add(dataset) session.commit() - data = { - "token": "huggingface_token", - "username": "username", - } - with open(path_test_resources() / "uploaders" / "huggingface" / "example.csv", "rb") as f: files = {"file": f.read()} @@ -85,15 +75,96 @@ def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock): ) huggingface_hub.upload_file = Mock(return_value=None) response = client.post( - f"/upload/datasets/{dataset_id}/huggingface", - data=data, + "/upload/datasets/1/huggingface", params={"username": "Fake-username", "token": "Fake-token"}, headers={"Authorization": "Fake token"}, files=files, ) assert response.status_code == 200, response.json() id_response = response.json() - assert id_response == dataset_id + assert id_response == 1 + + +def test_wrong_platform(client: TestClient, mocked_privileged_token: Mock, dataset: Dataset): + keycloak_openid.userinfo = mocked_privileged_token + + dataset = copy.deepcopy(dataset) + dataset.platform = "example" + dataset.platform_resource_identifier = "Fake-username/test" + + with DbSession() as session: + session.add(dataset) + session.commit() + + with open(path_test_resources() / "uploaders" / "huggingface" / "example.csv", "rb") as f: + files = {"file": f.read()} + + response = client.post( + "/upload/datasets/1/huggingface", + params={"username": "Fake-username", "token": "Fake-token"}, + headers={"Authorization": "Fake token"}, + files=files, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST, response.json() + assert ( + response.json()["detail"] + == "The dataset with identifier 1 should have platform=PlatformName.huggingface." + ) -# TODO: tests some error handling? +@pytest.mark.parametrize( + "username,dataset_name,expected_error", + [ + ("0-hero", "0-hero/OIG-small-chip2", None), + ("user", "user/Foo-BAR_foo.bar123", None), + ( + "user", + "user/Test name with ?", + ValueError( + "The platform_resource_identifier for HuggingFace should be a valid repo_id. " + "A repo_id should only contain [a-zA-Z0-9] or ”-”, ”_”, ”.”" + ), + ), + ( + "username", + "acronym_identification", + ValueError( + "The username should be part of the platform_resource_identifier for HuggingFace: " + "username/acronym_identification. Please update the dataset " + "platform_resource_identifier." + ), + ), + ( + "user", + "user/data/set", + ValueError( + "The platform_resource_identifier for HuggingFace should be a valid repo_id. " + "For new repositories, there should be a single forward slash in the repo_id " + "(namespace/repo_name). Legacy repositories are without a namespace. This " + "repo_id has too many forward slashes." + ), + ), + ( + "user", + "wrong-namespace/name", + ValueError( + "The namespace should be equal to the username, but wrong-namespace != user." + ), + ), + ( + "user", + "user/" + "a" * 200, + ValueError( + "The platform_resource_identifier for HuggingFace should be a valid repo_id. " + "A repo_id should be between 1 and 96 characters." + ), + ), + ], +) +def test_repo_id(username: str, dataset_name: str, expected_error: ValueError | None): + if expected_error is None: + _throw_error_on_invalid_repo_id(username, dataset_name) + else: + with pytest.raises(type(expected_error)) as exception_info: + _throw_error_on_invalid_repo_id(username, dataset_name) + assert exception_info.value.args[0] == expected_error.args[0] diff --git a/src/tests/validators/__init__.py b/src/tests/validators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tests/validators/test_huggingface_validators.py b/src/tests/validators/test_huggingface_validators.py new file mode 100644 index 00000000..95b83ebd --- /dev/null +++ b/src/tests/validators/test_huggingface_validators.py @@ -0,0 +1,43 @@ +import pytest + +from database.validators import huggingface_validators + + +@pytest.mark.parametrize( + "identifier,expected_error", + [ + ("0-hero/OIG-small-chip2", None), + ("user/Foo-BAR_foo.bar123", None), + ("acronym_identification", None), + ( + "user/data/set", + ValueError( + "The platform_resource_identifier for HuggingFace should be a valid repo_id. For " + "new repositories, there should be a single forward slash in the repo_id " + "(namespace/repo_name). Legacy repositories are without a namespace. This repo_id " + "has too many forward slashes." + ), + ), + ( + "a", + ValueError( + "The platform_resource_identifier for HuggingFace should be a valid repo_id. A " + "repo_id should be between 1 and 96 characters." + ), + ), + ( + "user/" + "a" * 200, + ValueError( + "The platform_resource_identifier for HuggingFace should be a valid repo_id. A " + "repo_id should be between 1 and 96 characters." + ), + ), + ], +) +def test_identifier(identifier: str, expected_error: ValueError | None): + if expected_error is None: + huggingface_validators.throw_error_on_invalid_identifier(identifier) + else: + with pytest.raises(type(expected_error)) as exception_info: + huggingface_validators.throw_error_on_invalid_identifier(identifier) + assert exception_info.value.args[0] == expected_error.args[0] diff --git a/src/uploader/hugging_face_uploader.py b/src/uploader/hugging_face_uploader.py index fc7f18a2..dec6f50d 100644 --- a/src/uploader/hugging_face_uploader.py +++ b/src/uploader/hugging_face_uploader.py @@ -6,20 +6,27 @@ from sqlmodel import Session from database.model.dataset.dataset import Dataset +from database.model.platform.platform_names import PlatformName from database.session import DbSession +from database.validators import huggingface_validators from .utils import huggingface_license_identifiers -def handle_upload( - identifier: int, - file: UploadFile, - token: str, - username: str, -): +def handle_upload(identifier: int, file: UploadFile, token: str, username: str): with DbSession() as session: - dataset = _get_resource(session=session, identifier=identifier) - dataset_name_cleaned = "".join(c if c.isalnum() else "_" for c in dataset.name) - repo_id = f"{username}/{dataset_name_cleaned}" + dataset: Dataset = _get_resource(session=session, identifier=identifier) + repo_id = dataset.platform_resource_identifier + if dataset.platform != PlatformName.huggingface or not repo_id: + msg = ( + f"The dataset with identifier {dataset.identifier} should have platform=" + f"{PlatformName.huggingface}." + ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=msg) + + try: + _throw_error_on_invalid_repo_id(username, repo_id) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.args[0]) url = _create_or_get_repo_url(repo_id, token) metadata_file = _generate_metadata_file(dataset) @@ -61,8 +68,7 @@ def handle_upload( msg = f"Error uploading the file, unexpected error: {e.with_traceback}" raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) - if not any(data.name == repo_id for data in dataset.distribution): - _store_resource_updated(session, dataset, url, repo_id) + _store_resource_updated(session, dataset, url, repo_id) return dataset.identifier @@ -105,7 +111,7 @@ def _create_or_get_repo_url(repo_id, token): if "You already created this dataset repo" in e.args[0]: return f"https://huggingface.co/datasets/{repo_id}" else: - msg = "Error uploading the file, unexpected error" + msg = f"Unexpected error while creating the repository: {e}" raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) from e @@ -117,7 +123,7 @@ def _generate_metadata_file(dataset: Dataset) -> bytes: if tags: content += "tags:\n" content += "\n".join(tags) + "\n" - # TODO the license must be in the hugginface format: + # TODO the license must be in the huggingface format: # https://huggingface.co/docs/hub/repositories-licenses if dataset.license in huggingface_license_identifiers: @@ -127,3 +133,36 @@ def _generate_metadata_file(dataset: Dataset) -> bytes: content += f"# {dataset.name}\n" content += "Created from AIOD platform" # TODO add url return content.encode("utf-8") + + +def _throw_error_on_invalid_repo_id(username: str, platform_resource_identifier: str): + """ + Return a valid repository identifier, including namespace, for Huggingface, or raise an error, + + Valid repo_ids: + Between 1 and 96 characters. + Either “repo_name” or “namespace/repo_name” + [a-zA-Z0-9] or ”-”, ”_”, ”.” + ”—” and ”..” are forbidden + + Refer to: + https://huggingface.co/docs/huggingface_hub/package_reference/utilities#huggingface_hub.utils.validate_repo_id + """ + huggingface_validators.throw_error_on_invalid_identifier(platform_resource_identifier) + if "/" not in platform_resource_identifier: + msg = ( + f"The username should be part of the platform_resource_identifier for HuggingFace: " + f"{username}/{platform_resource_identifier}. Please update the dataset " + f"platform_resource_identifier." + ) + # In general, it's allowed in HuggingFace to have a dataset name without namespace. This + # is legacy: "The legacy GitHub datasets were added originally on our GitHub repository + # and therefore don’t have a namespace on the Hub". + # Any new dataset will therefore have a namespace. Since we're uploading a new dataset, + # we should not accept a legacy name. + raise ValueError(msg) + + namespace = platform_resource_identifier.split("/")[0] + if username != namespace: + msg = f"The namespace should be equal to the username, but {namespace} != {username}." + raise ValueError(msg) From e7cd283ada3388b72a8679d31bfa898131296774 Mon Sep 17 00:00:00 2001 From: Jos van der Velde Date: Fri, 1 Dec 2023 11:57:53 +0100 Subject: [PATCH 2/3] Additional comments --- .../huggingface/huggingface_dataset_connector.py | 3 +++ src/database/model/concept/concept.py | 7 +++++++ src/tests/uploader/huggingface/test_dataset_uploader.py | 3 ++- src/uploader/hugging_face_uploader.py | 5 ++++- 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/connectors/huggingface/huggingface_dataset_connector.py b/src/connectors/huggingface/huggingface_dataset_connector.py index fc3cb28f..8d17f7a9 100644 --- a/src/connectors/huggingface/huggingface_dataset_connector.py +++ b/src/connectors/huggingface/huggingface_dataset_connector.py @@ -155,6 +155,9 @@ def _parse_citations(self, dataset, pydantic_class_publication) -> list: ] return [ pydantic_class_publication( + # The platform and platform_resource_identifier should be None: this publication + # is not stored on HuggingFace (and not identifiable within HF using, + # for instance, citation["ID"]) name=citation["title"], same_as=citation["link"] if "link" in citation else None, type=citation["ENTRYTYPE"], diff --git a/src/database/model/concept/concept.py b/src/database/model/concept/concept.py index 40912e62..b3c74f5c 100644 --- a/src/database/model/concept/concept.py +++ b/src/database/model/concept/concept.py @@ -43,6 +43,13 @@ class AIoDConceptBase(SQLModel): @validator("platform_resource_identifier") def platform_resource_identifier_valid(cls, platform_resource_identifier: str, values) -> str: + """ + Throw a ValueError if the platform_resource_identifier is invalid for this platform. + + Note that field order matters: platform is defined before platform_resource_identifier, + so that this validator can use the value of the platform. Refer to + https://docs.pydantic.dev/1.10/usage/models/#field-ordering + """ if platform := values.get("platform", None): match platform: case PlatformName.huggingface: diff --git a/src/tests/uploader/huggingface/test_dataset_uploader.py b/src/tests/uploader/huggingface/test_dataset_uploader.py index b9a75f6c..07f2c609 100644 --- a/src/tests/uploader/huggingface/test_dataset_uploader.py +++ b/src/tests/uploader/huggingface/test_dataset_uploader.py @@ -148,7 +148,8 @@ def test_wrong_platform(client: TestClient, mocked_privileged_token: Mock, datas "user", "wrong-namespace/name", ValueError( - "The namespace should be equal to the username, but wrong-namespace != user." + "The namespace (the first part of the platform_resource_identifier) should be " + "equal to the username, but wrong-namespace != user." ), ), ( diff --git a/src/uploader/hugging_face_uploader.py b/src/uploader/hugging_face_uploader.py index dec6f50d..885aec51 100644 --- a/src/uploader/hugging_face_uploader.py +++ b/src/uploader/hugging_face_uploader.py @@ -164,5 +164,8 @@ def _throw_error_on_invalid_repo_id(username: str, platform_resource_identifier: namespace = platform_resource_identifier.split("/")[0] if username != namespace: - msg = f"The namespace should be equal to the username, but {namespace} != {username}." + msg = ( + f"The namespace (the first part of the platform_resource_identifier) should be " + f"equal to the username, but {namespace} != {username}." + ) raise ValueError(msg) From 3f0e878380c0ef00a03a5dd05b6045329d1bb0cf Mon Sep 17 00:00:00 2001 From: Jos van der Velde Date: Fri, 1 Dec 2023 14:49:27 +0100 Subject: [PATCH 3/3] Separate error if repo_id is None --- src/uploader/hugging_face_uploader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/uploader/hugging_face_uploader.py b/src/uploader/hugging_face_uploader.py index 885aec51..d62195b3 100644 --- a/src/uploader/hugging_face_uploader.py +++ b/src/uploader/hugging_face_uploader.py @@ -16,13 +16,17 @@ def handle_upload(identifier: int, file: UploadFile, token: str, username: str): with DbSession() as session: dataset: Dataset = _get_resource(session=session, identifier=identifier) repo_id = dataset.platform_resource_identifier - if dataset.platform != PlatformName.huggingface or not repo_id: + if dataset.platform != PlatformName.huggingface: msg = ( f"The dataset with identifier {dataset.identifier} should have platform=" f"{PlatformName.huggingface}." ) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=msg) - + if not repo_id: + # this if-statement is purely to make MyPy understand that the repo_id cannot be None. + # This is enforced by a CheckConstraint in the db, so this error will never be thrown. + msg = "Every dataset with a platform should also have a platform_resource_identifier" + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=msg) try: _throw_error_on_invalid_repo_id(username, repo_id) except ValueError as e: