Skip to content

Commit

Permalink
[Core-531] Catalog support for on-behalf-of (#12666)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 807be42ec6ddac93ac2153bb1af783d7b409fa9f
  • Loading branch information
stephencpope authored and Descartes Labs Build committed Sep 18, 2024
1 parent 95d216a commit 47bb4a7
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 36 deletions.
6 changes: 4 additions & 2 deletions descarteslabs/core/catalog/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def __init__(self, **kwargs):
super(Band, self).__init__(**kwargs)

@classmethod
def search(cls, client=None, request_params=None):
def search(cls, client=None, request_params=None, headers=None):
"""A search query for all bands.
Returns an instance of the
Expand All @@ -656,7 +656,9 @@ def search(cls, client=None, request_params=None):
:py:class:`~descarteslabs.catalog.Search`
An instance of the :py:class:`~descarteslabs.catalog.Search` class
"""
search = super(Band, cls).search(client, request_params=request_params)
search = super(Band, cls).search(
client, request_params=request_params, headers=headers
)
if cls._derived_type:
search = search.filter(properties.type == cls._derived_type)
return search
Expand Down
11 changes: 8 additions & 3 deletions descarteslabs/core/catalog/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def get(
name=None,
client=None,
request_params=None,
headers=None,
):
"""Get an existing Blob from the Descartes Labs catalog.
Expand Down Expand Up @@ -400,7 +401,9 @@ def get(
raise TypeError("Must specify exactly one of id or name parameters")
if not id:
id = f"{storage_type}/{Blob.namespace_id(namespace)}/{name}"
return super(cls, Blob).get(id, client=client)
return super(cls, Blob).get(
id, client=client, request_params=request_params, headers=headers
)

@classmethod
def get_or_create(
Expand Down Expand Up @@ -464,7 +467,7 @@ def get_or_create(
return super(cls, Blob).get_or_create(id, client=client, **kwargs)

@classmethod
def search(cls, client=None, request_params=None):
def search(cls, client=None, request_params=None, headers=None):
"""A search query for all blobs.
Return an `~descarteslabs.catalog.BlobSearch` instance for searching
Expand Down Expand Up @@ -493,7 +496,9 @@ def search(cls, client=None, request_params=None):
... print(result.name) # doctest: +SKIP
"""
return BlobSearch(cls, client=client, request_params=request_params)
return BlobSearch(
cls, client=client, request_params=request_params, headers=headers
)

@check_deleted
def upload(self, file):
Expand Down
39 changes: 26 additions & 13 deletions descarteslabs/core/catalog/catalog_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def state(self):
return DocumentState.SAVED

@classmethod
def get(cls, id, client=None, request_params=None):
def get(cls, id, client=None, request_params=None, headers=None):
"""Get an existing object from the Descartes Labs catalog.
If the Descartes Labs catalog object is found, it will be returned in the
Expand Down Expand Up @@ -535,6 +535,7 @@ def get(cls, id, client=None, request_params=None):
id=id,
client=client,
request_params=request_params,
headers=headers,
)
except NotFoundError:
return None
Expand All @@ -553,7 +554,9 @@ def get(cls, id, client=None, request_params=None):
)

@classmethod
def get_or_create(cls, id, client=None, **kwargs):
def get_or_create(
cls, id, client=None, request_params=None, headers=None, **kwargs
):
"""Get an existing object from the Descartes Labs catalog or create a new object.
If the Descartes Labs catalog object is found, and the remainder of the
Expand Down Expand Up @@ -588,7 +591,7 @@ def get_or_create(cls, id, client=None, **kwargs):
The requested catalog object that was retrieved or created.
"""
obj = cls.get(id, client=client)
obj = cls.get(id, client=client, request_params=request_params, headers=headers)

if obj is None:
obj = cls(id=id, client=client, **kwargs)
Expand All @@ -598,7 +601,9 @@ def get_or_create(cls, id, client=None, **kwargs):
return obj

