Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use common interface in joiners and loaders #235

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion srai/joiners/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ class Joiner(abc.ABC):

@abc.abstractmethod
def transform(
self, regions: gpd.GeoDataFrame, features: gpd.GeoDataFrame
self,
regions: gpd.GeoDataFrame,
features: gpd.GeoDataFrame,
) -> gpd.GeoDataFrame: # pragma: no cover
"""
Join features to regions.
Expand Down
3 changes: 2 additions & 1 deletion srai/joiners/intersection_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import pandas as pd

from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, REGIONS_INDEX
from srai.joiners import Joiner


class IntersectionJoiner:
class IntersectionJoiner(Joiner):
"""
Intersection Joiner.

Expand Down
9 changes: 4 additions & 5 deletions srai/loaders/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Base class for loaders."""

import abc
from pathlib import Path
from typing import Union
from typing import Any

import geopandas as gpd

Expand All @@ -11,13 +10,13 @@ class Loader(abc.ABC):
"""Abstract class for loaders."""

@abc.abstractmethod
def load(self, area: Union[gpd.GeoDataFrame, Path]) -> gpd.GeoDataFrame: # pragma: no cover
def load(self, *args: Any, **kwargs: Any) -> gpd.GeoDataFrame: # pragma: no cover
Calychas marked this conversation as resolved.
Show resolved Hide resolved
"""
Load data for a given area.

Args:
area (gdf.GeoDataFrame | Path): GeoDataFrame with the area of interest or a path
to a file with a geometry.
*args: Positional arguments dependating on a specific loader.
**kwargs: Keyword arguments dependating on a specific loader.

Returns:
GeoDataFrame with the downloaded data.
Expand Down
3 changes: 2 additions & 1 deletion srai/loaders/geoparquet_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import geopandas as gpd

from srai.constants import GEOMETRY_COLUMN, WGS84_CRS
from srai.loaders import Loader


class GeoparquetLoader:
class GeoparquetLoader(Loader):
"""
GeoparquetLoader.

Expand Down
3 changes: 2 additions & 1 deletion srai/loaders/gtfs_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from shapely.geometry import Point

from srai.constants import GEOMETRY_COLUMN, WGS84_CRS
from srai.loaders import Loader
from srai.utils._optional import import_optional_dependencies

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -27,7 +28,7 @@
GTFS2VEC_TRIPS_PREFIX = "trips_at_"


class GTFSLoader:
class GTFSLoader(Loader):
"""
GTFSLoader.

Expand Down
3 changes: 2 additions & 1 deletion srai/loaders/osm_loaders/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
from tqdm import tqdm

from srai.loaders import Loader
from srai.loaders.osm_loaders.filters._typing import (
grouped_osm_tags_type,
merge_grouped_osm_tags_type,
Expand All @@ -16,7 +17,7 @@
from srai.utils.typing import is_expected_type


class OSMLoader(abc.ABC):
class OSMLoader(Loader, abc.ABC):
"""Abstract class for loaders."""

@abc.abstractmethod
Expand Down
5 changes: 4 additions & 1 deletion srai/loaders/osm_loaders/osm_online_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from srai.loaders.osm_loaders._base import OSMLoader
from srai.loaders.osm_loaders.filters._typing import grouped_osm_tags_type, osm_tags_type
from srai.loaders.osm_loaders.filters._typing import (
grouped_osm_tags_type,
osm_tags_type,
)
from srai.utils._optional import import_optional_dependencies


Expand Down
5 changes: 4 additions & 1 deletion srai/loaders/osm_loaders/osm_pbf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from srai.loaders.osm_loaders._base import OSMLoader
from srai.loaders.osm_loaders.filters._typing import grouped_osm_tags_type, osm_tags_type
from srai.loaders.osm_loaders.filters._typing import (
grouped_osm_tags_type,
osm_tags_type,
)
from srai.utils._optional import import_optional_dependencies


Expand Down
3 changes: 2 additions & 1 deletion srai/loaders/osm_way_loader/osm_way_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from srai.exceptions import LoadedDataIsEmptyException
from srai.loaders import Loader
from srai.utils._optional import import_optional_dependencies

from . import constants
Expand All @@ -41,7 +42,7 @@ class NetworkType(str, Enum):
WALK = "walk"


class OSMWayLoader:
class OSMWayLoader(Loader):
"""
OSMWayLoader downloads road infrastructure from OSM.

Expand Down