Skip to content

Commit

Permalink
feat: support multiple extension uris
Browse files Browse the repository at this point in the history
  • Loading branch information
jsignell authored and gadomski committed Apr 12, 2023
1 parent 0f90e7a commit 051152e
Show file tree
Hide file tree
Showing 13 changed files with 166 additions and 37 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

- Include a copy of the `fields.json` file (for summaries) with each distribution of PySTAC ([#1045](https://github.com/stac-utils/pystac/pull/1045))
- Removed documentation references to `to_dict` methods returning JSON ([#1074](https://github.com/stac-utils/pystac/pull/1074))
- Expand support for previous extension schema URIs ([#1091](https://github.com/stac-utils/pystac/pull/1091))

### Deprecated

Expand Down
4 changes: 4 additions & 0 deletions pystac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@
import pystac.extensions.datacube
import pystac.extensions.eo
import pystac.extensions.file
import pystac.extensions.grid
import pystac.extensions.item_assets
import pystac.extensions.label
import pystac.extensions.pointcloud
import pystac.extensions.projection
import pystac.extensions.raster
import pystac.extensions.sar
import pystac.extensions.sat
import pystac.extensions.scientific
Expand All @@ -105,10 +107,12 @@
pystac.extensions.datacube.DATACUBE_EXTENSION_HOOKS,
pystac.extensions.eo.EO_EXTENSION_HOOKS,
pystac.extensions.file.FILE_EXTENSION_HOOKS,
pystac.extensions.grid.GRID_EXTENSION_HOOKS,
pystac.extensions.item_assets.ITEM_ASSETS_EXTENSION_HOOKS,
pystac.extensions.label.LABEL_EXTENSION_HOOKS,
pystac.extensions.pointcloud.POINTCLOUD_EXTENSION_HOOKS,
pystac.extensions.projection.PROJECTION_EXTENSION_HOOKS,
pystac.extensions.raster.RASTER_EXTENSION_HOOKS,
pystac.extensions.sar.SAR_EXTENSION_HOOKS,
pystac.extensions.sat.SAT_EXTENSION_HOOKS,
pystac.extensions.scientific.SCIENTIFIC_EXTENSION_HOOKS,
Expand Down
10 changes: 7 additions & 3 deletions pystac/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def get_schema_uri(cls) -> str:
"""Gets the schema URI associated with this extension."""
raise NotImplementedError

@classmethod
def get_schema_uris(cls) -> List[str]:
"""Gets a list of schema URIs associated with this extension."""
return [cls.get_schema_uri()]

@classmethod
def add_to(cls, obj: S) -> None:
"""Add the schema URI for this extension to the
Expand All @@ -135,9 +140,8 @@ def remove_from(cls, obj: S) -> None:
def has_extension(cls, obj: S) -> bool:
"""Check if the given object implements this extension by checking
:attr:`pystac.STACObject.stac_extensions` for this extension's schema URI."""
return (
obj.stac_extensions is not None
and cls.get_schema_uri() in obj.stac_extensions
return obj.stac_extensions is not None and any(
uri in obj.stac_extensions for uri in cls.get_schema_uris()
)

@classmethod
Expand Down
14 changes: 11 additions & 3 deletions pystac/extensions/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
from __future__ import annotations

import re
from typing import Any, Dict, Optional, Pattern, Set, Union
from typing import Any, Dict, List, Optional, Pattern, Set, Union

import pystac
from pystac.extensions.base import ExtensionManagementMixin, PropertiesExtension
from pystac.extensions.hooks import ExtensionHooks

SCHEMA_URI: str = "https://stac-extensions.github.io/grid/v1.1.0/schema.json"
SCHEMA_URIS: List[str] = [
"https://stac-extensions.github.io/grid/v1.0.0/schema.json",
SCHEMA_URI,
]
PREFIX: str = "grid:"

# Field names
Expand Down Expand Up @@ -80,6 +84,10 @@ def code(self, v: str) -> None:
def get_schema_uri(cls) -> str:
return SCHEMA_URI

@classmethod
def get_schema_uris(cls) -> List[str]:
return SCHEMA_URIS

@classmethod
def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> GridExtension:
"""Extends the given STAC Object with properties from the :stac-ext:`Grid
Expand All @@ -102,8 +110,8 @@ def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> GridExtension:

class GridExtensionHooks(ExtensionHooks):
schema_uri: str = SCHEMA_URI
prev_extension_ids: Set[str] = set()
prev_extension_ids: Set[str] = {*[uri for uri in SCHEMA_URIS if uri != SCHEMA_URI]}
stac_object_types = {pystac.STACObjectType.ITEM}


Grid_EXTENSION_HOOKS: ExtensionHooks = GridExtensionHooks()
GRID_EXTENSION_HOOKS: ExtensionHooks = GridExtensionHooks()
11 changes: 9 additions & 2 deletions pystac/extensions/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from pystac.utils import StringEnum, get_required, map_opt

SCHEMA_URI = "https://stac-extensions.github.io/label/v1.0.1/schema.json"

SCHEMA_URIS = [
"https://stac-extensions.github.io/label/v1.0.0/schema.json",
SCHEMA_URI,
]
PREFIX = "label:"

PROPERTIES_PROP = PREFIX + "properties"
Expand Down Expand Up @@ -691,6 +694,10 @@ def add_geojson_labels(
def get_schema_uri(cls) -> str:
return SCHEMA_URI

@classmethod
def get_schema_uris(cls) -> List[str]:
return SCHEMA_URIS

@classmethod
def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> LabelExtension:
"""Extends the given STAC Object with properties from the :stac-ext:`Label
Expand Down Expand Up @@ -791,7 +798,7 @@ class LabelExtensionHooks(ExtensionHooks):
schema_uri: str = SCHEMA_URI
prev_extension_ids = {
"label",
"https://stac-extensions.github.io/label/v1.0.0/schema.json",
*[uri for uri in SCHEMA_URIS if uri != SCHEMA_URI],
}
stac_object_types = {pystac.STACObjectType.ITEM}

Expand Down
19 changes: 9 additions & 10 deletions pystac/extensions/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ def transform(self, v: Optional[List[float]]) -> None:
def get_schema_uri(cls) -> str:
return SCHEMA_URI

@classmethod
def get_schema_uris(cls) -> List[str]:
return SCHEMA_URIS

@classmethod
def ext(cls, obj: T, add_if_missing: bool = False) -> ProjectionExtension[T]:
"""Extends the given STAC Object with properties from the :stac-ext:`Projection
Expand Down Expand Up @@ -294,15 +298,6 @@ def summaries(
cls.validate_has_extension(obj, add_if_missing)
return SummariesProjectionExtension(obj)

@classmethod
def has_extension(cls, obj: Union[pystac.Item, pystac.Collection]) -> bool:
if isinstance(obj, pystac.Item) or isinstance(obj, pystac.Collection):
return obj.stac_extensions is not None and any(
uri in obj.stac_extensions for uri in SCHEMA_URIS
)
else:
return False


class ItemProjectionExtension(ProjectionExtension[pystac.Item]):
"""A concrete implementation of :class:`ProjectionExtension` on an
Expand Down Expand Up @@ -376,7 +371,11 @@ def epsg(self, v: Optional[List[int]]) -> None:

class ProjectionExtensionHooks(ExtensionHooks):
schema_uri: str = SCHEMA_URI
prev_extension_ids = {"proj", "projection"}
prev_extension_ids = {
"proj",
"projection",
*[uri for uri in SCHEMA_URIS if uri != SCHEMA_URI],
}
stac_object_types = {pystac.STACObjectType.ITEM}


Expand Down
21 changes: 19 additions & 2 deletions pystac/extensions/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Union

import pystac
from pystac.extensions.base import (
ExtensionManagementMixin,
PropertiesExtension,
SummariesExtension,
)
from pystac.extensions.hooks import ExtensionHooks
from pystac.utils import StringEnum, get_opt, get_required, map_opt

SCHEMA_URI = "https://stac-extensions.github.io/raster/v1.1.0/schema.json"

SCHEMA_URIS = [
"https://stac-extensions.github.io/raster/v1.0.0/schema.json",
SCHEMA_URI,
]
BANDS_PROP = "raster:bands"


Expand Down Expand Up @@ -706,6 +710,10 @@ def _get_bands(self) -> Optional[List[RasterBand]]:
def get_schema_uri(cls) -> str:
return SCHEMA_URI

@classmethod
def get_schema_uris(cls) -> List[str]:
return SCHEMA_URIS

@classmethod
def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> RasterExtension:
"""Extends the given STAC Object with properties from the :stac-ext:`Raster
Expand Down Expand Up @@ -752,3 +760,12 @@ def bands(self) -> Optional[List[RasterBand]]:
@bands.setter
def bands(self, v: Optional[List[RasterBand]]) -> None:
self._set_summary(BANDS_PROP, map_opt(lambda x: [b.to_dict() for b in x], v))


class RasterExtensionHooks(ExtensionHooks):
schema_uri: str = SCHEMA_URI
prev_extension_ids: Set[str] = {*[uri for uri in SCHEMA_URIS if uri != SCHEMA_URI]}
stac_object_types = {pystac.STACObjectType.ITEM, pystac.STACObjectType.COLLECTION}


RASTER_EXTENSION_HOOKS: ExtensionHooks = RasterExtensionHooks()
19 changes: 9 additions & 10 deletions pystac/serialization/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,19 @@ def migrate_to_latest(
# Force stac_extensions property, as it makes
# downstream migration less complex
result["stac_extensions"] = []
pystac.EXTENSION_HOOKS.migrate(result, version, info)

for ext in result["stac_extensions"][:]:
if ext in removed_extension_migrations:
object_types, migration_fn = removed_extension_migrations[ext]
if object_types is None or info.object_type in object_types:
if migration_fn:
migration_fn(result, version, info)
result["stac_extensions"].remove(ext)

result["stac_version"] = STACVersion.DEFAULT_STAC_VERSION
else:
# Ensure stac_extensions property for consistency
if "stac_extensions" not in result:
result["stac_extensions"] = []

pystac.EXTENSION_HOOKS.migrate(result, version, info)
for ext in result["stac_extensions"][:]:
if ext in removed_extension_migrations:
object_types, migration_fn = removed_extension_migrations[ext]
if object_types is None or info.object_type in object_types:
if migration_fn:
migration_fn(result, version, info)
result["stac_extensions"].remove(ext)

return result
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# TODO move all test case code to this file

from pathlib import Path
from datetime import datetime

import pytest
Expand All @@ -9,6 +10,9 @@
from .utils import ARBITRARY_BBOX, ARBITRARY_EXTENT, ARBITRARY_GEOM, TestCases


here = Path(__file__).resolve().parent


@pytest.fixture
def catalog() -> Catalog:
return Catalog("test-catalog", "A test catalog")
Expand Down Expand Up @@ -38,3 +42,7 @@ def test_case_8_collection() -> Collection:
def projection_landsat8_item() -> Item:
path = TestCases.get_path("data-files/projection/example-landsat8.json")
return Item.from_file(path)


def get_data_file(rel_path: str) -> str:
return str(here / "data-files" / rel_path)
26 changes: 26 additions & 0 deletions tests/extensions/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from datetime import datetime
from typing import Any, Dict

import pytest
import pystac
from pystac import ExtensionTypeError
from pystac.extensions import grid
from pystac.extensions.grid import GridExtension
from tests.utils import TestCases
from tests.conftest import get_data_file

code = "MGRS-4CFJ"

Expand Down Expand Up @@ -137,3 +139,27 @@ def test_should_raise_exception_when_passing_invalid_extension_object(
GridExtension.ext,
object(),
)


@pytest.fixture
def ext_item() -> pystac.Item:
ext_item_uri = get_data_file("grid/example-sentinel2.json")
return pystac.Item.from_file(ext_item_uri)


def test_older_extension_version(ext_item: pystac.Item) -> None:
old = "https://stac-extensions.github.io/grid/v1.0.0/schema.json"
new = "https://stac-extensions.github.io/grid/v1.1.0/schema.json"

stac_extensions = set(ext_item.stac_extensions)
stac_extensions.remove(new)
stac_extensions.add(old)
item_as_dict = ext_item.to_dict(include_self_link=False, transform_hrefs=False)
item_as_dict["stac_extensions"] = list(stac_extensions)
item = pystac.Item.from_dict(item_as_dict)
assert GridExtension.has_extension(item)
assert old in item.stac_extensions

migrated_item = pystac.Item.from_dict(item_as_dict, migrate=True)
assert GridExtension.has_extension(migrated_item)
assert new in migrated_item.stac_extensions
26 changes: 26 additions & 0 deletions tests/extensions/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
)
from pystac.utils import get_opt
from tests.utils import TestCases, assert_to_from_dict
from tests.conftest import get_data_file
import pytest


class LabelTypeTest(unittest.TestCase):
Expand Down Expand Up @@ -576,3 +578,27 @@ def test_should_raise_exception_when_passing_invalid_extension_object(
LabelExtension.ext,
object(),
)


@pytest.fixture
def ext_item() -> pystac.Item:
ext_item_uri = get_data_file("label/label-example-1.json")
return pystac.Item.from_file(ext_item_uri)


def test_older_extension_version(ext_item: pystac.Item) -> None:
old = "https://stac-extensions.github.io/label/v1.0.0/schema.json"
new = "https://stac-extensions.github.io/label/v1.0.1/schema.json"

stac_extensions = set(ext_item.stac_extensions)
stac_extensions.remove(new)
stac_extensions.add(old)
item_as_dict = ext_item.to_dict(include_self_link=False, transform_hrefs=False)
item_as_dict["stac_extensions"] = list(stac_extensions)
item = pystac.Item.from_dict(item_as_dict)
assert LabelExtension.has_extension(item)
assert old in item.stac_extensions

migrated_item = pystac.Item.from_dict(item_as_dict, migrate=True)
assert LabelExtension.has_extension(migrated_item)
assert new in migrated_item.stac_extensions
18 changes: 11 additions & 7 deletions tests/extensions/test_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,16 +539,20 @@ def test_summaries_adds_uri(self) -> None:


def test_older_extension_version(projection_landsat8_item: Item) -> None:
old = "https://stac-extensions.github.io/projection/v1.0.0/schema.json"
new = "https://stac-extensions.github.io/projection/v1.1.0/schema.json"

stac_extensions = set(projection_landsat8_item.stac_extensions)
stac_extensions.remove(
"https://stac-extensions.github.io/projection/v1.1.0/schema.json"
)
stac_extensions.add(
"https://stac-extensions.github.io/projection/v1.0.0/schema.json"
)
stac_extensions.remove(new)
stac_extensions.add(old)
item_as_dict = projection_landsat8_item.to_dict(
include_self_link=False, transform_hrefs=False
)
item_as_dict["stac_extensions"] = stac_extensions
item_as_dict["stac_extensions"] = list(stac_extensions)
item = Item.from_dict(item_as_dict)
assert ProjectionExtension.has_extension(item)
assert old in item.stac_extensions

migrated_item = pystac.Item.from_dict(item_as_dict, migrate=True)
assert ProjectionExtension.has_extension(migrated_item)
assert new in migrated_item.stac_extensions
Loading

0 comments on commit 051152e

Please sign in to comment.