Skip to content

Commit

Permalink
Merge pull request #216 from aiondemand/bugfix/huggingface-dataset-up…
Browse files Browse the repository at this point in the history
…loader

Bugfix: Huggingface dataset uploader
  • Loading branch information
josvandervelde authored Dec 1, 2023
2 parents 73d475c + 3f0e878 commit e4d5285
Show file tree
Hide file tree
Showing 12 changed files with 316 additions and 47 deletions.
6 changes: 4 additions & 2 deletions src/connectors/huggingface/huggingface_dataset_connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import typing

import bibtexparser
import requests
from huggingface_hub import list_datasets
Expand Down Expand Up @@ -154,8 +155,9 @@ def _parse_citations(self, dataset, pydantic_class_publication) -> list:
]
return [
pydantic_class_publication(
platform=self.platform_name,
platform_resource_identifier=citation["ID"],
# The platform and platform_resource_identifier should be None: this publication
# is not stored on HuggingFace (and not identifiable within HF using,
# for instance, citation["ID"])
name=citation["title"],
same_as=citation["link"] if "link" in citation else None,
type=citation["ENTRYTYPE"],
Expand Down
27 changes: 26 additions & 1 deletion src/database/model/concept/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from typing import Optional, Tuple

from pydantic import validator
from sqlalchemy import CheckConstraint, Index
from sqlalchemy.orm import declared_attr
from sqlalchemy.sql.functions import coalesce
Expand All @@ -13,6 +14,7 @@
from database.model.platform.platform_names import PlatformName
from database.model.relationships import OneToOne
from database.model.serializers import CastDeserializer
from database.validators import huggingface_validators, openml_validators

IS_SQLITE = os.getenv("DB") == "SQLite"
CONSTRAINT_LOWERCASE = f"{'platform' if IS_SQLITE else 'BINARY(platform)'} = LOWER(platform)"
Expand All @@ -32,11 +34,34 @@ class AIoDConceptBase(SQLModel):
platform_resource_identifier: str | None = Field(
max_length=NORMAL,
description="A unique identifier issued by the external platform that's specified in "
"'platform'. Leave empty if this item is not part of an external platform.",
"'platform'. Leave empty if this item is not part of an external platform. For example, "
"for HuggingFace, this should be the <namespace>/<dataset_name>, and for Openml, the "
"OpenML identifier.",
default=None,
schema_extra={"example": "1"},
)

@validator("platform_resource_identifier")
def platform_resource_identifier_valid(cls, platform_resource_identifier: str, values) -> str:
"""
Throw a ValueError if the platform_resource_identifier is invalid for this platform.
Note that field order matters: platform is defined before platform_resource_identifier,
so that this validator can use the value of the platform. Refer to
https://docs.pydantic.dev/1.10/usage/models/#field-ordering
"""
if platform := values.get("platform", None):
match platform:
case PlatformName.huggingface:
huggingface_validators.throw_error_on_invalid_identifier(
platform_resource_identifier
)
case PlatformName.openml:
openml_validators.throw_error_on_invalid_identifier(
platform_resource_identifier
)
return platform_resource_identifier


class AIoDConcept(AIoDConceptBase):
identifier: int = Field(default=None, primary_key=True)
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions src/database/validators/huggingface_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import re

REPO_ID_ILLEGAL_CHARACTERS = re.compile(r"[^0-9a-zA-Z-_./]+")
MSG_PREFIX = "The platform_resource_identifier for HuggingFace should be a valid repo_id. "


def throw_error_on_invalid_identifier(platform_resource_identifier: str):
"""
Throw a ValueError on an invalid repository identifier.
Valid repo_ids:
Between 1 and 96 characters.
Either “repo_name” or “namespace/repo_name”
[a-zA-Z0-9] or ”-”, ”_”, ”.”
”—” and ”..” are forbidden
Refer to:
https://huggingface.co/docs/huggingface_hub/package_reference/utilities#huggingface_hub.utils.validate_repo_id
"""
repo_id = platform_resource_identifier
if REPO_ID_ILLEGAL_CHARACTERS.search(repo_id):
msg = "A repo_id should only contain [a-zA-Z0-9] or ”-”, ”_”, ”.”"
raise ValueError(MSG_PREFIX + msg)
if not (1 < len(repo_id) < 96):
msg = "A repo_id should be between 1 and 96 characters."
raise ValueError(MSG_PREFIX + msg)
if repo_id.count("/") > 1:
msg = (
"For new repositories, there should be a single forward slash in the repo_id ("
"namespace/repo_name). Legacy repositories are without a namespace. This repo_id has "
"too many forward slashes."
)
raise ValueError(MSG_PREFIX + msg)
if ".." in repo_id:
msg = "A repo_id may not contain multiple consecutive dots."
raise ValueError(MSG_PREFIX + msg)
11 changes: 11 additions & 0 deletions src/database/validators/openml_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
MSG = "An OpenML platform_resource_identifier should be a positive integer."


def throw_error_on_invalid_identifier(platform_resource_identifier: str):
"""Throw a ValueError on an invalid repository identifier."""
try:
openml_identifier = int(platform_resource_identifier)
except ValueError:
raise ValueError(MSG)
if openml_identifier < 0:
raise ValueError(MSG)
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_incorrect_citation():
citation.name == "ArCOV-19: The First Arabic COVID-19 Twitter Dataset with Propagation "
"Networks"
)
assert citation.platform_resource_identifier == "haouari2020arcov19"
assert citation.platform_resource_identifier is None
assert citation.type == "article"
assert (
citation.description.plain == "By Fatima Haouari and Maram Hasanain and Reem Suwaileh "
Expand Down
34 changes: 34 additions & 0 deletions src/tests/routers/resource_routers/test_router_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
from unittest.mock import Mock

from starlette import status
from starlette.testclient import TestClient

from authentication import keycloak_openid
Expand Down Expand Up @@ -50,3 +51,36 @@ def test_happy_path(
"geo": {"latitude": 37.42242, "longitude": -122.08585, "elevation_millimeters": 2000},
}
# TODO: test delete


def test_post_invalid_huggingface_identifier(
client: TestClient,
mocked_privileged_token: Mock,
):
keycloak_openid.userinfo = mocked_privileged_token

body = {"name": "name", "platform": "huggingface", "platform_resource_identifier": "a"}

response = client.post("/datasets/v1", json=body, headers={"Authorization": "Fake token"})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
assert (
response.json()["detail"][0]["msg"]
== "The platform_resource_identifier for HuggingFace should be a valid repo_id. A repo_id "
"should be between 1 and 96 characters."
)


def test_post_invalid_openml_identifier(
client: TestClient,
mocked_privileged_token: Mock,
):
keycloak_openid.userinfo = mocked_privileged_token

body = {"name": "name", "platform": "openml", "platform_resource_identifier": "a"}

response = client.post("/datasets/v1", json=body, headers={"Authorization": "Fake token"})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
assert (
response.json()["detail"][0]["msg"]
== "An OpenML platform_resource_identifier should be a positive integer."
)
Empty file.
132 changes: 102 additions & 30 deletions src/tests/uploader/huggingface/test_dataset_uploader.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import copy
from unittest.mock import Mock

import huggingface_hub
import pytest
import responses
from starlette import status
from starlette.testclient import TestClient

from authentication import keycloak_openid
from database.model.ai_asset.ai_asset_table import AIAssetTable
from database.model.dataset.dataset import Dataset
from database.session import DbSession
from tests.testutils.paths import path_test_resources
from uploader.hugging_face_uploader import _throw_error_on_invalid_repo_id


def test_happy_path_new_repository(
client: TestClient, mocked_privileged_token: Mock, dataset: Dataset
):
dataset = copy.deepcopy(dataset)
dataset.platform = "huggingface"
dataset.platform_resource_identifier = "Fake-username/test"

keycloak_openid.userinfo = mocked_privileged_token
with DbSession() as session:
session.add(dataset)
session.commit()

data = {
"token": "huggingface_token",
"username": "username",
}

with open(path_test_resources() / "uploaders" / "huggingface" / "example.csv", "rb") as f:
files = {"file": f.read()}

Expand All @@ -37,39 +39,27 @@ def test_happy_path_new_repository(
huggingface_hub.upload_file = Mock(return_value=None)
response = client.post(
"/upload/datasets/1/huggingface",
data=data,
params={"username": "Fake-username", "token": "Fake-token"},
headers={"Authorization": "Fake token"},
files=files,
)

assert response.status_code == 200, response.json()
id_response = response.json()
assert id_response == 1


def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock):
def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock, dataset: Dataset):
keycloak_openid.userinfo = mocked_privileged_token
dataset_id = 1

dataset = copy.deepcopy(dataset)
dataset.platform = "huggingface"
dataset.platform_resource_identifier = "Fake-username/test"

with DbSession() as session:
session.add_all(
[
AIAssetTable(type="dataset"),
Dataset(
identifier=dataset_id,
name="Parent",
platform="example",
platform_resource_identifier="1",
same_as="",
),
]
)
session.add(dataset)
session.commit()

data = {
"token": "huggingface_token",
"username": "username",
}

with open(path_test_resources() / "uploaders" / "huggingface" / "example.csv", "rb") as f:
files = {"file": f.read()}

Expand All @@ -85,15 +75,97 @@ def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock):
)
huggingface_hub.upload_file = Mock(return_value=None)
response = client.post(
f"/upload/datasets/{dataset_id}/huggingface",
data=data,
"/upload/datasets/1/huggingface",
params={"username": "Fake-username", "token": "Fake-token"},
headers={"Authorization": "Fake token"},
files=files,
)
assert response.status_code == 200, response.json()
id_response = response.json()
assert id_response == dataset_id
assert id_response == 1


