Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Huggingface dataset uploader #216

Merged
merged 3 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/connectors/huggingface/huggingface_dataset_connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import typing

import bibtexparser
import requests
from huggingface_hub import list_datasets
Expand Down Expand Up @@ -154,8 +155,9 @@ def _parse_citations(self, dataset, pydantic_class_publication) -> list:
]
return [
pydantic_class_publication(
platform=self.platform_name,
platform_resource_identifier=citation["ID"],
# The platform and platform_resource_identifier should be None: this publication
# is not stored on HuggingFace (and not identifiable within HF using,
# for instance, citation["ID"])
name=citation["title"],
same_as=citation["link"] if "link" in citation else None,
type=citation["ENTRYTYPE"],
Expand Down
27 changes: 26 additions & 1 deletion src/database/model/concept/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from typing import Optional, Tuple

from pydantic import validator
from sqlalchemy import CheckConstraint, Index
from sqlalchemy.orm import declared_attr
from sqlalchemy.sql.functions import coalesce
Expand All @@ -13,6 +14,7 @@
from database.model.platform.platform_names import PlatformName
from database.model.relationships import OneToOne
from database.model.serializers import CastDeserializer
from database.validators import huggingface_validators, openml_validators

IS_SQLITE = os.getenv("DB") == "SQLite"
CONSTRAINT_LOWERCASE = f"{'platform' if IS_SQLITE else 'BINARY(platform)'} = LOWER(platform)"
Expand All @@ -32,11 +34,34 @@ class AIoDConceptBase(SQLModel):
platform_resource_identifier: str | None = Field(
max_length=NORMAL,
description="A unique identifier issued by the external platform that's specified in "
"'platform'. Leave empty if this item is not part of an external platform.",
"'platform'. Leave empty if this item is not part of an external platform. For example, "
"for HuggingFace, this should be the <namespace>/<dataset_name>, and for Openml, the "
"OpenML identifier.",
default=None,
schema_extra={"example": "1"},
)

@validator("platform_resource_identifier")
def platform_resource_identifier_valid(cls, platform_resource_identifier: str, values) -> str:
"""
Throw a ValueError if the platform_resource_identifier is invalid for this platform.

Note that field order matters: platform is defined before platform_resource_identifier,
so that this validator can use the value of the platform. Refer to
https://docs.pydantic.dev/1.10/usage/models/#field-ordering
"""
if platform := values.get("platform", None):
match platform:
case PlatformName.huggingface:
huggingface_validators.throw_error_on_invalid_identifier(
platform_resource_identifier
)
case PlatformName.openml:
openml_validators.throw_error_on_invalid_identifier(
platform_resource_identifier
)
return platform_resource_identifier


class AIoDConcept(AIoDConceptBase):
identifier: int = Field(default=None, primary_key=True)
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions src/database/validators/huggingface_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import re

REPO_ID_ILLEGAL_CHARACTERS = re.compile(r"[^0-9a-zA-Z-_./]+")
MSG_PREFIX = "The platform_resource_identifier for HuggingFace should be a valid repo_id. "


def throw_error_on_invalid_identifier(platform_resource_identifier: str):
"""
Throw a ValueError on an invalid repository identifier.

Valid repo_ids:
Between 1 and 96 characters.
Either “repo_name” or “namespace/repo_name”
[a-zA-Z0-9] or ”-”, ”_”, ”.”
”—” and ”..” are forbidden

Refer to:
https://huggingface.co/docs/huggingface_hub/package_reference/utilities#huggingface_hub.utils.validate_repo_id
"""
repo_id = platform_resource_identifier
if REPO_ID_ILLEGAL_CHARACTERS.search(repo_id):
msg = "A repo_id should only contain [a-zA-Z0-9] or ”-”, ”_”, ”.”"
raise ValueError(MSG_PREFIX + msg)
if not (1 < len(repo_id) < 96):
msg = "A repo_id should be between 1 and 96 characters."
raise ValueError(MSG_PREFIX + msg)
if repo_id.count("/") > 1:
msg = (
"For new repositories, there should be a single forward slash in the repo_id ("
"namespace/repo_name). Legacy repositories are without a namespace. This repo_id has "
"too many forward slashes."
)
raise ValueError(MSG_PREFIX + msg)
if ".." in repo_id:
msg = "A repo_id may not contain multiple consecutive dots."
raise ValueError(MSG_PREFIX + msg)
11 changes: 11 additions & 0 deletions src/database/validators/openml_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
MSG = "An OpenML platform_resource_identifier should be a positive integer."


def throw_error_on_invalid_identifier(platform_resource_identifier: str):
"""Throw a ValueError on an invalid repository identifier."""
try:
openml_identifier = int(platform_resource_identifier)
except ValueError:
raise ValueError(MSG)
if openml_identifier < 0:
raise ValueError(MSG)
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_incorrect_citation():
citation.name == "ArCOV-19: The First Arabic COVID-19 Twitter Dataset with Propagation "
"Networks"
)
assert citation.platform_resource_identifier == "haouari2020arcov19"
assert citation.platform_resource_identifier is None
assert citation.type == "article"
assert (
citation.description.plain == "By Fatima Haouari and Maram Hasanain and Reem Suwaileh "
Expand Down
34 changes: 34 additions & 0 deletions src/tests/routers/resource_routers/test_router_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
from unittest.mock import Mock

from starlette import status
from starlette.testclient import TestClient

from authentication import keycloak_openid
Expand Down Expand Up @@ -50,3 +51,36 @@ def test_happy_path(
"geo": {"latitude": 37.42242, "longitude": -122.08585, "elevation_millimeters": 2000},
}
# TODO: test delete


def test_post_invalid_huggingface_identifier(
client: TestClient,
mocked_privileged_token: Mock,
):
keycloak_openid.userinfo = mocked_privileged_token

body = {"name": "name", "platform": "huggingface", "platform_resource_identifier": "a"}

response = client.post("/datasets/v1", json=body, headers={"Authorization": "Fake token"})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
assert (
response.json()["detail"][0]["msg"]
== "The platform_resource_identifier for HuggingFace should be a valid repo_id. A repo_id "
"should be between 1 and 96 characters."
)


def test_post_invalid_openml_identifier(
client: TestClient,
mocked_privileged_token: Mock,
):
keycloak_openid.userinfo = mocked_privileged_token

body = {"name": "name", "platform": "openml", "platform_resource_identifier": "a"}

response = client.post("/datasets/v1", json=body, headers={"Authorization": "Fake token"})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
assert (
response.json()["detail"][0]["msg"]
== "An OpenML platform_resource_identifier should be a positive integer."
)
Empty file.
132 changes: 102 additions & 30 deletions src/tests/uploader/huggingface/test_dataset_uploader.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import copy
from unittest.mock import Mock

import huggingface_hub
import pytest
import responses
from starlette import status
from starlette.testclient import TestClient

from authentication import keycloak_openid
from database.model.ai_asset.ai_asset_table import AIAssetTable
from database.model.dataset.dataset import Dataset
from database.session import DbSession
from tests.testutils.paths import path_test_resources
from uploader.hugging_face_uploader import _throw_error_on_invalid_repo_id


def test_happy_path_new_repository(
client: TestClient, mocked_privileged_token: Mock, dataset: Dataset
):
dataset = copy.deepcopy(dataset)
dataset.platform = "huggingface"
dataset.platform_resource_identifier = "Fake-username/test"

keycloak_openid.userinfo = mocked_privileged_token
with DbSession() as session:
session.add(dataset)
session.commit()

data = {
"token": "huggingface_token",
"username": "username",
}

with open(path_test_resources() / "uploaders" / "huggingface" / "example.csv", "rb") as f:
files = {"file": f.read()}

Expand All @@ -37,39 +39,27 @@ def test_happy_path_new_repository(
huggingface_hub.upload_file = Mock(return_value=None)
response = client.post(
"/upload/datasets/1/huggingface",
data=data,
params={"username": "Fake-username", "token": "Fake-token"},
headers={"Authorization": "Fake token"},
files=files,
)

assert response.status_code == 200, response.json()
id_response = response.json()
assert id_response == 1


def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock):
def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock, dataset: Dataset):
keycloak_openid.userinfo = mocked_privileged_token
dataset_id = 1

dataset = copy.deepcopy(dataset)
dataset.platform = "huggingface"
dataset.platform_resource_identifier = "Fake-username/test"

with DbSession() as session:
session.add_all(
[
AIAssetTable(type="dataset"),
Dataset(
identifier=dataset_id,
name="Parent",
platform="example",
platform_resource_identifier="1",
same_as="",
),
]
)
session.add(dataset)
session.commit()

data = {
"token": "huggingface_token",
"username": "username",
}

with open(path_test_resources() / "uploaders" / "huggingface" / "example.csv", "rb") as f:
files = {"file": f.read()}

Expand All @@ -85,15 +75,97 @@ def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock):
)
huggingface_hub.upload_file = Mock(return_value=None)
response = client.post(
f"/upload/datasets/{dataset_id}/huggingface",
data=data,
"/upload/datasets/1/huggingface",
params={"username": "Fake-username", "token": "Fake-token"},
headers={"Authorization": "Fake token"},
files=files,
)
assert response.status_code == 200, response.json()
id_response = response.json()
assert id_response == dataset_id
assert id_response == 1


