Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CVAT][Exchange Oracle] Use UUID type in postgres #2624

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Use UUID type for primar keys

Revision ID: 284adb30d75e
Revises: fde2b09b6b39
Create Date: 2024-10-09 17:45:27.692538

"""

from alembic import op

# revision identifiers, used by Alembic.
revision = "284adb30d75e"
down_revision = "fde2b09b6b39"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.execute("ALTER TABLE assignments ALTER COLUMN id TYPE UUID USING id::uuid")
op.execute("ALTER TABLE data_uploads ALTER COLUMN id TYPE UUID USING id::uuid")
op.execute("ALTER TABLE escrow_creations ALTER COLUMN id TYPE UUID USING id::uuid")
op.execute("ALTER TABLE escrow_validations ALTER COLUMN id TYPE UUID USING id::uuid")
op.execute("ALTER TABLE images ALTER COLUMN id TYPE UUID USING id::uuid")
op.execute("ALTER TABLE jobs ALTER COLUMN id TYPE UUID USING id::uuid")
op.execute("ALTER TABLE projects ALTER COLUMN id TYPE UUID USING id::uuid")
op.execute("ALTER TABLE tasks ALTER COLUMN id TYPE UUID USING id::uuid")
op.execute("ALTER TABLE webhooks ALTER COLUMN id TYPE UUID USING id::uuid")


def downgrade() -> None:
op.execute("ALTER TABLE webhooks ALTER COLUMN id TYPE VARCHAR USING id::text")
op.execute("ALTER TABLE tasks ALTER COLUMN id TYPE VARCHAR USING id::text")
op.execute("ALTER TABLE projects ALTER COLUMN id TYPE VARCHAR USING id::text")
op.execute("ALTER TABLE jobs ALTER COLUMN id TYPE VARCHAR USING id::text")
op.execute("ALTER TABLE images ALTER COLUMN id TYPE VARCHAR USING id::text")
op.execute("ALTER TABLE escrow_validations ALTER COLUMN id TYPE VARCHAR USING id::text")
op.execute("ALTER TABLE escrow_creations ALTER COLUMN id TYPE VARCHAR USING id::text")
op.execute("ALTER TABLE data_uploads ALTER COLUMN id TYPE VARCHAR USING id::text")
op.execute("ALTER TABLE assignments ALTER COLUMN id TYPE VARCHAR USING id::text")
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Add default for UUID primary keys

Revision ID: 4fc740e8c6ff
Revises: 284adb30d75e
Create Date: 2024-10-10 14:10:45.948593

"""

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision = "4fc740e8c6ff"
down_revision = "284adb30d75e"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.alter_column("assignments", "id", server_default=sa.text("uuid_generate_v4()"))
op.alter_column("data_uploads", "id", server_default=sa.text("uuid_generate_v4()"))
op.alter_column("escrow_creations", "id", server_default=sa.text("uuid_generate_v4()"))
op.alter_column("images", "id", server_default=sa.text("uuid_generate_v4()"))
op.alter_column("jobs", "id", server_default=sa.text("uuid_generate_v4()"))
op.alter_column("projects", "id", server_default=sa.text("uuid_generate_v4()"))
op.alter_column("tasks", "id", server_default=sa.text("uuid_generate_v4()"))
op.alter_column("webhooks", "id", server_default=sa.text("uuid_generate_v4()"))


def downgrade() -> None:
op.alter_column("webhooks", "id", server_default=None)
op.alter_column("tasks", "id", server_default=None)
op.alter_column("projects", "id", server_default=None)
op.alter_column("jobs", "id", server_default=None)
op.alter_column("images", "id", server_default=None)
op.alter_column("escrow_creations", "id", server_default=None)
op.alter_column("data_uploads", "id", server_default=None)
op.alter_column("assignments", "id", server_default=None)
22 changes: 20 additions & 2 deletions packages/examples/cvat/exchange-oracle/src/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar
from uuid import uuid4

