Skip to content

Commit

Permalink
plugins for drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaapel committed Apr 2, 2024
1 parent e5d5fc4 commit 29ca25c
Show file tree
Hide file tree
Showing 22 changed files with 106 additions and 29 deletions.
2 changes: 1 addition & 1 deletion hydromt/data_source/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from hydromt._typing import DataType
from hydromt.data_adapter.caching import _uri_validator
from hydromt.data_adapter.data_adapter_base import DataAdapterBase
from hydromt.driver import BaseDriver
from hydromt.drivers import BaseDriver

logger: Logger = getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion hydromt/data_source/geodataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from hydromt._typing import Bbox, ErrorHandleMethod, Geom, NoDataStrategy, TotalBounds
from hydromt.data_adapter.geodataframe import GeoDataFrameAdapter
from hydromt.driver.geodataframe_driver import GeoDataFrameDriver
from hydromt.drivers.geodataframe_driver import GeoDataFrameDriver

from .data_source import DataSource

Expand Down
2 changes: 1 addition & 1 deletion hydromt/data_source/rasterdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from hydromt._typing import Bbox, ErrorHandleMethod, Geom, TimeRange, TotalBounds
from hydromt.data_adapter.rasterdataset import RasterDatasetAdapter
from hydromt.data_source.data_source import DataSource
from hydromt.driver.rasterdataset_driver import RasterDatasetDriver
from hydromt.drivers.rasterdataset_driver import RasterDatasetDriver

logger: Logger = getLogger(__name__)

Expand Down
File renamed without changes.
34 changes: 23 additions & 11 deletions hydromt/driver/base_driver.py → hydromt/drivers/base_driver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Base class for different drivers."""

from abc import ABC
from typing import Any, Callable, ClassVar
from typing import Any, Callable, ClassVar, Generator, List, Type

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from hydromt.metadata_resolver import MetaDataResolver
from hydromt.metadata_resolver.resolver_plugin import RESOLVERS
from hydromt.plugins import PLUGINS


class BaseDriver(BaseModel, ABC):
Expand Down Expand Up @@ -72,20 +73,31 @@ def _init_driver(cls, data: Any, handler: Callable):
return handler(data)

if name := data.get("name"):
try:
# Find which DataSource to instantiate.
target_cls: BaseDriver = next(
filter(lambda sc: sc.name == name, cls._find_all_possible_types())
) # subclasses should be loaded from __init__.py
return target_cls.model_validate(data)
except StopIteration:
raise ValueError(f"Unknown 'name': '{name}'")
# Load plugins, importing subclasses of BaseDriver
PLUGINS.driver_plugins # noqa: B018

# Find which Driver to instantiate.
possible_drivers: List[Type["BaseDriver"]] = list(
filter(lambda dr: dr.name == name, cls._find_all_possible_types())
)
if len(possible_drivers) == 0:
raise ValueError(f"Unknown 'name': '{name}'")
elif len(possible_drivers) > 1:
raise ValueError(
f"""Duplication between driver name {name} in classes:
{list(map(lambda dr: dr.__qualname__, possible_drivers))}"""
)
else:
return possible_drivers[0].model_validate(data)
raise ValueError(f"{cls.__name__} needs 'name'")

@classmethod
def _find_all_possible_types(cls):
"""Recursively generate all possible types for this object."""
def _find_all_possible_types(cls) -> Generator[None, None, Type["BaseDriver"]]:
"""Recursively generate all possible types for this object.
Logic relies on __bases__ and __subclass__() of the BaseDriver class,
which means that all drivers and plugins should be loaded in before.
"""
# any concrete class is a possible type
if ABC not in cls.__bases__:
yield cls
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from hydromt._typing import Bbox, Geom
from hydromt._typing.error import NoDataStrategy
from hydromt.driver import BaseDriver
from hydromt.drivers import BaseDriver

logger: Logger = getLogger(__name__)

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from hydromt._typing import Bbox, Geom, GpdShapeGeom
from hydromt._typing.error import NoDataStrategy
from hydromt.driver.geodataframe_driver import GeoDataFrameDriver
from hydromt.drivers.geodataframe_driver import GeoDataFrameDriver
from hydromt.gis import parse_geom_bbox_buffer

logger: Logger = getLogger(__name__)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from hydromt._typing import Bbox, Geom
from hydromt._typing.error import NoDataStrategy
from hydromt.driver.preprocessing import PREPROCESSORS
from hydromt.driver.rasterdataset_driver import RasterDatasetDriver
from hydromt.drivers.preprocessing import PREPROCESSORS
from hydromt.drivers.rasterdataset_driver import RasterDatasetDriver


class ZarrDriver(RasterDatasetDriver):
Expand Down
50 changes: 50 additions & 0 deletions hydromt/plugins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Implementation of the mechanism to access the plugin entrypoints."""