@classmethod
def get_many(cls, ids, ignore_missing=False, client=None, request_params=None):
def get_many(
cls, ids, ignore_missing=False, client=None, request_params=None, headers=None
):
"""Get existing objects from the Descartes Labs catalog.
All returned Descartes Labs catalog objects will be in the
Expand Down Expand Up @@ -647,6 +652,7 @@ def get_many(cls, ids, ignore_missing=False, client=None, request_params=None):
client=client,
json={"filter": json.dumps([id_filter], separators=(",", ":"))},
request_params=request_params,
headers=headers,
)

if not ignore_missing:
Expand Down Expand Up @@ -676,7 +682,7 @@ def get_many(cls, ids, ignore_missing=False, client=None, request_params=None):

@classmethod
@check_derived
def exists(cls, id, client=None):
def exists(cls, id, client=None, headers=None):
"""Checks if an object exists in the Descartes Labs catalog.
Parameters
Expand Down Expand Up @@ -704,15 +710,15 @@ def exists(cls, id, client=None):
client = client or CatalogClient.get_default_client()
r = None
try:
r = client.session.head(cls._url + "/" + id)
r = client.session.head(cls._url + "/" + id, headers=headers)
except NotFoundError:
return False

return r and r.ok

@classmethod
@check_derived
def search(cls, client=None, request_params=None):
def search(cls, client=None, request_params=None, headers=None):
"""A search query for all objects of the type this class represents.
Parameters
Expand All @@ -736,11 +742,13 @@ def search(cls, client=None, request_params=None):
print(result.name) # doctest: +SKIP
"""
return Search(cls, client=client, request_params=request_params)
return Search(
cls, client=client, request_params=request_params, headers=headers
)

@check_deleted
@deprecate(renamed={"extra_attributes": "request_params"})
def save(self, request_params=None):
def save(self, request_params=None, headers=None):
"""Saves this object to the Descartes Labs catalog.
If this instance was created using the constructor, it will be in the
Expand Down Expand Up @@ -769,6 +777,8 @@ def save(self, request_params=None):
and the object is in the `~descarteslabs.catalog.DocumentState.SAVED`
state, it is updated in the Descartes Labs catalog even though no attributes
were modified.
headers : dict, optional
A dictionary of header keys and values to be sent with the request.
Raises
------
Expand Down Expand Up @@ -830,7 +840,7 @@ def save(self, request_params=None):
json["data"]["attributes"].update(request_params)

data, related_objects = self._send_data(
method=method, id=self.id, json=json, client=self._client
method=method, id=self.id, json=json, client=self._client, headers=headers
)

self._initialize(
Expand All @@ -842,7 +852,7 @@ def save(self, request_params=None):
)

@check_deleted
def reload(self, request_params=None):
def reload(self, request_params=None, headers=None):
"""Reload all attributes from the Descartes Labs catalog.
Refresh the state of this catalog object from the object in the Descartes Labs
Expand Down Expand Up @@ -896,6 +906,7 @@ def reload(self, request_params=None):
id=self.id,
client=self._client,
request_params=request_params,
headers=headers,
)

# this will effectively wipe all current state & caching
Expand Down Expand Up @@ -974,7 +985,9 @@ def _instance_delete(self):

@classmethod
@check_derived
def _send_data(cls, method, id=None, json=None, client=None, request_params=None):
def _send_data(
cls, method, id=None, json=None, client=None, request_params=None, headers=None
):
client = client or CatalogClient.get_default_client()
session_method = getattr(client.session, method.lower())
url = cls._url
Expand All @@ -996,7 +1009,7 @@ def _send_data(cls, method, id=None, json=None, client=None, request_params=None
if query_params:
url += "?" + urllib.parse.urlencode(query_params)

r = session_method(url, json=json).json()
r = session_method(url, json=json, headers=headers).json()
data = r["data"]
related_objects = cls._load_related_objects(r, client)

Expand Down
6 changes: 4 additions & 2 deletions descarteslabs/core/catalog/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def date(self):
return self.acquired

@classmethod
def search(cls, client=None, request_params=None):
def search(cls, client=None, request_params=None, headers=None):
"""A search query for all images.
Return an `~descarteslabs.catalog.ImageSearch` instance for searching
Expand Down Expand Up @@ -574,7 +574,9 @@ def search(cls, client=None, request_params=None):
... print(result.name) # doctest: +SKIP
"""
return ImageSearch(cls, client=client, request_params=request_params)
return ImageSearch(
cls, client=client, request_params=request_params, headers=headers
)

@check_deleted
def upload(self, files, upload_options=None, overwrite=False):
Expand Down
28 changes: 18 additions & 10 deletions descarteslabs/core/catalog/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def named_id(self, name):
return "{}:{}".format(self.id, name)

@check_deleted
def get_band(self, name, client=None, request_params=None):
def get_band(self, name, client=None, request_params=None, headers=None):
"""Retrieve the request band associated with this product by name.
Parameters
Expand All @@ -273,11 +273,14 @@ def get_band(self, name, client=None, request_params=None):
from .band import Band

return Band.get(
self.named_id(name), request_params=request_params, client=client
self.named_id(name),
client=client,
request_params=request_params,
headers=headers,
)

@check_deleted
def get_image(self, name, client=None, request_params=None):
def get_image(self, name, client=None, request_params=None, headers=None):
"""Retrieve the request image associated with this product by name.
Parameters
Expand All @@ -299,7 +302,10 @@ def get_image(self, name, client=None, request_params=None):
from .image import Image

return Image.get(
self.named_id(name), request_params=request_params, client=client
self.named_id(name),
client=client,
request_params=request_params,
headers=headers,
)

@check_deleted
Expand Down Expand Up @@ -367,7 +373,7 @@ def get_delete_status(self):
)

