Skip to content

Commit

Permalink
fix: improved removing items from database
Browse files Browse the repository at this point in the history
  • Loading branch information
dreulavelle committed Oct 10, 2024
1 parent ff14f85 commit e4b6e2b
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 70 deletions.
48 changes: 29 additions & 19 deletions src/controllers/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,28 @@
from typing import Optional

import Levenshtein
from RTN import RTN, Torrent
from RTN import Torrent
from fastapi import APIRouter, HTTPException, Request
from sqlalchemy import func, select
from sqlalchemy import delete, func, select
from sqlalchemy.exc import NoResultFound

from program.content import Overseerr
from program.db.db import db
from program.db.db_functions import (
clear_streams,
clear_streams_by_id,
delete_media_item,
delete_media_item_by_id,
get_media_items_by_ids,
get_parent_ids,
get_parent_items_by_ids,
reset_media_item,
)
from program.media.item import MediaItem
from program.media.item import Episode, MediaItem, Season
from program.media.state import States
from program.symlink import Symlinker
from program.downloaders import Downloader, get_needed_media
from program.downloaders.realdebrid import RealDebridDownloader, add_torrent_magnet, torrent_info
from program.settings.versions import models
from program.settings.manager import settings_manager
from program.downloaders.realdebrid import add_torrent_magnet, torrent_info
from program.media.stream import Stream
from program.scrapers.shared import rtn
from program.types import Event
Expand Down Expand Up @@ -246,29 +247,38 @@ async def retry_items(request: Request, ids: str):
"/remove",
summary="Remove Media Items",
description="Remove media items based on item IDs",
operation_id="remove_item",
)
async def remove_item(request: Request, ids: str):
ids = handle_ids(ids)
ids: list[int] = handle_ids(ids)
try:
media_items = get_parent_items_by_ids(ids)
media_items: list[int] = get_parent_ids(ids)
if not media_items:
raise ValueError("Invalid item ID(s) provided. Some items may not exist.")
return HTTPException(status_code=404, detail="Item(s) not found")

for media_item in media_items:
logger.debug(f"Removing item {media_item.title} with ID {media_item._id}")
request.app.program.em.cancel_job(media_item)
await asyncio.sleep(0.1) # Ensure cancellation is processed
clear_streams(media_item)
logger.debug(f"Removing item with ID {media_item}")
request.app.program.em.cancel_job_by_id(media_item)
await asyncio.sleep(0.2) # Ensure cancellation is processed
clear_streams_by_id(media_item)

symlink_service = request.app.program.services.get(Symlinker)
if symlink_service:
symlink_service.delete_item_symlinks(media_item)
if media_item.requested_by == "overseerr" and media_item.requested_id:
logger.debug(f"Item was originally requested by Overseerr, deleting request within Overseerr...")
Overseerr.delete_request(media_item.requested_id)
delete_media_item(media_item)
symlink_service.delete_item_symlinks_by_id(media_item)

with db.Session() as session:
requested_id = session.execute(select(MediaItem.requested_id).where(MediaItem._id == media_item)).scalar_one()
if requested_id:
logger.debug(f"Deleting request from Overseerr with ID {requested_id}")
Overseerr.delete_request(requested_id)

logger.debug(f"Deleting item from database with ID {media_item}")
delete_media_item_by_id(media_item)
logger.info(f"Successfully removed item with ID {media_item}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

return {"success": True, "message": f"Removed items with ids {ids}"}
return {"success": True, "message": f"Successfully removed items", "removed_ids": ids}

@router.post("/{id}/set_torrent_rd_magnet", description="Set a torrent for a media item using a magnet link.")
def add_torrent(request: Request, id: int, magnet: str):
Expand Down
143 changes: 122 additions & 21 deletions src/program/db/db_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import TYPE_CHECKING, List

import alembic
from sqlalchemy import delete, func, insert, select, text, union_all
from sqlalchemy.orm import Session, aliased, selectinload
from sqlalchemy import delete, func, insert, select, text
from sqlalchemy.orm import Session, selectinload

from program.libraries.symlink import fix_broken_symlinks
from program.media.stream import Stream, StreamBlacklistRelation, StreamRelation
Expand All @@ -15,7 +15,7 @@
from .db import alembic, db

if TYPE_CHECKING:
from program.media.item import MediaItem
from program.media.item import MediaItem, Episode, Season


def get_media_items_by_ids(media_item_ids: list[int]):
Expand Down Expand Up @@ -69,6 +69,20 @@ def get_parent_items_by_ids(media_item_ids: list[int]):
items.append(item)
return items

def get_parent_ids(media_item_ids: list[int]):
"""Retrieve the _ids of MediaItems of type 'movie' or 'show' by a list of MediaItem _ids."""
from program.media.item import MediaItem
with db.Session() as session:
parent_ids = []
for media_item_id in media_item_ids:
item_id = session.execute(
select(MediaItem._id)
.where(MediaItem._id == media_item_id, MediaItem.type.in_(["movie", "show"]))
).scalar_one()
if item_id:
parent_ids.append(item_id)
return parent_ids

def get_item_by_imdb_id(imdb_id: str):
"""Retrieve a MediaItem of type 'movie' or 'show' by an IMDb ID."""
from program.media.item import MediaItem
Expand All @@ -83,17 +97,68 @@ def delete_media_item(item: "MediaItem"):
session.delete(item)
session.commit()

def delete_media_item_by_id(media_item_id: int):
"""Delete a MediaItem and all its associated relationships by the MediaItem _id."""
from program.media.item import MediaItem
def delete_media_item_by_id(media_item_id: int, batch_size: int = 30):
"""Delete a Movie or Show by _id. If it's a Show, delete its Seasons and Episodes in batches, committing after each batch."""
from program.media.item import MediaItem, Show, Season, Episode
with db.Session() as session:
item = session.query(MediaItem).filter_by(_id=media_item_id).first()
# First, retrieve the media item's type
media_item = session.execute(
select(MediaItem._id, MediaItem.type)
.where(MediaItem._id == media_item_id)
).first()

if not media_item:
logger.error(f"No item found with ID {media_item_id}")
return False

media_item_id, media_item_type = media_item

if media_item_type == "movie":
# Directly delete movie
logger.debug(f"Deleting Movie with ID {media_item_id}")
session.execute(delete(MediaItem).where(MediaItem._id == media_item_id))
session.commit() # Commit movie deletion immediately

elif media_item_type == "show":
logger.debug(f"Deleting Show with ID {media_item_id}")

# Delete Seasons and Episodes in batches, committing after each batch
delete_seasons_and_episodes(session, media_item_id, batch_size)

# Ensure no references exist in the Show table
session.execute(delete(Show).where(Show._id == media_item_id))

# Delete the Show itself
session.execute(delete(MediaItem).where(MediaItem._id == media_item_id))
session.commit() # Commit show deletion at the end

return True

def delete_seasons_and_episodes(session, show_id: int, batch_size: int):
"""Delete seasons and episodes of a show in batches, committing after each batch."""
from program.media.item import Episode, Season

if item:
session.delete(item)
session.commit()
else:
raise ValueError(f"MediaItem with id {media_item_id} does not exist.")
# Delete seasons one by one
season_ids = session.execute(
select(Season._id).where(Season.parent_id == show_id)
).scalars().all()

for season_id in season_ids:
# Delete episodes in batches for each season
while True:
episode_ids = session.execute(
select(Episode._id)
.where(Episode.parent_id == season_id)
.limit(batch_size)
).scalars().all()

if not episode_ids:
break

session.execute(delete(Episode).where(Episode._id.in_(episode_ids)))
session.commit() # Commit after each batch of episodes
session.execute(delete(Season).where(Season._id == season_id))
session.commit() # Commit after deleting the season

def delete_media_item_by_item_id(item_id: str):
"""Delete a MediaItem and all its associated relationships by the MediaItem _id."""
Expand Down Expand Up @@ -156,6 +221,17 @@ def clear_streams(item: "MediaItem"):
)
session.commit()