from abc import ABC
from typing import TYPE_CHECKING, Dict, Optional, Type, TypedDict, cast

from importlib_metadata import entry_points

if TYPE_CHECKING:
from hydromt.components import ModelComponent # noqa
from hydromt.drivers import BaseDriver
from hydromt.models import Model # noqa

__all__ = ["PLUGINS"]
Expand Down Expand Up @@ -47,10 +49,12 @@ class Plugins:
def __init__(self):
"""Initiate the catalog object."""
self._component_plugins: Optional[Dict[str, Plugin]] = None
self._driver_plugins: Optional[Dict[str, Plugin]] = None
self._model_plugins: Optional[Dict[str, Plugin]] = None

def _initialize_plugins(self) -> None:
self._component_plugins = _discover_plugins(group="hydromt.components")
self._driver_plugins = _discover_plugins(group="hydromt.drivers")
self._model_plugins = _discover_plugins(group="hydromt.models")

@property
Expand All @@ -71,6 +75,25 @@ def component_plugins(self) -> dict[str, type["ModelComponent"]]:
},
)

@property
def driver_plugins(self) -> dict[str, Type["BaseDriver"]]:
"""Load and provide access to all known driver plugins."""
if self._driver_plugins is None:
self._initialize_plugins()

if self._driver_plugins is None:
# core itself exposes plugins so if we can't find anything, something is wrong
raise RuntimeError("Could not load any driver plugins")

drivers: dict[str, Type["BaseDriver"]] = cast(
Dict[str, Type["BaseDriver"]],
{name: value["type"] for name, value in self._driver_plugins.items()},
)
# Do not return ABCs, such as BaseDriver, or RasterDatasetDriver
return {
key: value for key, value in drivers.items() if ABC not in value.__bases__
}

@property
def model_plugins(self) -> dict[str, type["Model"]]:
"""Load and provide access to all known model plugins."""
Expand Down Expand Up @@ -116,6 +139,25 @@ def component_metadata(self) -> Dict[str, Dict[str, str]]:
{k: v for k, v in self._component_plugins.items() if k != "type"},
)

@property
def driver_metadata(self) -> Dict[str, Dict[str, str]]:
"""Load and provide access to all known driver plugin metadata."""
if self._driver_plugins is None:
self._initialize_plugins()

if self._driver_plugins is None:
# core itself exposes plugins so if we can't find anything, something is wrong
raise RuntimeError("Could not load any driver plugins")
else:
return cast(
Dict[str, Dict[str, str]],
{
k: v
for k, v in self._driver_plugins.items()
if ABC not in v["type"].__bases__ # filter ABCs from core
},
)

def model_summary(self) -> str:
"""Generate string representation containing the registered model entrypoints."""
s = ""
Expand All @@ -126,6 +168,14 @@ def model_summary(self) -> str:
)
return f"Model plugins:\n\t- {model_plugins}"

def driver_summary(self) -> str:
"""Generate string representation container the registered driver entrypoints."""
self._initialize_plugins()
driver_plugins = "\n\t ".join(
list(map(_format_metadata, self.driver_metadata.values()))
)
return f"Driver Plugins:\n\t- {driver_plugins}"

def component_summary(self) -> str:
"""Generate string representation containing the registered component entrypoints."""
self._initialize_plugins()
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ hydromt = "hydromt.cli.main:main"
[project.entry-points."hydromt.components"]
core = "hydromt.components"

[project.entry-points."hydromt.drivers"]
core = "hydromt.drivers"

[project.entry-points."hydromt.models"]
core = "hydromt.models"

Expand Down
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from hydromt.data_adapter.geodataframe import GeoDataFrameAdapter
from hydromt.data_catalog import DataCatalog
from hydromt.driver.geodataframe_driver import GeoDataFrameDriver
from hydromt.driver.rasterdataset_driver import RasterDatasetDriver
from hydromt.drivers.geodataframe_driver import GeoDataFrameDriver
from hydromt.drivers.rasterdataset_driver import RasterDatasetDriver
from hydromt.gis import raster, utils, vector
from hydromt.metadata_resolver import MetaDataResolver
from hydromt.models.model import Model
Expand Down Expand Up @@ -379,6 +379,8 @@ def mock_geodf_driver(
geodf: gpd.GeoDataFrame, mock_resolver: MetaDataResolver
) -> GeoDataFrameDriver:
class MockGeoDataFrameDriver(GeoDataFrameDriver):
name = "mock_geodf_driver"

def read(self, *args, **kwargs) -> gpd.GeoDataFrame:
return geodf