@check_deleted
def bands(self, request_params=None):
def bands(self, request_params=None, headers=None):
"""A search query for all bands for this product, sorted by default band
``sort_order``.
Expand All @@ -386,13 +392,15 @@ def bands(self, request_params=None):
from .band import Band

return (
Band.search(client=self._client, request_params=request_params)
Band.search(
client=self._client, request_params=request_params, headers=headers
)
.filter(properties.product_id == self.id)
.sort("sort_order")
)

@check_deleted
def images(self, request_params=None):
def images(self, request_params=None, headers=None):
"""A search query for all images in this product.
Returns
Expand All @@ -409,9 +417,9 @@ def images(self, request_params=None):
"""
from .image import Image

return Image.search(client=self._client, request_params=request_params).filter(
properties.product_id == self.id
)
return Image.search(
client=self._client, request_params=request_params, headers=headers
).filter(properties.product_id == self.id)

@check_deleted
def image_uploads(self):
Expand Down
29 changes: 23 additions & 6 deletions descarteslabs/core/catalog/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,22 @@ class Search(object):
"""

def __init__(
self, model, client=None, url=None, includes=True, request_params=None
self,
model,
client=None,
url=None,
includes=True,
request_params=None,
headers=None,
):
self._url = url or model._url
self._model_cls = model
self._request_params = {}
if request_params:
self._request_params.update(request_params)
self._headers = {}
if headers:
self._headers.update(headers)

self._filter_properties = None
self._client = client or CatalogClient.get_default_client()
Expand Down Expand Up @@ -281,7 +290,7 @@ def count(self):
# modify query to return 0 results, and just get the object count
s = self.limit(0)
url, params = s._to_request()
r = self._client.session.put(url, json=params)
r = self._client.session.put(url, json=params, headers=s._headers)
response = r.json()
return response["meta"]["count"]

Expand Down Expand Up @@ -330,7 +339,7 @@ def __iter__(self):
"""
url_next, params = self._to_request()
while url_next is not None:
r = self._client.session.put(url_next, json=params)
r = self._client.session.put(url_next, json=params, headers=self._headers)
response = r.json()
if not response["data"]:
break
Expand Down Expand Up @@ -433,10 +442,16 @@ class GeoSearch(Search):
geometries."""

def __init__(
self, model, client=None, url=None, includes=True, request_params=None
self,
model,
client=None,
url=None,
includes=True,
request_params=None,
headers=None,
):
super(GeoSearch, self).__init__(
model, client, url, includes, request_params=request_params
model, client, url, includes, request_params=request_params, headers=headers
)
self._intersects = None

Expand Down Expand Up @@ -603,7 +618,9 @@ def summary_interval(
else:
s._request_params["_end"] = "" # Unbounded

r = self._client.session.put(summary_url, json=s._summary_request())
r = self._client.session.put(
summary_url, json=s._summary_request(), headers=s._headers
)
response = r.json()

return [self.SummaryResult(**d["attributes"]) for d in response["data"]]
Loading

0 comments on commit 47bb4a7

Please sign in to comment.