Skip to content

Commit

Permalink
Replace return_meta by ResultList with meta attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
J535D165 committed Dec 22, 2024
1 parent 08c4c23 commit 265f1ed
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 93 deletions.
106 changes: 71 additions & 35 deletions pyalex/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,22 +205,20 @@ def __next__(self):
else:
raise ValueError()

results, meta = self.endpoint_class.get(
return_meta=True, per_page=self.per_page, **pagination_params
)
r = self.endpoint_class.get(per_page=self.per_page, **pagination_params)

if self.method == "cursor":
self._next_value = meta["next_cursor"]
self._next_value = r.meta["next_cursor"]

if self.method == "page":
if len(results) > 0:
self._next_value = meta["page"] + 1
if len(r) > 0:
self._next_value = r.meta["page"] + 1
else:
self._next_value = None

self.n = self.n + len(results)
self.n = self.n + len(r)

return results
return r


class OpenAlexAuth(AuthBase):
Expand Down Expand Up @@ -249,6 +247,54 @@ def __call__(self, r):
return r


class ResultList(list):
"""A list of OpenAlexEntity objects with metadata.
Attributes:
meta: a dictionary with metadata about the results
data_attr: the key in the response dictionary that contains the results
resource_class: the class to use for each entity in the results
Arguments:
results: a list of OpenAlexEntity objects
meta: a dictionary with metadata about the results
resource_class: the class to use for each entity in the results
data_attr: the key in the response dictionary that contains the results
Returns:
a ResultList object
"""

def __init__(self, results, meta, resource_class=None, data_attr="results"):
resource_class = resource_class or OpenAlexEntity
results_list = [resource_class(ent) for ent in results]

super().__init__(results_list)
self.meta = meta
self.data_attr = data_attr
self.resource_class = resource_class

@classmethod
def from_results(cls, response, resource_class=None, data_attr="results"):
"""Create a ResultList from a response dictionary.
Arguments:
response: the response dictionary from the OpenAlex API
resource_class: the class to use for each entity in the results
data_attr: the key in the response dictionary that contains the results
Returns:
a ResultList object
"""
return cls(
response[data_attr],
response["meta"],
resource_class=resource_class,
data_attr=data_attr,
)


class BaseOpenAlex:
"""Base class for OpenAlex objects."""

Expand Down Expand Up @@ -286,7 +332,6 @@ def __getitem__(self, record_id):

return self._get_from_url(
f"{self._full_collection_name()}/{_quote_oa_value(record_id)}",
return_meta=False,
)

@property
Expand All @@ -311,14 +356,11 @@ def url(self):
return self._full_collection_name()

def count(self):
_, m = self.get(return_meta=True, per_page=1)

return m["count"]
return self.get(per_page=1).meta["count"]

def _get_from_url(self, url, return_meta=False):
def _get_from_url(self, url):
res = _get_requests_session().get(url, auth=OpenAlexAuth(config))

# handle query errors
if res.status_code == 403:
if (
isinstance(res.json()["error"], str)
Expand All @@ -329,30 +371,30 @@ def _get_from_url(self, url, return_meta=False):
res.raise_for_status()
res_json = res.json()

# group-by or results page
if self.params and "group-by" in self.params:
results = res_json["group_by"]
return ResultList.from_results(
res_json, self.resource_class, data_attr="group_by"
)
elif "results" in res_json:
results = [self.resource_class(ent) for ent in res_json["results"]]
return ResultList.from_results(res_json, self.resource_class)
elif "id" in res_json:
results = self.resource_class(res_json)
return self.resource_class(res_json)
else:
raise ValueError("Unknown response format")

# return result and metadata
if return_meta:
return results, res_json["meta"]
else:
return results

def get(self, return_meta=False, page=None, per_page=None, cursor=None):
def get(self, page=None, per_page=None, cursor=None, return_meta=False):
if per_page is not None and (per_page < 1 or per_page > 200):
raise ValueError("per_page should be a number between 1 and 200.")

if return_meta:
raise DeprecationWarning(
"return_meta is deprecated, call .meta on the result"
)

self._add_params("per-page", per_page)
self._add_params("page", page)
self._add_params("cursor", cursor)
return self._get_from_url(self.url, return_meta=return_meta)
return self._get_from_url(self.url)

def paginate(self, method="cursor", page=1, per_page=None, cursor="*", n_max=10000):
if method == "cursor":
Expand Down Expand Up @@ -448,19 +490,15 @@ def __getitem__(self, key):

return super().__getitem__(key)

def ngrams(self, return_meta=False):
def ngrams(self):
openalex_id = self["id"].split("/")[-1]
n_gram_url = f"{config.openalex_url}/works/{openalex_id}/ngrams"

res = _get_requests_session().get(n_gram_url, auth=OpenAlexAuth(config))
res.raise_for_status()
results = res.json()

# return result and metadata
if return_meta:
return results["ngrams"], results["meta"]
else:
return results["ngrams"]
return ResultList.from_results(results, data_attr="ngrams")


class Works(BaseOpenAlex):
Expand Down Expand Up @@ -549,9 +587,7 @@ class autocompletes(BaseOpenAlex):
resource_class = Autocomplete

def __getitem__(self, key):
return self._get_from_url(
f"{config.openalex_url}/autocomplete?q={key}", return_meta=False
)
return self._get_from_url(f"{config.openalex_url}/autocomplete?q={key}")


class Concept(OpenAlexEntity):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_paging.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def test_cursor():
# loop till next_cursor is None
while next_cursor is not None:
# get the results
r, m = query.get(return_meta=True, per_page=200, cursor=next_cursor)
r = query.get(per_page=200, cursor=next_cursor)

# results
results.extend(r)

# set the next cursor
next_cursor = m["next_cursor"]
next_cursor = r.meta["next_cursor"]

assert len(results) > 200

Expand All @@ -41,17 +41,17 @@ def test_page():
# loop till page is None
while page is not None:
# get the results
r, m = query.get(return_meta=True, per_page=200, page=page)
r = query.get(per_page=200, page=page)

# results
results.extend(r)
page = None if len(r) == 0 else m["page"] + 1
page = None if len(r) == 0 else r.meta["page"] + 1

assert len(results) > 200


def test_paginate_counts():
_, m = Authors().search_filter(display_name="einstein").get(return_meta=True)
r = Authors().search_filter(display_name="einstein").get()

p_default = Authors().search_filter(display_name="einstein").paginate(per_page=200)
n_p_default = sum(len(page) for page in p_default)
Expand All @@ -70,7 +70,7 @@ def test_paginate_counts():
)
n_p_page = sum(len(page) for page in p_page)

assert m["count"] == n_p_page >= n_p_default == n_p_cursor
assert r.meta["count"] == n_p_page >= n_p_default == n_p_cursor


def test_paginate_instance():
Expand Down
Loading

0 comments on commit 265f1ed

Please sign in to comment.