def clear_streams_by_id(media_item_id: int):
"""Clear all streams for a media item by the MediaItem _id."""
with db.Session() as session:
session.execute(
delete(StreamRelation).where(StreamRelation.parent_id == media_item_id)
)
session.execute(
delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == media_item_id)
)
session.commit()

def blacklist_stream(item: "MediaItem", stream: Stream, session: Session = None) -> bool:
"""Blacklist a stream for a media item."""
close_session = False
Expand Down Expand Up @@ -233,19 +309,37 @@ def load_streams_in_pages(session: Session, media_item_id: int, page_number: int

def _get_item_ids(session, item):
from program.media.item import Episode, Season

if item.type == "show":
show_id = item._id

season_alias = aliased(Season, flat=True)
season_query = select(Season._id.label('id')).where(Season.parent_id == show_id)
episode_query = (
select(Episode._id.label('id'))
.join(season_alias, Episode.parent_id == season_alias._id)
.where(season_alias.parent_id == show_id)
)
# season_alias = aliased(Season, flat=True)
# season_query = select(Season._id.label('id')).where(Season.parent_id == show_id)
# episode_query = (
# select(Episode._id.label('id'))
# .join(season_alias, Episode.parent_id == season_alias._id)
# .where(season_alias.parent_id == show_id)
# )

# combined_query = union_all(season_query, episode_query)
# related_ids = session.execute(combined_query).scalars().all()
# return show_id, related_ids

# Fetch season IDs
season_ids = session.execute(
select(Season._id).where(Season.parent_id == show_id)
).scalars().all()

# Fetch episode IDs for each season
episode_ids = []
for season_id in season_ids:
episode_ids.extend(
session.execute(
select(Episode._id).where(Episode.parent_id == season_id)
).scalars().all()
)

combined_query = union_all(season_query, episode_query)
related_ids = session.execute(combined_query).scalars().all()
related_ids = season_ids + episode_ids
return show_id, related_ids

elif item.type == "season":
Expand All @@ -263,6 +357,13 @@ def _get_item_ids(session, item):

return item._id, []

def _get_item_ids_from_item_id(session, media_item_id: int):
from program.media.item import MediaItem
item = session.execute(select(MediaItem).where(MediaItem._id == media_item_id)).unique().scalar_one_or_none()
if not item:
return None
return _get_item_ids(session, item)

def _ensure_item_exists_in_db(item: "MediaItem") -> bool:
from program.media.item import MediaItem, Movie, Show
if isinstance(item, (Movie, Show)):
Expand Down
11 changes: 7 additions & 4 deletions src/program/media/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,13 @@ def copy_other_media_attr(self, other):

def is_scraped(self):
session = object_session(self)
if session:
session.refresh(self, attribute_names=['blacklisted_streams']) # Prom: Ensure these reflect the state of whats in the db.
return (len(self.streams) > 0
and any(not stream in self.blacklisted_streams for stream in self.streams))
if session and session.is_active:
try:
session.refresh(self, attribute_names=['blacklisted_streams'])
return (len(self.streams) > 0 and any(not stream in self.blacklisted_streams for stream in self.streams))
except (sqlalchemy.exc.InvalidRequestError, sqlalchemy.orm.exc.DetachedInstanceError):
return False
return False

def to_dict(self):
"""Convert item to dictionary (API response)"""
Expand Down
Loading

0 comments on commit e4b6e2b

Please sign in to comment.