Skip to content

Commit

Permalink
refactor: implement composite key generation for MediaItem (#838)
Browse files Browse the repository at this point in the history
* refactor: implement composite key generation for MediaItem

- Changed MediaItem ID type from int to str to support composite keys.
- Added a method to generate composite keys using item type and trakt_id.
- Updated related functions and methods to handle the new ID type.
- Modified database queries and relationships to accommodate string IDs.
- Adjusted API endpoints and event management to work with composite keys.

* refactor: update ID handling to use strings instead of integers across modules
  • Loading branch information
iPromKnight authored Nov 2, 2024
1 parent 4e57c42 commit 2a0291c
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 45 deletions.
14 changes: 7 additions & 7 deletions src/program/db/db_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
if TYPE_CHECKING:
from program.media.item import MediaItem

def get_item_by_id(id: str, item_types = None, session = None):
if not id:
def get_item_by_id(item_id: str, item_types = None, session = None):
if not item_id:
return None

from program.media.item import MediaItem, Season, Show
_session = session if session else db.Session()

with _session:
query = (select(MediaItem)
.where(MediaItem.id == id)
.where(MediaItem.id == item_id)
.options(
selectinload(Show.seasons)
.selectinload(Season.episodes)
Expand Down Expand Up @@ -79,7 +79,7 @@ def delete_media_item(item: "MediaItem"):
session.delete(item)
session.commit()

def delete_media_item_by_id(media_item_id: int, batch_size: int = 30):
def delete_media_item_by_id(media_item_id: str, batch_size: int = 30):
"""Delete a Movie or Show by _id. If it's a Show, delete its Seasons and Episodes in batches, committing after each batch."""
from sqlalchemy.exc import IntegrityError

Expand Down Expand Up @@ -132,7 +132,7 @@ def delete_media_item_by_id(media_item_id: int, batch_size: int = 30):
session.rollback()
return False

def delete_seasons_and_episodes(session, season_ids: list[int], batch_size: int = 30):
def delete_seasons_and_episodes(session, season_ids: list[str], batch_size: int = 30):
"""Delete seasons and episodes of a show in batches, committing after each batch."""
from program.media.item import Episode, Season
from program.media.stream import StreamBlacklistRelation, StreamRelation
Expand Down Expand Up @@ -187,7 +187,7 @@ def clear_streams(item: "MediaItem"):
"""Clear all streams for a media item."""
reset_streams(item)

def clear_streams_by_id(media_item_id: int):
def clear_streams_by_id(media_item_id: str):
"""Clear all streams for a media item by the MediaItem id."""
with db.Session() as session:
session.execute(
Expand Down Expand Up @@ -267,7 +267,7 @@ def unblacklist_stream(item: "MediaItem", stream: Stream, session: Session = Non
if close_session:
session.close()

def get_item_ids(session, item_id: int) -> tuple[int, list[int]]:
def get_item_ids(session, item_id: str) -> tuple[str, list[str]]:
"""Get the item ID and all related item IDs for a given MediaItem."""
from program.media.item import Episode, MediaItem, Season

Expand Down
17 changes: 9 additions & 8 deletions src/program/managers/event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def remove_event_from_queue(self, event: Event):

def remove_event_from_running(self, event: Event):
with self.mutex:
self._running_events.remove(event)
logger.debug(f"Removed {event.log_message} from running events.")
if event in self._running_events:
self._running_events.remove(event)
logger.debug(f"Removed {event.log_message} from running events.")

def remove_id_from_queue(self, item_id: int):
def remove_id_from_queue(self, item_id: str):
"""
Removes an item from the queue.
Expand All @@ -133,7 +134,7 @@ def add_event_to_running(self, event: Event):
self._running_events.append(event)
logger.debug(f"Added {event.log_message} to running events.")

def remove_id_from_running(self, item_id: int):
def remove_id_from_running(self, item_id: str):
"""
Removes an item from the running events.
Expand All @@ -144,12 +145,12 @@ def remove_id_from_running(self, item_id: int):
if event.item_id == item_id:
self.remove_event_from_running(event)

def remove_id_from_queues(self, item_id: int):
def remove_id_from_queues(self, item_id: str):
"""
Removes an item from both the queue and the running events.
Args:
item (MediaItem): The event item to remove from both the queue and the running events.
item_id: The event item to remove from both the queue and the running events.
"""
self.remove_id_from_queue(item_id)
self.remove_id_from_running(item_id)
Expand Down Expand Up @@ -180,7 +181,7 @@ def submit_job(self, service, program, event=None):
sse_manager.publish_event("event_update", self.get_event_updates())
future.add_done_callback(lambda f:self._process_future(f, service))

def cancel_job(self, item_id: int, suppress_logs=False):
def cancel_job(self, item_id: str, suppress_logs=False):
"""
Cancels a job associated with the given item.
Expand Down Expand Up @@ -308,7 +309,7 @@ def add_item(self, item, service="Manual"):
logger.debug(f"Added item with IMDB ID {item.imdb_id} to the queue.")


def get_event_updates(self) -> Dict[str, List[int]]:
def get_event_updates(self) -> Dict[str, List[str]]:
events = [future.event for future in self._futures if hasattr(future, "event")]
event_types = ["Scraping", "Downloader", "Symlinker", "Updater", "PostProcessing"]

Expand Down
25 changes: 17 additions & 8 deletions src/program/media/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class MediaItem(db.Model):
"""MediaItem class"""
__tablename__ = "MediaItem"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True)
id: Mapped[str] = mapped_column(sqlalchemy.String, primary_key=True)
imdb_id: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
tvdb_id: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
tmdb_id: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
Expand Down Expand Up @@ -85,7 +85,7 @@ class MediaItem(db.Model):
def __init__(self, item: dict | None) -> None:
if item is None:
return
self.id = item.get("trakt_id")
self.id = self.__generate_composite_key(item)
self.requested_at = item.get("requested_at", datetime.now())
self.requested_by = item.get("requested_by")
self.requested_id = item.get("requested_id")
Expand Down Expand Up @@ -132,6 +132,15 @@ def __init__(self, item: dict | None) -> None:
#Post processing
self.subtitles = item.get("subtitles", [])

@staticmethod
def __generate_composite_key(item: dict) -> str | None:
"""Generate a composite key for the item."""
trakt_id = item.get("trakt_id", None)
if not trakt_id:
return None
item_type = item.get("type", "unknown")
return f"{item_type}_{trakt_id}"

def store_state(self, given_state=None) -> None:
new_state = given_state if given_state else self._determine_state()
if self.last_state and self.last_state != new_state:
Expand Down Expand Up @@ -386,7 +395,7 @@ def collection(self):
class Movie(MediaItem):
"""Movie class"""
__tablename__ = "Movie"
id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("MediaItem.id"), primary_key=True)
id: Mapped[str] = mapped_column(sqlalchemy.ForeignKey("MediaItem.id"), primary_key=True)
__mapper_args__ = {
"polymorphic_identity": "movie",
"polymorphic_load": "inline",
Expand All @@ -410,7 +419,7 @@ def __hash__(self):
class Show(MediaItem):
"""Show class"""
__tablename__ = "Show"
id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("MediaItem.id"), primary_key=True)
id: Mapped[str] = mapped_column(sqlalchemy.ForeignKey("MediaItem.id"), primary_key=True)
seasons: Mapped[List["Season"]] = relationship(back_populates="parent", foreign_keys="Season.parent_id", lazy="joined", cascade="all, delete-orphan", order_by="Season.number")

__mapper_args__ = {
Expand Down Expand Up @@ -519,8 +528,8 @@ def propagate(target, source):
class Season(MediaItem):
"""Season class"""
__tablename__ = "Season"
id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("MediaItem.id"), primary_key=True)
parent_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("Show.id"), use_existing_column=True)
id: Mapped[str] = mapped_column(sqlalchemy.ForeignKey("MediaItem.id"), primary_key=True)
parent_id: Mapped[str] = mapped_column(sqlalchemy.ForeignKey("Show.id"), use_existing_column=True)
parent: Mapped["Show"] = relationship(lazy=False, back_populates="seasons", foreign_keys="Season.parent_id")
episodes: Mapped[List["Episode"]] = relationship(back_populates="parent", foreign_keys="Episode.parent_id", lazy="joined", cascade="all, delete-orphan", order_by="Episode.number")
__mapper_args__ = {
Expand Down Expand Up @@ -624,8 +633,8 @@ def get_top_title(self) -> str:
class Episode(MediaItem):
"""Episode class"""
__tablename__ = "Episode"
id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("MediaItem.id"), primary_key=True)
parent_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("Season.id"), use_existing_column=True)
id: Mapped[str] = mapped_column(sqlalchemy.ForeignKey("MediaItem.id"), primary_key=True)
parent_id: Mapped[str] = mapped_column(sqlalchemy.ForeignKey("Season.id"), use_existing_column=True)
parent: Mapped["Season"] = relationship(back_populates="episodes", foreign_keys="Episode.parent_id", lazy="joined")

__mapper_args__ = {
Expand Down
7 changes: 7 additions & 0 deletions src/program/services/indexers/trakt.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def _add_seasons_to_show(show: Show, imdb_id: str):
show.add_season(season_item)


def _assign_item_type(item_type):
if item_type == "movie":
return "movie"
return "show"


def _map_item_from_data(data, item_type: str, show_genres: List[str] = None) -> Optional[MediaItem]:
"""Map trakt.tv API data to MediaItemContainer."""
if item_type not in ["movie", "show", "season", "episode"]:
Expand All @@ -146,6 +152,7 @@ def _map_item_from_data(data, item_type: str, show_genres: List[str] = None) ->
"country": getattr(data, "country", None),
"language": getattr(data, "language", None),
"requested_at": datetime.now(),
"type": _assign_item_type(item_type),
}

item["is_anime"] = (
Expand Down
2 changes: 1 addition & 1 deletion src/program/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ProcessedEvent:
@dataclass
class Event:
emitted_by: Service
item_id: Optional[int] = None
item_id: Optional[str] = None
content_item: Optional[MediaItem] = None
run_at: datetime = datetime.now()

Expand Down
42 changes: 21 additions & 21 deletions src/routers/secure/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
)


def handle_ids(ids: str) -> list[int]:
ids = [int(id) for id in ids.split(",")] if "," in ids else [int(ids)]
def handle_ids(ids: str) -> list[str]:
ids = [str(id) for id in ids.split(",")] if "," in ids else [str(ids)]
if not ids:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No item ID provided")
return ids
Expand Down Expand Up @@ -218,12 +218,12 @@ async def add_items(request: Request, imdb_ids: str = None) -> MessageResponse:
description="Fetch a single media item by ID",
operation_id="get_item",
)
async def get_item(_: Request, id: int, use_tmdb_id: Optional[bool] = False) -> dict:
async def get_item(_: Request, id: str, use_tmdb_id: Optional[bool] = False) -> dict:
with db.Session() as session:
try:
query = select(MediaItem)
if use_tmdb_id:
query = query.where(MediaItem.tmdb_id == str(id))
query = query.where(MediaItem.tmdb_id == id)
else:
query = query.where(MediaItem.imdb_id == id)
item = session.execute(query).unique().scalar_one()
Expand Down Expand Up @@ -255,7 +255,7 @@ async def get_items_by_imdb_ids(request: Request, imdb_ids: str) -> list[dict]:

class ResetResponse(BaseModel):
message: str
ids: list[int]
ids: list[str]


@router.post(
Expand Down Expand Up @@ -285,7 +285,7 @@ async def reset_items(request: Request, ids: str) -> ResetResponse:

class RetryResponse(BaseModel):
message: str
ids: list[int]
ids: list[str]


@router.post(
Expand Down Expand Up @@ -315,7 +315,7 @@ async def retry_items(request: Request, ids: str) -> RetryResponse:

class RemoveResponse(BaseModel):
message: str
ids: list[int]
ids: list[str]


@router.delete(
Expand All @@ -325,30 +325,30 @@ class RemoveResponse(BaseModel):
operation_id="remove_item",
)
async def remove_item(request: Request, ids: str) -> RemoveResponse:
ids: list[int] = handle_ids(ids)
ids: list[str] = handle_ids(ids)
try:
media_items: list[int] = db_functions.get_items_by_ids(ids, ["movie", "show"])
media_items: list[str] = db_functions.get_items_by_ids(ids, ["movie", "show"])
if not media_items:
return HTTPException(status_code=404, detail="Item(s) not found")
for item_id in media_items:
logger.debug(f"Removing item with ID {item_id}")
request.app.program.em.cancel_job(item_id)
for item in media_items:
logger.debug(f"Removing item with ID {item.id}")
request.app.program.em.cancel_job(item.id)
await asyncio.sleep(0.2) # Ensure cancellation is processed
db_functions.clear_streams_by_id(item_id)
db_functions.clear_streams_by_id(item.id)

symlink_service = request.app.program.services.get(Symlinker)
if symlink_service:
symlink_service.delete_item_symlinks_by_id(item_id)
symlink_service.delete_item_symlinks_by_id(item.id)

with db.Session() as session:
requested_id = session.execute(select(MediaItem.requested_id).where(MediaItem.id == item_id)).scalar_one()
requested_id = session.execute(select(MediaItem.requested_id).where(MediaItem.id == item.id)).scalar_one()
if requested_id:
logger.debug(f"Deleting request from Overseerr with ID {requested_id}")
Overseerr.delete_request(requested_id)

logger.debug(f"Deleting item from database with ID {item_id}")
db_functions.delete_media_item_by_id(item_id)
logger.info(f"Successfully removed item with ID {item_id}")
logger.debug(f"Deleting item from database with ID {item.id}")
db_functions.delete_media_item_by_id(item.id)
logger.info(f"Successfully removed item with ID {item.id}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

Expand All @@ -357,7 +357,7 @@ async def remove_item(request: Request, ids: str) -> RemoveResponse:
@router.get(
"/{item_id}/streams"
)
async def get_item_streams(_: Request, item_id: int, db: Session = Depends(get_db)):
async def get_item_streams(_: Request, item_id: str, db: Session = Depends(get_db)):
item: MediaItem = (
db.execute(
select(MediaItem)
Expand All @@ -379,7 +379,7 @@ async def get_item_streams(_: Request, item_id: int, db: Session = Depends(get_d
@router.post(
"/{item_id}/streams/{stream_id}/blacklist"
)
async def blacklist_stream(_: Request, item_id: int, stream_id: int, db: Session = Depends(get_db)):
async def blacklist_stream(_: Request, item_id: str, stream_id: int, db: Session = Depends(get_db)):
item: MediaItem = (
db.execute(
select(MediaItem)
Expand All @@ -402,7 +402,7 @@ async def blacklist_stream(_: Request, item_id: int, stream_id: int, db: Session
@router.post(
"{item_id}/streams/{stream_id}/unblacklist"
)
async def unblacklist_stream(_: Request, item_id: int, stream_id: int, db: Session = Depends(get_db)):
async def unblacklist_stream(_: Request, item_id: str, stream_id: int, db: Session = Depends(get_db)):
item: MediaItem = (
db.execute(
select(MediaItem)
Expand Down

0 comments on commit 2a0291c

Please sign in to comment.