def test_wrong_platform(client: TestClient, mocked_privileged_token: Mock, dataset: Dataset):
keycloak_openid.userinfo = mocked_privileged_token

dataset = copy.deepcopy(dataset)
dataset.platform = "example"
dataset.platform_resource_identifier = "Fake-username/test"

with DbSession() as session:
session.add(dataset)
session.commit()

with open(path_test_resources() / "uploaders" / "huggingface" / "example.csv", "rb") as f:
files = {"file": f.read()}

response = client.post(
"/upload/datasets/1/huggingface",
params={"username": "Fake-username", "token": "Fake-token"},
headers={"Authorization": "Fake token"},
files=files,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST, response.json()
assert (
response.json()["detail"]
== "The dataset with identifier 1 should have platform=PlatformName.huggingface."
)


# TODO: tests some error handling?
@pytest.mark.parametrize(
"username,dataset_name,expected_error",
[
("0-hero", "0-hero/OIG-small-chip2", None),
("user", "user/Foo-BAR_foo.bar123", None),
(
"user",
"user/Test name with ?",
ValueError(
"The platform_resource_identifier for HuggingFace should be a valid repo_id. "
"A repo_id should only contain [a-zA-Z0-9] or ”-”, ”_”, ”.”"
),
),
(
"username",
"acronym_identification",
ValueError(
"The username should be part of the platform_resource_identifier for HuggingFace: "
"username/acronym_identification. Please update the dataset "
"platform_resource_identifier."
),
),
(
"user",
"user/data/set",
ValueError(
"The platform_resource_identifier for HuggingFace should be a valid repo_id. "
"For new repositories, there should be a single forward slash in the repo_id "
"(namespace/repo_name). Legacy repositories are without a namespace. This "
"repo_id has too many forward slashes."
),
),
(
"user",
"wrong-namespace/name",
ValueError(
"The namespace (the first part of the platform_resource_identifier) should be "
"equal to the username, but wrong-namespace != user."
),
),
(
"user",
"user/" + "a" * 200,
ValueError(
"The platform_resource_identifier for HuggingFace should be a valid repo_id. "
"A repo_id should be between 1 and 96 characters."
),
),
],
)
def test_repo_id(username: str, dataset_name: str, expected_error: ValueError | None):
if expected_error is None:
_throw_error_on_invalid_repo_id(username, dataset_name)
else:
with pytest.raises(type(expected_error)) as exception_info:
_throw_error_on_invalid_repo_id(username, dataset_name)
assert exception_info.value.args[0] == expected_error.args[0]
Empty file.
Loading