Expand All @@ -390,6 +392,8 @@ def mock_raster_ds_driver(
rasterds: xr.Dataset, mock_resolver: MetaDataResolver
) -> RasterDatasetDriver:
class MockRasterDatasetDriver(RasterDatasetDriver):
name = "mock_raster_ds_driver"

def read(self, *args, **kwargs) -> xr.Dataset:
return rasterds

Expand Down
4 changes: 3 additions & 1 deletion tests/data_sources/test_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hydromt.data_adapter.geodataframe import GeoDataFrameAdapter
from hydromt.data_source import DataSource, GeoDataFrameSource, create_source
from hydromt.data_source.data_source import get_nested_var, set_nested_var
from hydromt.driver.geodataframe_driver import GeoDataFrameDriver
from hydromt.drivers.geodataframe_driver import GeoDataFrameDriver


class TestDataSource:
Expand Down Expand Up @@ -42,6 +42,8 @@ def test_reads_nested(
self,
):
class FakeGeoDfDriver(GeoDataFrameDriver):
name = "test_reads_nested"

def read(self, **kwargs):
pass

Expand Down
2 changes: 1 addition & 1 deletion tests/data_sources/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from hydromt.data_adapter import GeoDataFrameAdapter
from hydromt.data_source import DataSource, GeoDataFrameSource, create_source
from hydromt.driver import GeoDataFrameDriver
from hydromt.drivers import GeoDataFrameDriver


class TestCreateSource:
Expand Down
4 changes: 2 additions & 2 deletions tests/data_sources/test_geo_data_frame_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from hydromt.data_adapter.geodataframe import GeoDataFrameAdapter
from hydromt.data_catalog import DataCatalog
from hydromt.data_source.geodataframe import GeoDataFrameSource
from hydromt.driver.geodataframe_driver import GeoDataFrameDriver
from hydromt.driver.pyogrio_driver import PyogrioDriver
from hydromt.drivers.geodataframe_driver import GeoDataFrameDriver
from hydromt.drivers.pyogrio_driver import PyogrioDriver
from hydromt.metadata_resolver.convention_resolver import ConventionResolver


Expand Down
2 changes: 1 addition & 1 deletion tests/data_sources/test_raster_dataset_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hydromt.data_adapter import RasterDatasetAdapter
from hydromt.data_source import RasterDatasetSource
from hydromt.driver.rasterdataset_driver import RasterDatasetDriver
from hydromt.drivers.rasterdataset_driver import RasterDatasetDriver


@pytest.fixture()
Expand Down
2 changes: 1 addition & 1 deletion tests/drivers/test_base_driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hydromt.driver.base_driver import BaseDriver
from hydromt.drivers.base_driver import BaseDriver


class TestBaseDriver:
Expand Down
2 changes: 1 addition & 1 deletion tests/drivers/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import xarray as xr

from hydromt.driver.preprocessing import (
from hydromt.drivers.preprocessing import (
round_latlon,
)
from hydromt.gis.raster import full_from_transform
Expand Down
2 changes: 1 addition & 1 deletion tests/drivers/test_pyogrio_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from shapely import box

from hydromt._typing import Bbox
from hydromt.driver.pyogrio_driver import PyogrioDriver
from hydromt.drivers.pyogrio_driver import PyogrioDriver
from hydromt.metadata_resolver.convention_resolver import ConventionResolver


Expand Down
2 changes: 1 addition & 1 deletion tests/drivers/test_zarr_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import zarr

from hydromt.driver.zarr_driver import ZarrDriver
from hydromt.drivers.zarr_driver import ZarrDriver
from hydromt.metadata_resolver.convention_resolver import ConventionResolver


Expand Down
6 changes: 6 additions & 0 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ def test_core_model_plugins():
def test_summary():
component_summary = PLUGINS.component_summary()
model_summary = PLUGINS.model_summary()
driver_summary = PLUGINS.driver_summary()
assert "Component plugins:" in component_summary
assert "Model plugins:" in model_summary
assert "ModelRegionComponent" in component_summary
assert "GridComponent" in component_summary
assert "ModelComponent" in component_summary
assert "Model" in model_summary
assert "Driver Plugins:" in driver_summary
assert "PyogrioDriver" in driver_summary
assert "RasterDatasetDriver" not in driver_summary


def _patch_plugin_entry_point(mocker: MockerFixture, component_names: List[str]):
Expand Down Expand Up @@ -68,9 +72,11 @@ def _patch_plugin_entry_point(mocker: MockerFixture, component_names: List[str])
@pytest.fixture()
def _reset_plugins():
PLUGINS._component_plugins = None
PLUGINS._driver_plugins = None
PLUGINS._model_plugins = None
yield
PLUGINS._component_plugins = None
PLUGINS._driver_plugins = None
PLUGINS._model_plugins = None


Expand Down

0 comments on commit 29ca25c

Please sign in to comment.