Skip to content
This repository has been archived by the owner on Feb 22, 2023. It is now read-only.

Add option to sort search results by created_on #916

Merged
merged 18 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/catalog/api/constants/field_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

json_fields = [
"id",
"created_on",
dhruvkb marked this conversation as resolved.
Show resolved Hide resolved
"title",
"foreign_landing_url",
"url",
Expand Down
13 changes: 13 additions & 0 deletions api/catalog/api/constants/sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
RELEVANCE = "relevance"
INDEXED_ON = "indexed_on"
SORT_FIELDS = [
(RELEVANCE, "Relevance"), # default
(INDEXED_ON, "Indexing date"), # date on which media was indexed into Openverse
]

DESCENDING = "desc"
ASCENDING = "asc"
SORT_DIRECTIONS = [
(DESCENDING, "Descending"), # default
(ASCENDING, "Ascending"),
]
15 changes: 10 additions & 5 deletions api/catalog/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pprint
from itertools import accumulate
from math import ceil
from typing import Any, Literal
from typing import Literal

from django.conf import settings
from django.core.cache import cache
Expand All @@ -17,6 +17,8 @@
from elasticsearch_dsl.response import Hit, Response

import catalog.api.models as models
from catalog.api.constants.sorting import INDEXED_ON
from catalog.api.serializers import media_serializers
from catalog.api.utils import tallies
from catalog.api.utils.dead_link_mask import get_query_hash, get_query_mask
from catalog.api.utils.validate_images import validate_images
Expand Down Expand Up @@ -220,8 +222,7 @@ def _post_process_results(

def _apply_filter(
s: Search,
# Any is used here to avoid a circular import
search_params: Any, # MediaSearchRequestSerializer
search_params: media_serializers.MediaSearchRequestSerializer,
serializer_field: str,
es_field: str | None = None,
behaviour: Literal["filter", "exclude"] = "filter",
Expand Down Expand Up @@ -278,8 +279,7 @@ def _exclude_mature_by_param(s: Search, search_params):


def search(
# Any is used here to avoid a circular import
search_params: Any, # MediaSearchRequestSerializer
search_params: media_serializers.MediaSearchRequestSerializer,
index: Literal["image", "audio"],
page_size: int,
ip: int,
Expand Down Expand Up @@ -390,6 +390,11 @@ def search(
# Route users to the same Elasticsearch worker node to reduce
# pagination inconsistencies and increase cache hits.
s = s.params(preference=str(ip), request_timeout=7)

# Sort by new
if search_params.validated_data["sort_by"] == INDEXED_ON:
s = s.sort({"created_on": {"order": search_params.validated_data["sort_dir"]}})

# Paginate
start, end = _get_query_slice(s, page_size, page, filter_dead)
s = s[start:end]
Expand Down
1 change: 1 addition & 0 deletions api/catalog/api/examples/audio_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

base_audio = {
"id": identifier,
"created_on": "2022-12-06",
"title": "Wish You Were Here",
"foreign_landing_url": "https://www.jamendo.com/track/1214935",
"url": "https://mp3d.jamendo.com/download/track/1214935/mp32",
Expand Down
1 change: 1 addition & 0 deletions api/catalog/api/examples/image_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

base_image = {
"id": identifier,
"created_on": "2022-08-27",
"title": "Tree Bark Photo",
"foreign_landing_url": "https://stocksnap.io/photo/XNVBVXO3B7",
"url": "https://cdn.stocksnap.io/img-thumbs/960w/XNVBVXO3B7.jpg",
Expand Down
46 changes: 46 additions & 0 deletions api/catalog/api/serializers/media_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from rest_framework.exceptions import NotAuthenticated

from catalog.api.constants.licenses import LICENSE_GROUPS
from catalog.api.constants.sorting import (
DESCENDING,
RELEVANCE,
SORT_DIRECTIONS,
SORT_FIELDS,
)
from catalog.api.controllers import search_controller
from catalog.api.models.media import AbstractMedia
from catalog.api.serializers.base import BaseModelSerializer
Expand Down Expand Up @@ -42,6 +48,8 @@ class MediaSearchRequestSerializer(serializers.Serializer):
"extension",
"mature",
"qa",
# "unstable__sort_by", # excluding unstable fields
# "unstable__sort_dir", # excluding unstable fields
dhruvkb marked this conversation as resolved.
Show resolved Hide resolved
"page_size",
"page",
]
Expand Down Expand Up @@ -109,6 +117,28 @@ class MediaSearchRequestSerializer(serializers.Serializer):
required=False,
default=False,
)

# The ``unstable__`` prefix is used in the query params.
# The validated data does not contain the ``unstable__`` prefix.
# If you rename these fields, update the following references:
# - ``field_names`` in ``MediaSearchRequestSerializer``
# - validators for these fields in ``MediaSearchRequestSerializer``
unstable__sort_by = serializers.ChoiceField(
source="sort_by",
help_text="The field which should be the basis for sorting results.",
choices=SORT_FIELDS,
required=False,
default=RELEVANCE,
)
unstable__sort_dir = serializers.ChoiceField(
source="sort_dir",
help_text="The direction of sorting. Cannot be applied when sorting by "
"`relevance`.",
choices=SORT_DIRECTIONS,
required=False,
default=DESCENDING,
)

page_size = serializers.IntegerField(
label="page_size",
help_text="Number of results to return per page.",
Expand Down Expand Up @@ -170,6 +200,16 @@ def validate_tags(self, value):
def validate_title(self, value):
return self._truncate(value)

def validate_unstable__sort_by(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
return RELEVANCE if is_anonymous else value

def validate_unstable__sort_dir(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
return DESCENDING if is_anonymous else value

def validate_page_size(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
Expand Down Expand Up @@ -314,6 +354,7 @@ class Meta:
model = AbstractMedia
fields = [
"id",
"created_on",
"title",
"foreign_landing_url",
"url",
Expand Down Expand Up @@ -345,6 +386,11 @@ class Meta:
source="identifier",
)

created_on = serializers.DateTimeField(
dhruvkb marked this conversation as resolved.
Show resolved Hide resolved
format="%Y-%m-%d",
dhruvkb marked this conversation as resolved.
Show resolved Hide resolved
help_text="The timestamp of when the media was indexed by Openverse.",
)

tags = TagSerializer(
allow_null=True, # replaced with ``[]`` in ``to_representation`` below
many=True,
Expand Down
17 changes: 11 additions & 6 deletions api/catalog/api/views/media_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,17 @@ def _get_request_serializer(self, request):
return req_serializer

def get_db_results(self, results):
hit_map = {hit.identifier: hit for hit in results}
results = self.get_queryset().filter(identifier__in=hit_map.keys())
for obj in results:
obj.fields_matched = getattr(
hit_map[str(obj.identifier)], "fields_matched", None
)
identifiers = []
hits = []
for hit in results:
identifiers.append(hit.identifier)
hits.append(hit)

results = list(self.get_queryset().filter(identifier__in=identifiers))
results.sort(key=lambda x: identifiers.index(str(x.identifier)))
for result, hit in zip(results, hits):
result.fields_matched = getattr(hit, "fields_matched", None)

return results

# Standard actions
Expand Down
6 changes: 6 additions & 0 deletions api/catalog/templates/drf-yasg/redoc.html
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,11 @@
img[alt="logo"] {
padding: 20px; /* same as other sidebar items */
}

/* Hide fields that are unstable and likely to change */
td[kind="field"][title^="unstable__"],
td[kind="field"][title^="unstable__"] ~ td {
display: none
}
</style>
{% endblock %}
25 changes: 25 additions & 0 deletions api/test/auth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,31 @@ def test_auth_rate_limit_reporting(
assert res_data["verified"] is False


@pytest.mark.django_db
@pytest.mark.parametrize(
"sort_dir, exp_created_on",
[
("desc", "2022-12-31"),
("asc", "2022-01-01"),
],
)
def test_sorting_authed(
client, monkeypatch, test_auth_token_exchange, sort_dir, exp_created_on
):
# Prevent DB lookup for ES results because DB is empty.
monkeypatch.setattr("catalog.api.views.image_views.ImageSerializer.needs_db", False)

time.sleep(1)
token = test_auth_token_exchange["access_token"]
query_params = {"unstable__sort_by": "indexed_on", "unstable__sort_dir": sort_dir}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
assert res.status_code == 200

res_data = res.json()
created_on = res_data["results"][0]["created_on"][:10] # ``created_on`` is ISO.
assert created_on == exp_created_on


@pytest.mark.django_db
def test_page_size_limit_unauthed(client):
query_params = {"page_size": 20}
Expand Down
Loading