def test_wrong_platform(client: TestClient, mocked_privileged_token: Mock, dataset: Dataset):
keycloak_openid.userinfo = mocked_privileged_token

dataset = copy.deepcopy(dataset)
dataset.platform = "example"
dataset.platform_resource_identifier = "Fake-username/test"

with DbSession() as session:
session.add(dataset)
session.commit()

with open(path_test_resources() / "uploaders" / "huggingface" / "example.csv", "rb") as f:
files = {"file": f.read()}

response = client.post(
"/upload/datasets/1/huggingface",
params={"username": "Fake-username", "token": "Fake-token"},
headers={"Authorization": "Fake token"},
files=files,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST, response.json()
assert (
response.json()["detail"]
== "The dataset with identifier 1 should have platform=PlatformName.huggingface."
)


# TODO: tests some error handling?
@pytest.mark.parametrize(
"username,dataset_name,expected_error",
[
("0-hero", "0-hero/OIG-small-chip2", None),
("user", "user/Foo-BAR_foo.bar123", None),
(
"user",
"user/Test name with ?",
ValueError(
"The platform_resource_identifier for HuggingFace should be a valid repo_id. "
"A repo_id should only contain [a-zA-Z0-9] or ”-”, ”_”, ”.”"
),
),
(
"username",
"acronym_identification",
ValueError(
"The username should be part of the platform_resource_identifier for HuggingFace: "
"username/acronym_identification. Please update the dataset "
"platform_resource_identifier."
),
),
(
"user",
"user/data/set",
ValueError(
"The platform_resource_identifier for HuggingFace should be a valid repo_id. "
"For new repositories, there should be a single forward slash in the repo_id "
"(namespace/repo_name). Legacy repositories are without a namespace. This "
"repo_id has too many forward slashes."
),
),
(
"user",
"wrong-namespace/name",
ValueError(
"The namespace (the first part of the platform_resource_identifier) should be "
"equal to the username, but wrong-namespace != user."
),
),
(
"user",
"user/" + "a" * 200,
ValueError(
"The platform_resource_identifier for HuggingFace should be a valid repo_id. "
"A repo_id should be between 1 and 96 characters."
),
),
],
)
def test_repo_id(username: str, dataset_name: str, expected_error: ValueError | None):
if expected_error is None:
_throw_error_on_invalid_repo_id(username, dataset_name)
else:
with pytest.raises(type(expected_error)) as exception_info:
_throw_error_on_invalid_repo_id(username, dataset_name)
assert exception_info.value.args[0] == expected_error.args[0]
Empty file.
Loading

0 comments on commit e4d5285

Please sign in to comment.