Skip to content

Commit

Permalink
Merge branch 'develop' into bump-pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Nov 20, 2024
2 parents c3b1521 + 24407ee commit 4fc36af
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 40 deletions.
66 changes: 66 additions & 0 deletions scripts/migrate_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Updates the metadata of Hugging Face entries to use `_id` instead of `id` as platform identifier.
The `id` field (i.e., `username/datasetname`, e.g., `pgijsbers/titanic`) is subject to change when
a user changes their username or the dataset name. The `_id` field is persistent across these changes,
so can be used to avoid indexing the same dataset twice under a different platform identifier.
To be run once (around sometime Nov 2024), likely not needed after that. See also #385, 392.
"""
import logging
import string
from http import HTTPStatus

from sqlalchemy import select
from database.session import DbSession, EngineSingleton
from database.model.dataset.dataset import Dataset
from database.model.platform.platform import Platform
from database.model.platform.platform_names import PlatformName
from database.model.concept.concept import AIoDConcept

# Magic import which triggers ORM setup
import database.setup

import requests


def main():
AIoDConcept.metadata.create_all(EngineSingleton().engine, checkfirst=True)
with DbSession() as session:
datasets_query = select(Dataset).where(Dataset.platform == PlatformName.huggingface)
datasets = session.scalars(datasets_query).all()

for dataset in datasets:
if all(c in string.hexdigits for c in dataset.platform_resource_identifier):
continue # entry already updated to use new-style id

response = requests.get(
f"https://huggingface.co/api/datasets/{dataset.name}",
params={"full": "False"},
headers={},
timeout=10,
)
if response.status_code != HTTPStatus.OK:
logging.warning(f"Dataset {dataset.name} could not be retrieved.")
continue

dataset_json = response.json()
if dataset.platform_resource_identifier != dataset_json["id"]:
logging.info(
f"Dataset {dataset.platform_resource_identifier} moved to {dataset_json['id']}"
"Deleting the old entry. The new entry either already exists or"
"will be added on a later synchronization invocation."
)
session.delete(dataset)
continue

persistent_id = dataset_json["_id"]
logging.info(
f"Setting platform id of {dataset.platform_resource_identifier} to {persistent_id}"
)
dataset.platform_resource_identifier = persistent_id
session.commit()


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src/connectors/huggingface/huggingface_dataset_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def fetch(
dataset, pydantic_class, pydantic_class_publication, pydantic_class_contact
)
except Exception as e:
# We use the normal id here since it is more informative and can be used to visit hf
yield RecordError(identifier=dataset.id, error=e)

def fetch_dataset(
Expand Down Expand Up @@ -119,7 +120,7 @@ def fetch_dataset(
return ResourceWithRelations[pydantic_class]( # type:ignore
resource=pydantic_class(
aiod_entry=AIoDEntryCreate(status="published"),
platform_resource_identifier=dataset.id,
platform_resource_identifier=dataset._id, # see #385, 392
platform=self.platform_name,
name=dataset.id,
same_as=f"https://huggingface.co/datasets/{dataset.id}",
Expand Down
25 changes: 13 additions & 12 deletions src/database/deletion/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@

from typing import Type

from sqlalchemy import DDL, event
from sqlalchemy import DDL
from sqlmodel import SQLModel

from database.model.helper_functions import get_relationships, non_abstract_subclasses


def add_delete_triggers(parent_class: Type[SQLModel]):
def create_delete_triggers(parent_class: Type[SQLModel]):
classes: list[Type[SQLModel]] = non_abstract_subclasses(parent_class)
triggers = []
for cls in classes:
for name, value in get_relationships(cls).items():
value.create_triggers(cls, name)
trigger = value.create_triggers(cls, name)
if trigger is not None:
triggers.append(trigger)
return triggers


def create_deletion_trigger_one_to_one(
Expand Down Expand Up @@ -49,9 +53,9 @@ def create_deletion_trigger_one_to_one(
trigger_name = trigger.__tablename__
delete_name = to_delete.__tablename__

ddl = DDL(
return DDL(
f"""
CREATE TRIGGER delete_{trigger_name}_{trigger_identifier_link}_{delete_name}
CREATE TRIGGER IF NOT EXISTS delete_{trigger_name}_{trigger_identifier_link}_{delete_name}
AFTER DELETE ON {trigger_name}
FOR EACH ROW
BEGIN
Expand All @@ -60,7 +64,6 @@ def create_deletion_trigger_one_to_one(
END;
""" # noqa: S608 # never user input
)
event.listen(trigger.metadata, "after_create", ddl)


def create_deletion_trigger_many_to_one(
Expand All @@ -84,9 +87,9 @@ def create_deletion_trigger_many_to_one(
trigger_name = trigger.__tablename__
delete_name = to_delete.__tablename__

ddl = DDL(
return DDL(
f"""
CREATE TRIGGER delete_{trigger_name}_{delete_name}
CREATE TRIGGER IF NOT EXISTS delete_{trigger_name}_{delete_name}
AFTER DELETE ON {trigger_name}
FOR EACH ROW
BEGIN
Expand All @@ -99,7 +102,6 @@ def create_deletion_trigger_many_to_one(
END;
""" # noqa: S608 # never user input
)
event.listen(trigger.metadata, "after_create", ddl)


def create_deletion_trigger_many_to_many(
Expand Down Expand Up @@ -142,9 +144,9 @@ def create_deletion_trigger_many_to_many(
""" # noqa: S608 # never user input
for link_name in link_names
)
ddl = DDL(
return DDL(
f"""
CREATE TRIGGER delete_{link_name}
CREATE TRIGGER IF NOT EXISTS delete_{link_name}
AFTER DELETE ON {trigger_name}
FOR EACH ROW
BEGIN
Expand All @@ -155,4 +157,3 @@ def create_deletion_trigger_many_to_many(
END;
""" # noqa: S608 # never user input
)
event.listen(trigger.metadata, "after_create", ddl)
6 changes: 3 additions & 3 deletions src/database/model/relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def create_triggers(self, parent_class: Type[SQLModel], field_name: str):
"The deletion trigger is configured wrongly: the field doesn't "
f"point to a SQLModel class: {parent_class} . {field_name}"
)
triggers.create_deletion_trigger_one_to_one(
return triggers.create_deletion_trigger_one_to_one(
trigger=parent_class,
trigger_identifier_link=self.on_delete_trigger_deletion_by,
to_delete=to_delete,
Expand All @@ -170,7 +170,7 @@ def create_triggers(self, parent_class: Type[SQLModel], field_name: str):
to_delete_identifier = getattr(
parent_class.RelationshipConfig, field_name
).identifier_name
triggers.create_deletion_trigger_many_to_one(
return triggers.create_deletion_trigger_many_to_one(
trigger=parent_class,
to_delete=self.on_delete_trigger_deletion_of_orphan,
trigger_identifier_link=to_delete_identifier,
Expand Down Expand Up @@ -216,6 +216,6 @@ def create_triggers(self, parent_class: Type[SQLModel], field_name: str):
)

other_links = self.on_delete_trigger_orphan_deletion()
triggers.create_deletion_trigger_many_to_many(
return triggers.create_deletion_trigger_many_to_many(
trigger=parent_class, link=link, to_delete=to_delete, other_links=other_links
)
10 changes: 4 additions & 6 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from authentication import get_user_or_raise, User
from config import KEYCLOAK_CONFIG
from database.deletion.triggers import add_delete_triggers
from database.deletion.triggers import create_delete_triggers
from database.model.concept.concept import AIoDConcept
from database.model.platform.platform import Platform
from database.model.platform.platform_names import PlatformName
Expand Down Expand Up @@ -134,16 +134,14 @@ def create_app() -> FastAPI:
create_database(delete_first=drop_database)
AIoDConcept.metadata.create_all(EngineSingleton().engine, checkfirst=True)
with DbSession() as session:
triggers = create_delete_triggers(AIoDConcept)
for trigger in triggers:
session.execute(trigger)
existing_platforms = session.scalars(select(Platform)).all()
if not any(existing_platforms):
session.add_all([Platform(name=name) for name in PlatformName])
session.commit()

# this is a bit of a hack: instead of checking whether the triggers exist, we check
# whether platforms are already present. If platforms were not present, the db is
# empty, and so the triggers should still be added.
add_delete_triggers(AIoDConcept)

add_routes(app, url_prefix=args.url_prefix)
return app

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@


def test_fetch_all_happy_path():
ids_expected = {
names_expected = {
"0n1xus/codexglue",
"04-07-22/wep-probes",
"rotten_tomatoes",
"acronym_identification",
"air_dialogue",
"bobbydylan/top2k",
}
ids_expected = {
"621ffdd236468d709f18203a",
"62cd5fa83e5ba89c40f22b0d",
"621ffdd236468d709f181f5f",
"621ffdd236468d709f181d58",
"621ffdd236468d709f181d5f",
"621ffdd236468d709f182fdf",
}
connector = HuggingFaceDatasetConnector()
with responses.RequestsMock() as mocked_requests:
path_data_list = path_test_resources() / "connectors" / "huggingface" / "data_list.json"
Expand All @@ -31,24 +39,24 @@ def test_fetch_all_happy_path():
json=response,
status=200,
)
for dataset_id in ids_expected:
mock_parquet(mocked_requests, dataset_id)
for dataset_name in names_expected:
mock_parquet(mocked_requests, dataset_name)
resources_with_relations = list(connector.fetch())

assert len(resources_with_relations) == len(ids_expected)
assert all(type(r) == ResourceWithRelations for r in resources_with_relations)

datasets = [r.resource for r in resources_with_relations]
assert {d.platform_resource_identifier for d in datasets} == ids_expected
assert {d.name for d in datasets} == ids_expected
assert {d.name for d in datasets} == names_expected
assert all(d.date_published for d in datasets)
assert all(d.aiod_entry for d in datasets)

assert all(len(r.related_resources) in (1, 2) for r in resources_with_relations)
assert all(len(r.related_resources["citation"]) == 1 for r in resources_with_relations[:5])

dataset = datasets[0]
assert dataset.platform_resource_identifier == "acronym_identification"
assert dataset.platform_resource_identifier == "621ffdd236468d709f181d58"
assert dataset.platform == PlatformName.huggingface
assert dataset.description == Text(
plain="Acronym identification training and development "
Expand Down Expand Up @@ -132,15 +140,15 @@ def test_incorrect_citation():
)


def mock_parquet(mocked_requests: responses.RequestsMock, dataset_id: str):
filename = f"parquet_{dataset_id.replace('/', '_')}.json"
def mock_parquet(mocked_requests: responses.RequestsMock, dataset_name: str):
filename = f"parquet_{dataset_name.replace('/', '_')}.json"
path_split = path_test_resources() / "connectors" / "huggingface" / filename
with open(path_split, "r") as f:
response = json.load(f)
status = 200 if "error" not in response else 404
mocked_requests.add(
responses.GET,
f"{HUGGINGFACE_URL}/parquet?dataset={dataset_id}",
f"{HUGGINGFACE_URL}/parquet?dataset={dataset_name}",
json=response,
status=status,
)
16 changes: 6 additions & 10 deletions src/tests/testutils/default_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sqlite3
import tempfile
from typing import Iterator, Type, Any
from typing import Iterator, Any
from unittest.mock import Mock

import pytest
Expand All @@ -12,7 +12,7 @@
from starlette.testclient import TestClient

from authentication import keycloak_openid
from database.deletion.triggers import add_delete_triggers
from database.deletion.triggers import create_delete_triggers
from database.model.concept.concept import AIoDConcept
from database.model.platform.platform import Platform
from database.model.platform.platform_names import PlatformName
Expand All @@ -22,20 +22,16 @@


@pytest.fixture(scope="session")
def deletion_triggers() -> Type[AIoDConcept]:
"""Making sure that the deletion triggers are only created once"""
add_delete_triggers(AIoDConcept)
return AIoDConcept


@pytest.fixture(scope="session")
def engine(deletion_triggers) -> Iterator[Engine]:
def engine() -> Iterator[Engine]:
"""
Create a SqlAlchemy engine for tests, backed by a temporary sqlite file.
"""
temporary_file = tempfile.NamedTemporaryFile()
engine = create_engine(f"sqlite:///{temporary_file.name}?check_same_thread=False")
AIoDConcept.metadata.create_all(engine)
with Session(engine) as session:
for trigger in create_delete_triggers(AIoDConcept):
session.execute(trigger)
EngineSingleton().patch(engine)

# Yielding is essential, the temporary file will be closed after the engine is used
Expand Down

0 comments on commit 4fc36af

Please sign in to comment.