diff --git a/tests/test_cdx_utils.py b/tests/test_cdx_utils.py index 6a57e2f..ccea1b1 100644 --- a/tests/test_cdx_utils.py +++ b/tests/test_cdx_utils.py @@ -6,6 +6,7 @@ check_collapses, check_filters, check_match_type, + check_sort, full_url, get_response, get_total_pages, @@ -101,3 +102,12 @@ def test_check_match_type() -> None: with pytest.raises(WaybackError): check_match_type("not a valid type", "url") + + +def test_check_sort() -> None: + assert check_sort("default") + assert check_sort("closest") + assert check_sort("reverse") + + with pytest.raises(WaybackError): + assert check_sort("random crap") diff --git a/waybackpy/cdx_api.py b/waybackpy/cdx_api.py index fb7587a..db02bf5 100644 --- a/waybackpy/cdx_api.py +++ b/waybackpy/cdx_api.py @@ -16,6 +16,7 @@ check_collapses, check_filters, check_match_type, + check_sort, full_url, get_response, get_total_pages, @@ -44,6 +45,7 @@ def __init__( end_timestamp: Optional[str] = None, filters: Optional[List[str]] = None, match_type: Optional[str] = None, + sort: Optional[str] = None, gzip: Optional[str] = None, collapses: Optional[List[str]] = None, limit: Optional[str] = None, @@ -57,6 +59,8 @@ def __init__( check_filters(self.filters) self.match_type = None if match_type is None else str(match_type).strip() check_match_type(self.match_type, self.url) + self.sort = None if sort is None else str(sort).strip() + check_sort(self.sort) self.gzip = gzip self.collapses = [] if collapses is None else collapses check_collapses(self.collapses) @@ -165,6 +169,9 @@ def add_payload(self, payload: Dict[str, str]) -> None: if self.match_type: payload["matchType"] = self.match_type + if self.sort: + payload["sort"] = self.sort + if self.filters and len(self.filters) > 0: for i, _filter in enumerate(self.filters): payload["filter" + str(i)] = _filter diff --git a/waybackpy/cdx_utils.py b/waybackpy/cdx_utils.py index 583dd26..79d222e 100644 --- a/waybackpy/cdx_utils.py +++ b/waybackpy/cdx_utils.py @@ -151,3 +151,24 @@ def check_match_type(match_type: Optional[str], url: str) -> bool: raise WaybackError(exc_message) return True + + +def check_sort(sort: Optional[str]) -> bool: + """ + Check that the sort argument passed by the end-user is valid. + If not valid then raise WaybackError. + """ + + legal_sort = ["default", "closest", "reverse"] + + if not sort: + return True + + if sort not in legal_sort: + exc_message = ( + f"{sort} is not an allowed argument for sort.\n" + "Use one from 'default', 'closest' or 'reverse'" + ) + raise WaybackError(exc_message) + + return True