import sqlalchemy
from psycopg2.errors import Error
from sqlalchemy import DDL, event
from sqlalchemy import DDL, UUID, event, func
from sqlalchemy.exc import SQLAlchemyError, StatementError
from sqlalchemy.orm import (
DeclarativeBase,
InstrumentedAttribute,
Mapped,
Relationship,
mapped_column,
sessionmaker,
)

Expand All @@ -29,10 +32,25 @@ class Base(DeclarativeBase):
__tablename__: ClassVar[str]


class BaseUUID(Base):
__abstract__ = True
id: Mapped[str] = mapped_column(
# Using `str` instead of python `uuid.UUID` for now
# to reduce amount of code needed to be rewritten.
# At some point it would make sense to use UUID(as_uuid=True).
UUID(as_uuid=False),
primary_key=True,
default=lambda: str(uuid4()),
server_default=func.uuid_generate_v4(),
sort_order=-1, # Make sure it's the first column.
index=True,
)


ParentT = TypeVar("ParentT", bound=type[Base])


class ChildOf(Base, Generic[ParentT]):
class ChildOf(BaseUUID, Generic[ParentT]):
__abstract__ = True

if TYPE_CHECKING:
Expand Down
25 changes: 8 additions & 17 deletions packages/examples/cvat/exchange-oracle/src/models/cvat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
TaskStatuses,
TaskTypes,
)
from src.db import Base, ChildOf
from src.db import Base, BaseUUID, ChildOf
from src.utils.time import utcnow


class Project(Base):
class Project(BaseUUID):
__tablename__ = "projects"
id = Column(String, primary_key=True, index=True)
cvat_id = Column(Integer, unique=True, index=True, nullable=False)
cvat_cloudstorage_id = Column(Integer, index=True, nullable=False)
status = Column(String, Enum(ProjectStatuses), nullable=False)
Expand Down Expand Up @@ -81,7 +80,6 @@ def __repr__(self) -> str:

class Task(ChildOf[Project]):
__tablename__ = "tasks"
id = Column(String, primary_key=True, index=True)
cvat_id = Column(Integer, unique=True, index=True, nullable=False)
cvat_project_id = Column(
Integer,
Expand All @@ -106,9 +104,8 @@ def __repr__(self) -> str:
return f"Task. id={self.id}"


class EscrowCreation(Base):
class EscrowCreation(BaseUUID):
__tablename__ = "escrow_creations"
id = Column(String, primary_key=True, index=True)

escrow_address = Column(String(42), index=True, nullable=False)
chain_id = Column(Integer, Enum(Networks), nullable=False)
Expand All @@ -135,12 +132,10 @@ def __repr__(self) -> str:
return f"EscrowCreation. id={self.id} escrow={self.escrow_address}"


class EscrowValidation(Base):
class EscrowValidation(BaseUUID):
__tablename__ = "escrow_validations"
__table_args__ = (UniqueConstraint("escrow_address", "chain_id", name="uix_escrow_chain"),)

id = Column(String, primary_key=True, index=True, server_default=func.uuid_generate_v4())

escrow_address = Column(String(42), index=True, nullable=False)
chain_id = Column(Integer, Enum(Networks), nullable=False)

Expand All @@ -160,9 +155,8 @@ class EscrowValidation(Base):
)


