Skip to content

Commit

Permalink
fix: use common interface in joiners and loaders (#235)
Browse files Browse the repository at this point in the history
* fix: use common interface in joiners and loaders

* fix: add ABC to OSMLoader class
  • Loading branch information
piotrgramacki authored Apr 25, 2023
1 parent fa5c79a commit 553179d
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 13 deletions.
4 changes: 3 additions & 1 deletion srai/joiners/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ class Joiner(abc.ABC):

@abc.abstractmethod
def transform(
self, regions: gpd.GeoDataFrame, features: gpd.GeoDataFrame
self,
regions: gpd.GeoDataFrame,
features: gpd.GeoDataFrame,
) -> gpd.GeoDataFrame: # pragma: no cover
"""
Join features to regions.
Expand Down
3 changes: 2 additions & 1 deletion srai/joiners/intersection_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import pandas as pd

from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, REGIONS_INDEX
from srai.joiners import Joiner


class IntersectionJoiner:
class IntersectionJoiner(Joiner):
"""
Intersection Joiner.
Expand Down
9 changes: 4 additions & 5 deletions srai/loaders/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Base class for loaders."""

import abc
from pathlib import Path
from typing import Union
from typing import Any

import geopandas as gpd

Expand All @@ -11,13 +10,13 @@ class Loader(abc.ABC):
"""Abstract class for loaders."""

@abc.abstractmethod
def load(self, area: Union[gpd.GeoDataFrame, Path]) -> gpd.GeoDataFrame: # pragma: no cover
def load(self, *args: Any, **kwargs: Any) -> gpd.GeoDataFrame: # pragma: no cover
"""
Load data for a given area.
Args:
area (gdf.GeoDataFrame | Path): GeoDataFrame with the area of interest or a path
to a file with a geometry.
*args: Positional arguments dependating on a specific loader.
**kwargs: Keyword arguments dependating on a specific loader.
Returns:
GeoDataFrame with the downloaded data.
Expand Down
3 changes: 2 additions & 1 deletion srai/loaders/geoparquet_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import geopandas as gpd

from srai.constants import GEOMETRY_COLUMN, WGS84_CRS
from srai.loaders import Loader


class GeoparquetLoader:
class GeoparquetLoader(Loader):
"""
GeoparquetLoader.
Expand Down
3 changes: 2 additions & 1 deletion srai/loaders/gtfs_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from shapely.geometry import Point

from srai.constants import GEOMETRY_COLUMN, WGS84_CRS
from srai.loaders import Loader
from srai.utils._optional import import_optional_dependencies

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -27,7 +28,7 @@
GTFS2VEC_TRIPS_PREFIX = "trips_at_"


class GTFSLoader:
class GTFSLoader(Loader):
"""
GTFSLoader.
Expand Down
3 changes: 2 additions & 1 deletion srai/loaders/osm_loaders/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
from tqdm import tqdm

from srai.loaders import Loader
from srai.loaders.osm_loaders.filters._typing import (
grouped_osm_tags_type,
merge_grouped_osm_tags_type,
Expand All @@ -16,7 +17,7 @@
from srai.utils.typing import is_expected_type


class OSMLoader(abc.ABC):
class OSMLoader(Loader, abc.ABC):
"""Abstract class for loaders."""

@abc.abstractmethod
Expand Down
5 changes: 4 additions & 1 deletion srai/loaders/osm_loaders/osm_online_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from srai.loaders.osm_loaders._base import OSMLoader
from srai.loaders.osm_loaders.filters._typing import grouped_osm_tags_type, osm_tags_type
from srai.loaders.osm_loaders.filters._typing import (
grouped_osm_tags_type,
osm_tags_type,
)
from srai.utils._optional import import_optional_dependencies


Expand Down
5 changes: 4 additions & 1 deletion srai/loaders/osm_loaders/osm_pbf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from srai.loaders.osm_loaders._base import OSMLoader
from srai.loaders.osm_loaders.filters._typing import grouped_osm_tags_type, osm_tags_type
from srai.loaders.osm_loaders.filters._typing import (
grouped_osm_tags_type,
osm_tags_type,
)
from srai.utils._optional import import_optional_dependencies


Expand Down
3 changes: 2 additions & 1 deletion srai/loaders/osm_way_loader/osm_way_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from srai.exceptions import LoadedDataIsEmptyException
from srai.loaders import Loader
from srai.utils._optional import import_optional_dependencies

from . import constants
Expand All @@ -41,7 +42,7 @@ class NetworkType(str, Enum):
WALK = "walk"


class OSMWayLoader:
class OSMWayLoader(Loader):
"""
OSMWayLoader downloads road infrastructure from OSM.
Expand Down

0 comments on commit 553179d

Please sign in to comment.