diff --git a/tests/unit/utils/db/test_windowed_query.py b/tests/unit/utils/db/test_windowed_query.py deleted file mode 100644 index 9c750cc7cafc..000000000000 --- a/tests/unit/utils/db/test_windowed_query.py +++ /dev/null @@ -1,41 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import pytest - -from sqlalchemy import select - -from warehouse.packaging.models import Project -from warehouse.utils.db.windowed_query import windowed_query - -from ....common.db.packaging import ProjectFactory - - -@pytest.mark.parametrize("window_size", [1, 2]) -def test_windowed_query(db_session, query_recorder, window_size): - projects = ProjectFactory.create_batch(10) - project_set = {(project.name, project.id) for project in projects} - - expected = math.ceil(len(projects) / window_size) + 1 - - query = select(Project) - - result_set = set() - with query_recorder: - for result in windowed_query(db_session, query, Project.id, window_size): - for project in result.scalars(): - result_set.add((project.name, project.id)) - - assert result_set == project_set - assert len(query_recorder.queries) == expected diff --git a/warehouse/search/tasks.py b/warehouse/search/tasks.py index a675ee929393..9fd2a3d2d054 100644 --- a/warehouse/search/tasks.py +++ b/warehouse/search/tasks.py @@ -23,7 +23,6 @@ from opensearchpy.helpers import parallel_bulk from redis.lock import Lock from sqlalchemy import func, select, text -from sqlalchemy.orm import aliased from urllib3.util import parse_url from warehouse import tasks @@ -36,28 +35,10 @@ ) from warehouse.packaging.search import Project as ProjectDocument from warehouse.search.utils import get_index -from warehouse.utils.db import windowed_query -def _project_docs(db, project_name=None): - releases_list = ( - select(Release.id) - .filter(Release.yanked.is_(False), Release.files.any()) - .order_by( - Release.project_id, - Release.is_prerelease.nullslast(), - Release._pypi_ordering.desc(), - ) - .distinct(Release.project_id) - ) - - if project_name: - releases_list = releases_list.join(Project).filter(Project.name == project_name) - - releases_list = releases_list.subquery() - rlist = aliased(Release, releases_list) - - classifiers = ( +def _project_docs(db, project_name: str | None = None): + classifiers_subquery = ( select(func.array_agg(Classifier.classifier)) .select_from(ReleaseClassifiers) .join(Classifier, Classifier.id == ReleaseClassifiers.trove_id) @@ -66,8 +47,7 @@ def _project_docs(db, project_name=None): .scalar_subquery() .label("classifiers") ) - - release_data = ( + projects_to_index = ( select( Description.raw.label("description"), Release.author, @@ -80,18 +60,32 @@ def _project_docs(db, project_name=None): Release.platform, Release.download_url, Release.created, - classifiers, + classifiers_subquery, Project.normalized_name, Project.name, ) - .select_from(rlist) - .join(Release, Release.id == rlist.id) + .select_from(Release) .join(Description) - .outerjoin(Release.project) + .join(Project) + .filter( + Release.yanked.is_(False), + Release.files.any(), + # Filter by project_name if provided + Project.name == project_name if project_name else text("TRUE"), + ) + .order_by( + Project.name, + Release.is_prerelease.nullslast(), + Release._pypi_ordering.desc(), + ) + .distinct(Project.name) + .execution_options(yield_per=25000) ) - for chunk in windowed_query(db, release_data, Project.name, 25000): - for release in chunk: + results = db.execute(projects_to_index) + + for partition in results.partitions(): + for release in partition: p = ProjectDocument.from_db(release) p._index = None p.full_clean() diff --git a/warehouse/utils/db/__init__.py b/warehouse/utils/db/__init__.py index caca2a78ab00..ae40f369c161 100644 --- a/warehouse/utils/db/__init__.py +++ b/warehouse/utils/db/__init__.py @@ -10,6 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from warehouse.utils.db.windowed_query import windowed_query +from warehouse.utils.db.query_printer import print_query -__all__ = ["windowed_query"] +__all__ = ["print_query"] diff --git a/warehouse/utils/db/windowed_query.py b/warehouse/utils/db/windowed_query.py deleted file mode 100644 index bf25b383a3c0..000000000000 --- a/warehouse/utils/db/windowed_query.py +++ /dev/null @@ -1,77 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Taken from "Theatrum Chemicum" at -# https://github.com/sqlalchemy/sqlalchemy/wiki/RangeQuery-and-WindowedRangeQuery - -from __future__ import annotations - -import typing - -from collections.abc import Iterator -from typing import Any - -from sqlalchemy import and_, func, select -from sqlalchemy.orm import Session - -if typing.TYPE_CHECKING: - from sqlalchemy import Result, Select, SQLColumnExpression - - -def column_windows( - session: Session, - stmt: Select[Any], - column: SQLColumnExpression[Any], - windowsize: int, -) -> Iterator[SQLColumnExpression[bool]]: - """Return a series of WHERE clauses against - a given column that break it into windows. - - Result is an iterable of WHERE clauses that are packaged with - the individual ranges to select from. - - Requires a database that supports window functions. - """ - rownum = func.row_number().over(order_by=column).label("rownum") - - subq = stmt.add_columns(rownum).subquery() - subq_column = list(subq.columns)[-1] - - target_column = subq.corresponding_column(column) # type: ignore - new_stmt = select(target_column) # type: ignore - - if windowsize > 1: - new_stmt = new_stmt.filter(subq_column % windowsize == 1) - - intervals = list(session.scalars(new_stmt)) - - # yield out WHERE clauses for each range - while intervals: - start = intervals.pop(0) - if intervals: - yield and_(column >= start, column < intervals[0]) - else: - yield column >= start - - -def windowed_query( - session: Session, - stmt: Select[Any], - column: SQLColumnExpression[Any], - windowsize: int, -) -> Iterator[Result[Any]]: - """Given a Session and Select() object, organize and execute the statement - such that it is invoked for ordered chunks of the total result. yield - out individual Result objects for each chunk. - """ - for whereclause in column_windows(session, stmt, column, windowsize): - yield session.execute(stmt.filter(whereclause).order_by(column))