class DataUpload(Base):
class DataUpload(BaseUUID):
__tablename__ = "data_uploads"
id = Column(String, primary_key=True, index=True)
task_id = Column(
Integer,
ForeignKey("tasks.cvat_id", ondelete="CASCADE"),
Expand All @@ -179,7 +173,6 @@ def __repr__(self) -> str:

class Job(ChildOf[Task]):
__tablename__ = "jobs"
id = Column(String, primary_key=True, index=True)
cvat_id = Column(Integer, unique=True, index=True, nullable=False)
cvat_task_id = Column(Integer, ForeignKey("tasks.cvat_id", ondelete="CASCADE"), nullable=False)
cvat_project_id = Column(
Expand Down Expand Up @@ -209,7 +202,7 @@ def __repr__(self) -> str:
return f"Job. id={self.id}"


class User(Base):
class User(Base): # user does not have a UUID primary key
__tablename__ = "users"
wallet_address = Column(String, primary_key=True, index=True, nullable=False)
cvat_email = Column(String, unique=True, index=True, nullable=True)
Expand All @@ -223,9 +216,8 @@ def __repr__(self) -> str:
return f"User. wallet_address={self.wallet_address} cvat_id={self.cvat_id}"


class Assignment(Base):
class Assignment(BaseUUID):
__tablename__ = "assignments"
id = Column(String, primary_key=True, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
Expand Down Expand Up @@ -260,9 +252,8 @@ def __repr__(self) -> str:
return f"Assignment. id={self.id} user={self.user.cvat_id} job={self.job.cvat_id}"


class Image(Base):
class Image(BaseUUID):
__tablename__ = "images"
id = Column(String, primary_key=True, index=True)
cvat_project_id = Column(
Integer,
ForeignKey("projects.cvat_id", ondelete="CASCADE"),
Expand Down
8 changes: 4 additions & 4 deletions packages/examples/cvat/exchange-oracle/src/models/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from sqlalchemy.sql import func

from src.core.types import Networks, OracleWebhookStatuses, OracleWebhookTypes
from src.db import Base
from src.db import BaseUUID
from src.utils.time import utcnow


class Webhook(Base):
class Webhook(BaseUUID):
__tablename__ = "webhooks"
id = Column(String, primary_key=True, index=True)
signature = Column(String, unique=True, index=True, nullable=True)
escrow_address = Column(String(42), nullable=False)
chain_id = Column(Integer, Enum(Networks), nullable=False)
Expand All @@ -21,7 +21,7 @@ class Webhook(Base):
attempts = Column(Integer, server_default="0")
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
wait_until = Column(DateTime(timezone=True), server_default=func.now())
wait_until = Column(DateTime(timezone=True), server_default=func.now(), default=utcnow)
event_type = Column(String, nullable=False)
event_data = Column(JSON, nullable=True, server_default=None)
direction = Column(String, nullable=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def test_can_list_assignments_200(client: TestClient, session: Session) -> None:
"chain_id": ((cvat_projects[0].chain_id, len(assignments)),),
"assignment_id": (
(assignments[0].id, 1),
("unknown", 0),
(uuid.uuid4(), 0),
),
"job_type": (
(cvat_projects[0].job_type, len(assignments)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,7 @@ def test_process_incoming_recording_oracle_webhooks_submission_rejected_type_inv
status=OracleWebhookStatuses.pending.value,
event_type=RecordingOracleEventTypes.submission_rejected.value,
event_data={
"assignments": [
{"assignment_id": "sample assignment id", "reason": "sample reason"}
]
"assignments": [{"assignment_id": str(uuid.uuid4()), "reason": "sample reason"}]
},
direction=OracleWebhookDirectionTags.incoming,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_get_project_by_id(self):
assert project.escrow_address == escrow_address
assert project.bucket_url == bucket_url

project = cvat_service.get_project_by_id(self.session, "dummy_id")
project = cvat_service.get_project_by_id(self.session, uuid.uuid4())

assert project is None

Expand Down Expand Up @@ -533,7 +533,7 @@ def test_delete_project_wrong_project_id(self):
projects = self.session.query(Project).all()
assert len(projects) == 1
with pytest.raises(UnmappedInstanceError):
cvat_service.delete_project(self.session, "project_id")
cvat_service.delete_project(self.session, uuid.uuid4())

def test_create_task(self):
cvat_id = 1
Expand Down Expand Up @@ -597,7 +597,7 @@ def test_get_task_by_id(self):
assert task.cvat_project_id == cvat_project.cvat_id
assert task.status == TaskStatuses.annotation.value

task = cvat_service.get_task_by_id(self.session, "dummy_id")
task = cvat_service.get_task_by_id(self.session, uuid.uuid4())

assert task is None

Expand Down Expand Up @@ -908,7 +908,7 @@ def test_get_job_by_id(self):
assert job.cvat_task_id == cvat_task.cvat_id
assert job.cvat_project_id == cvat_project.cvat_id

job = cvat_service.get_job_by_id(self.session, "Dummy id")
job = cvat_service.get_job_by_id(self.session, uuid.uuid4())

assert job is None

Expand Down
Loading