From 46b8acb18f3e58ba6deb6a0fab24683bb7c0cb0d Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Wed, 12 Apr 2023 18:05:54 +0200 Subject: [PATCH 01/30] feat: added contextual count embedder --- .gitignore | 303 ++++++++++---------- srai/embedders/contextual_count_embedder.py | 104 +++++++ srai/embedders/count_embedder.py | 14 +- 3 files changed, 268 insertions(+), 153 deletions(-) create mode 100644 srai/embedders/contextual_count_embedder.py diff --git a/.gitignore b/.gitignore index 2efe9e7d..31787bca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,151 +1,152 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -coverage.*.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pdm -.pdm.toml -requirements.txt - -# VSCode -.vscode/**/* -!.vscode/settings.json.default -!.vscode/extensions.json - -# osmnx -cache/ - -# pytorch lightning -lightning_logs/ - -# files_cache -files/ - -# ruff -.ruff_cache +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +coverage.*.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pdm +.pdm.toml +.pdm-python +requirements.txt + +# VSCode +.vscode/**/* +!.vscode/settings.json.default +!.vscode/extensions.json + +# osmnx +cache/ + +# pytorch lightning +lightning_logs/ + +# files_cache +files/ + +# ruff +.ruff_cache diff --git a/srai/embedders/contextual_count_embedder.py b/srai/embedders/contextual_count_embedder.py new file mode 100644 index 00000000..eca3680e --- /dev/null +++ b/srai/embedders/contextual_count_embedder.py @@ -0,0 +1,104 @@ +""" +Contextual Count Embedder. + +This module contains contextual count embedder implementation from ARIC@SIGSPATIAL 2021 paper[1]. + +References: + [1] https://doi.org/10.1145/3486626.3493434 + [1] https://arxiv.org/abs/2111.00990 +""" + +from typing import List, Optional, TypeVar + +import geopandas as gpd +import numpy as np +import pandas as pd +from tqdm import tqdm + +from srai.embedders import CountEmbedder +from srai.neighbourhoods import Neighbourhood + +T = TypeVar("T") + + +class ContextualCountEmbedder(CountEmbedder): + """ContextualCountEmbedder.""" + + def __init__( + self, + neighbourhood: Neighbourhood[T], + neighbourhood_distance: int, + squash_vectors: bool = True, + expected_output_features: Optional[List[str]] = None, + count_subcategories: bool = False, + ) -> None: + """TODO.""" + # """ + # Init ContextualCountEmbedder. + + # Args: + # expected_output_features (List[str], optional): The features that are expected + # to be found in the resulting embedding. If not None, the missing features are + # added + # and filled with 0. The unexpected features are removed. + # The resulting columns are sorted accordingly. Defaults to None. + # count_subcategories (bool, optional): Whether to count all subcategories individually + # or count features only on the highest level based on features column name. + # Defaults to True. + # """ + super().__init__(expected_output_features, count_subcategories) + + self.neighbourhood = neighbourhood + self.neighbourhood_distance = neighbourhood_distance + self.squash_vectors = squash_vectors + + def transform( + self, + regions_gdf: gpd.GeoDataFrame, + features_gdf: gpd.GeoDataFrame, + joint_gdf: gpd.GeoDataFrame, + ) -> pd.DataFrame: + """TODO.""" + counts_df = super().transform(regions_gdf, features_gdf, joint_gdf) + + if self.squash_vectors: + return self._get_squashed_embeddings(counts_df) + else: + return self._get_concatenated_embeddings(counts_df) + + def _get_squashed_embeddings(self, counts_df: pd.DataFrame) -> pd.DataFrame: + base_columns = list(counts_df.columns) + + result_array = counts_df.values.astype(float) + for idx, region_id in tqdm(enumerate(counts_df.index), desc="Generating embeddings"): + for distance in range(1, self.neighbourhood_distance + 1): + neighbours = self.neighbourhood.get_neighbours_at_distance(region_id, distance) + matching_neighbours = counts_df.index.intersection(neighbours) + if not matching_neighbours.empty: + values = counts_df.loc[matching_neighbours].values + flattened_values = np.average(values, axis=0) + result_array[idx, :] += flattened_values / ((distance + 1) ** 2) + return pd.DataFrame(data=result_array, index=counts_df.index, columns=base_columns) + + def _get_concatenated_embeddings(self, counts_df: pd.DataFrame) -> pd.DataFrame: + base_columns = list(counts_df.columns) + no_base_columns = len(base_columns) + + columns = [ + f"{column}_{distance}" + for distance in range(self.neighbourhood_distance + 1) + for column in base_columns + ] + result_array = np.zeros(shape=(len(counts_df.index), len(columns))) + result_array[:, 0:no_base_columns] = counts_df.values + for idx, region_id in tqdm(enumerate(counts_df.index), desc="Generating embeddings"): + for distance in range(1, self.neighbourhood_distance + 1): + neighbours = self.neighbourhood.get_neighbours_at_distance(region_id, distance) + matching_neighbours = counts_df.index.intersection(neighbours) + if not matching_neighbours.empty: + values = counts_df.loc[matching_neighbours].values + flattened_values = np.average(values, axis=0) + result_array[ + idx, no_base_columns * distance : no_base_columns * (distance + 1) + ] = flattened_values + return pd.DataFrame(data=result_array, index=counts_df.index, columns=columns) diff --git a/srai/embedders/count_embedder.py b/srai/embedders/count_embedder.py index a32e9fbd..f51f0b1b 100644 --- a/srai/embedders/count_embedder.py +++ b/srai/embedders/count_embedder.py @@ -14,7 +14,9 @@ class CountEmbedder(Embedder): """Simple Embedder that counts occurences of feature values.""" - def __init__(self, expected_output_features: Optional[List[str]] = None) -> None: + def __init__( + self, expected_output_features: Optional[List[str]] = None, count_subcategories: bool = True + ) -> None: """ Init CountEmbedder. @@ -23,12 +25,17 @@ def __init__(self, expected_output_features: Optional[List[str]] = None) -> None to be found in the resulting embedding. If not None, the missing features are added and filled with 0. The unexpected features are removed. The resulting columns are sorted accordingly. Defaults to None. + count_subcategories (bool, optional): Whether to count all subcategories individually + or count features only on the highest level based on features column name. + Defaults to True. """ if expected_output_features is not None: self.expected_output_features = pd.Series(expected_output_features) else: self.expected_output_features = None + self.count_subcategories = count_subcategories + def transform( self, regions_gdf: gpd.GeoDataFrame, @@ -75,7 +82,10 @@ def transform( features_df = self._remove_geometry_if_present(features_gdf) joint_df = self._remove_geometry_if_present(joint_gdf) - feature_encodings = pd.get_dummies(features_df) + if self.count_subcategories: + feature_encodings = pd.get_dummies(features_df) + else: + feature_encodings = features_df.notna().astype(int) joint_with_encodings = joint_df.join(feature_encodings) region_embeddings = joint_with_encodings.groupby(level=0).sum() From 8a6923ac4813f596f3f59f157f097153d8a62d22 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Thu, 13 Apr 2023 15:25:29 +0200 Subject: [PATCH 02/30] feat: modify ContextualCountEmbedder --- .../embedders/contextual_count_embedder.ipynb | 215 ++++++++++++++++++ examples/embedders/highway2vec_embedder.ipynb | 2 +- examples/joiners/intersection_joiner.ipynb | 2 +- srai/embedders/__init__.py | 15 +- srai/embedders/contextual_count_embedder.py | 147 +++++++++--- srai/embedders/count_embedder.py | 4 +- srai/joiners/intersection_joiner.py | 5 +- srai/plotting/folium_wrapper.py | 20 +- 8 files changed, 360 insertions(+), 50 deletions(-) create mode 100644 examples/embedders/contextual_count_embedder.ipynb diff --git a/examples/embedders/contextual_count_embedder.ipynb b/examples/embedders/contextual_count_embedder.ipynb new file mode 100644 index 00000000..33d50856 --- /dev/null +++ b/examples/embedders/contextual_count_embedder.ipynb @@ -0,0 +1,215 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from srai.loaders.osm_loaders import OSMPbfLoader\n", + "from srai.regionizers import H3Regionizer\n", + "from srai.joiners import IntersectionJoiner\n", + "from srai.embedders import ContextualCountEmbedder\n", + "from srai.plotting.folium_wrapper import plot_regions, plot_numeric_data\n", + "from srai.neighbourhoods import H3Neighbourhood" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data preparation\n", + "\n", + "In order to use `ContextualCountEmbedder` we need to prepare some data. \n", + "Namely we need: `regions_gdf`, `features_gdf`, and `joint_gdf`. \n", + "These are the outputs of Regionizers, Loaders and Joiners respectively." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from srai.utils import geocode_to_region_gdf\n", + "\n", + "area_gdf = geocode_to_region_gdf(\"Lisboa, PT\")\n", + "plot_regions(area_gdf)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Regionize the area using an H3Regionizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "regionizer = H3Regionizer(resolution=9, buffer=True)\n", + "regions_gdf = regionizer.transform(area_gdf)\n", + "regions_gdf" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download some objects from OpenStreetMap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from srai.loaders.osm_loaders.filters.hex2vec import HEX2VEC_FILTER\n", + "\n", + "loader = OSMPbfLoader()\n", + "features_gdf = loader.load(area_gdf, tags=HEX2VEC_FILTER)\n", + "features_gdf" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Join the objects with the regions they belong to" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "joiner = IntersectionJoiner()\n", + "joint_gdf = joiner.transform(regions_gdf, features_gdf)\n", + "joint_gdf" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Embed using features existing in data\n", + "\n", + "`ContextualCountEmbedder` extends capabilities of basic `CountEmbedder` by incorporating the neighbourhood of embedded region. In this example we will use the `H3Neighbourhood`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "h3n = H3Neighbourhood()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Squashed vector version (default)\n", + "\n", + "Embedder will return vector of the same length as `CountEmbedder`, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "squash_cce = ContextualCountEmbedder(\n", + " neighbourhood=h3n, neighbourhood_distance=10, squash_vectors=True\n", + ")\n", + "embeddings = squash_cce.transform(regions_gdf, features_gdf, joint_gdf)\n", + "embeddings" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Concatenated vector version (default)\n", + "\n", + "Embedder will return vector of length `n * distance` where `n` is number of features from the `CountEmbedder` and `distance` is number of neighbourhoods analysed.\n", + "\n", + "Each feature will be postfixed with `_n` string, where `n` is the current distance. Values are averaged from all neighbours." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wide_cce = ContextualCountEmbedder(\n", + " neighbourhood=h3n, neighbourhood_distance=10, squash_vectors=False\n", + ")\n", + "wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)\n", + "wide_embeddings" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting example features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_numeric_data(regions_gdf, embeddings, \"leisure\", tiles_style=\"CartoDB positron\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_numeric_data(regions_gdf, embeddings, \"building\", tiles_style=\"CartoDB positron\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/embedders/highway2vec_embedder.ipynb b/examples/embedders/highway2vec_embedder.ipynb index 3a462beb..a58648ac 100644 --- a/examples/embedders/highway2vec_embedder.ipynb +++ b/examples/embedders/highway2vec_embedder.ipynb @@ -105,7 +105,7 @@ "from srai.joiners import IntersectionJoiner\n", "\n", "joiner = IntersectionJoiner()\n", - "joint_gdf = joiner.transform(regions_gdf, edges_gdf, return_geom=False)\n", + "joint_gdf = joiner.transform(regions_gdf, edges_gdf)\n", "joint_gdf" ] }, diff --git a/examples/joiners/intersection_joiner.ipynb b/examples/joiners/intersection_joiner.ipynb index 0cd9f341..a52e662d 100644 --- a/examples/joiners/intersection_joiner.ipynb +++ b/examples/joiners/intersection_joiner.ipynb @@ -99,7 +99,7 @@ "from srai.joiners import IntersectionJoiner\n", "\n", "joiner = IntersectionJoiner()\n", - "joint = joiner.transform(regions, features)\n", + "joint = joiner.transform(regions, features, return_geom=True)\n", "\n", "joint" ] diff --git a/srai/embedders/__init__.py b/srai/embedders/__init__.py index 71947425..2086c95a 100644 --- a/srai/embedders/__init__.py +++ b/srai/embedders/__init__.py @@ -1,8 +1,17 @@ """Embedders.""" -from ._base import Embedder -from .count_embedder import CountEmbedder +# Force import of required base classes +from srai.embedders._base import Embedder +from srai.embedders.count_embedder import CountEmbedder + +from .contextual_count_embedder import ContextualCountEmbedder from .gtfs2vec_embedder import GTFS2VecEmbedder from .highway2vec import Highway2VecEmbedder -__all__ = ["Embedder", "CountEmbedder", "GTFS2VecEmbedder", "Highway2VecEmbedder"] +__all__ = [ + "Embedder", + "CountEmbedder", + "ContextualCountEmbedder", + "GTFS2VecEmbedder", + "Highway2VecEmbedder", +] diff --git a/srai/embedders/contextual_count_embedder.py b/srai/embedders/contextual_count_embedder.py index eca3680e..f62a37bd 100644 --- a/srai/embedders/contextual_count_embedder.py +++ b/srai/embedders/contextual_count_embedder.py @@ -8,7 +8,7 @@ [1] https://arxiv.org/abs/2111.00990 """ -from typing import List, Optional, TypeVar +from typing import List, Optional import geopandas as gpd import numpy as np @@ -17,8 +17,7 @@ from srai.embedders import CountEmbedder from srai.neighbourhoods import Neighbourhood - -T = TypeVar("T") +from srai.neighbourhoods._base import IndexType class ContextualCountEmbedder(CountEmbedder): @@ -26,31 +25,36 @@ class ContextualCountEmbedder(CountEmbedder): def __init__( self, - neighbourhood: Neighbourhood[T], + neighbourhood: Neighbourhood[IndexType], neighbourhood_distance: int, - squash_vectors: bool = True, + concatenate_vectors: bool = False, expected_output_features: Optional[List[str]] = None, count_subcategories: bool = False, ) -> None: - """TODO.""" - # """ - # Init ContextualCountEmbedder. - - # Args: - # expected_output_features (List[str], optional): The features that are expected - # to be found in the resulting embedding. If not None, the missing features are - # added - # and filled with 0. The unexpected features are removed. - # The resulting columns are sorted accordingly. Defaults to None. - # count_subcategories (bool, optional): Whether to count all subcategories individually - # or count features only on the highest level based on features column name. - # Defaults to True. - # """ + """ + Init ContextualCountEmbedder. + + Args: + neighbourhood (Neighbourhood[T]): Neighbourhood object used to get neighbours for + the contextualization. + neighbourhood_distance (int): How many neighbours levels should be included in + the embedding. + concatenate_vectors (bool, optional): Whether to sum all neighbours into a single vector + with the same width as 1CountEmbedder1, or to concatenate them to the wide format + and keep all neighbour levels separate. Defaults to False. + expected_output_features (List[str], optional): The features that are expected + to be found in the resulting embedding. If not None, the missing features are + added and filled with 0. The unexpected features are removed. + The resulting columns are sorted accordingly. Defaults to None. + count_subcategories (bool, optional): Whether to count all subcategories individually + or count features only on the highest level based on features column name. + Defaults to False. + """ super().__init__(expected_output_features, count_subcategories) self.neighbourhood = neighbourhood self.neighbourhood_distance = neighbourhood_distance - self.squash_vectors = squash_vectors + self.concatenate_vectors = concatenate_vectors def transform( self, @@ -58,47 +62,116 @@ def transform( features_gdf: gpd.GeoDataFrame, joint_gdf: gpd.GeoDataFrame, ) -> pd.DataFrame: - """TODO.""" + """ + Embed a given GeoDataFrame. + + Creates region embeddings by counting the frequencies of each feature value and applying + a contextualization based on neighbours of regions. For each region, features will be + altered based on the neighbours either by adding averaged values dimished based on distance, + or by adding new separate columns with neighbour distance postfix. + Expects features_gdf to be in wide format with each column being a separate type of + feature (e.g. amenity, leisure) and rows to hold values of these features for each object. + The rows will hold numbers of this type of feature in each region. Numbers can be + fractional because neighbourhoods are averaged to represent a single value from + all neighbours on a given leven. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries. + features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values. + joint_gdf (gpd.GeoDataFrame): Joiner result with region-feature multi-index. + + Returns: + pd.DataFrame: Embedding for each region in regions_gdf. + + Raises: + ValueError: If features_gdf is empty and self.expected_output_features is not set. + ValueError: If any of the gdfs index names is None. + ValueError: If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels. + ValueError: If index levels in gdfs don't overlap correctly. + """ counts_df = super().transform(regions_gdf, features_gdf, joint_gdf) - if self.squash_vectors: - return self._get_squashed_embeddings(counts_df) + result_df: pd.DataFrame + if self.concatenate_vectors: + result_df = self._get_concatenated_embeddings(counts_df) else: - return self._get_concatenated_embeddings(counts_df) + result_df = self._get_squashed_embeddings(counts_df) + + return result_df def _get_squashed_embeddings(self, counts_df: pd.DataFrame) -> pd.DataFrame: + """ + Generate embeddings for regions by summing all neighbourhood levels. + + Creates embedding by getting an average of a neighbourhood at a given distance and adding it + to the base values with weight equal to 1 / (distance + 1) squared. This way, farther + neighbourhoods have lower impact on feature values. + + Args: + counts_df (pd.DataFrame): Calculated features from CountEmbedder. + + Returns: + pd.DataFrame: Embedding for each region in regions_gdf with number of features equal to + the same as returned by the CountEmbedder. + """ base_columns = list(counts_df.columns) result_array = counts_df.values.astype(float) - for idx, region_id in tqdm(enumerate(counts_df.index), desc="Generating embeddings"): + for idx, region_id in tqdm( + enumerate(counts_df.index), desc="Generating embeddings", total=len(counts_df.index) + ): for distance in range(1, self.neighbourhood_distance + 1): neighbours = self.neighbourhood.get_neighbours_at_distance(region_id, distance) matching_neighbours = counts_df.index.intersection(neighbours) - if not matching_neighbours.empty: - values = counts_df.loc[matching_neighbours].values - flattened_values = np.average(values, axis=0) - result_array[idx, :] += flattened_values / ((distance + 1) ** 2) + if matching_neighbours.empty: + continue + + values = counts_df.loc[matching_neighbours].values + flattened_values = np.average(values, axis=0) + result_array[idx, :] += flattened_values / ((distance + 1) ** 2) + return pd.DataFrame(data=result_array, index=counts_df.index, columns=base_columns) def _get_concatenated_embeddings(self, counts_df: pd.DataFrame) -> pd.DataFrame: + """ + Generate embeddings for regions by concatenating different neighbourhood levels. + + Creates embedding by getting an average of a neighbourhood at a given distance and adding + those features as separate columns with a postfix added to feature name. This way, + all neighbourhoods can be analyzed separately, but number of columns grows linearly with + a distance. + + Args: + counts_df (pd.DataFrame): Calculated features from CountEmbedder. + + Returns: + pd.DataFrame: Embedding for each region in regions_gdf with number of features equal to + a number of features returned by the CountEmbedder multiplied + by (neighbourhood distance + 1). + """ base_columns = list(counts_df.columns) no_base_columns = len(base_columns) - columns = [ f"{column}_{distance}" for distance in range(self.neighbourhood_distance + 1) for column in base_columns ] + result_array = np.zeros(shape=(len(counts_df.index), len(columns))) result_array[:, 0:no_base_columns] = counts_df.values - for idx, region_id in tqdm(enumerate(counts_df.index), desc="Generating embeddings"): + for idx, region_id in tqdm( + enumerate(counts_df.index), desc="Generating embeddings", total=len(counts_df.index) + ): for distance in range(1, self.neighbourhood_distance + 1): neighbours = self.neighbourhood.get_neighbours_at_distance(region_id, distance) matching_neighbours = counts_df.index.intersection(neighbours) - if not matching_neighbours.empty: - values = counts_df.loc[matching_neighbours].values - flattened_values = np.average(values, axis=0) - result_array[ - idx, no_base_columns * distance : no_base_columns * (distance + 1) - ] = flattened_values + if matching_neighbours.empty: + continue + + values = counts_df.loc[matching_neighbours].values + flattened_values = np.average(values, axis=0) + result_array[idx, no_base_columns * distance : no_base_columns * (distance + 1)] = ( + flattened_values + ) + return pd.DataFrame(data=result_array, index=counts_df.index, columns=columns) diff --git a/srai/embedders/count_embedder.py b/srai/embedders/count_embedder.py index f51f0b1b..4bb9b5d6 100644 --- a/srai/embedders/count_embedder.py +++ b/srai/embedders/count_embedder.py @@ -49,7 +49,7 @@ def transform( Expects features_gdf to be in wide format with each column being a separate type of feature (e.g. amenity, leisure) and rows to hold values of these features for each object. - The resulting GeoDataFrame will have columns made by combining + The resulting DataFrame will have columns made by combining the feature name (column) and value (row) e.g. amenity_fuel or type_0. The rows will hold numbers of this type of feature in each region. @@ -59,7 +59,7 @@ def transform( joint_gdf (gpd.GeoDataFrame): Joiner result with region-feature multi-index. Returns: - pd.DataFrame: Embedding and geometry index for each region in regions_gdf. + pd.DataFrame: Embedding for each region in regions_gdf. Raises: ValueError: If features_gdf is empty and self.expected_output_features is not set. diff --git a/srai/joiners/intersection_joiner.py b/srai/joiners/intersection_joiner.py index 930fdf1c..be6ac438 100644 --- a/srai/joiners/intersection_joiner.py +++ b/srai/joiners/intersection_joiner.py @@ -19,7 +19,7 @@ class IntersectionJoiner: """ def transform( - self, regions: gpd.GeoDataFrame, features: gpd.GeoDataFrame, return_geom: bool = True + self, regions: gpd.GeoDataFrame, features: gpd.GeoDataFrame, return_geom: bool = False ) -> gpd.GeoDataFrame: """ Join features to regions based on an 'intersects' predicate. @@ -29,7 +29,8 @@ def transform( Args: regions (gpd.GeoDataFrame): regions with which features are joined features (gpd.GeoDataFrame): features to be joined - return_geom (bool): whether to return geometry of the joined features + return_geom (bool): whether to return geometry of the joined features. + Defaults to False. Returns: GeoDataFrame with an intersection of regions and features, which contains diff --git a/srai/plotting/folium_wrapper.py b/srai/plotting/folium_wrapper.py index 717549e3..15584a07 100644 --- a/srai/plotting/folium_wrapper.py +++ b/srai/plotting/folium_wrapper.py @@ -32,6 +32,7 @@ def plot_regions( width: Union[str, float] = "100%", colormap: Union[str, List[str]] = px.colors.qualitative.Bold, map: Optional[folium.Map] = None, + show_borders: bool = True, ) -> folium.Map: """ Plot regions shapes using Folium library. @@ -47,6 +48,8 @@ def plot_regions( Defaults to `px.colors.qualitative.Bold` from plotly library. map (folium.Map, optional): Existing map instance on which to draw the plot. Defaults to None. + show_borders (bool, optional): Whether to show borders between regions or not. + Defaults to True. Returns: folium.Map: Generated map. @@ -60,7 +63,7 @@ def plot_regions( legend=False, cmap=colormap, categorical=True, - style_kwds=dict(color="#444", opacity=0.5, fillOpacity=0.5), + style_kwds=dict(color="#444", opacity=0.5 if show_borders else 0, fillOpacity=0.5), m=map, ) @@ -74,6 +77,7 @@ def plot_numeric_data( width: Union[str, float] = "100%", colormap: Union[str, List[str]] = px.colors.sequential.Sunsetdark, map: Optional[folium.Map] = None, + show_borders: bool = False, ) -> folium.Map: """ Plot numerical data within regions shapes using Folium library. @@ -92,6 +96,8 @@ def plot_numeric_data( Defaults to px.colors.sequential.Sunsetdark. map (folium.Map, optional): Existing map instance on which to draw the plot. Defaults to None. + show_borders (bool, optional): Whether to show borders between regions or not. + Defaults to False. Returns: folium.Map: Generated map. @@ -115,7 +121,7 @@ def plot_numeric_data( legend=True, cmap=colormap, categorical=False, - style_kwds=dict(color="#444", opacity=0.5, fillOpacity=0.8), + style_kwds=dict(color="#444", opacity=0.5 if show_borders else 0, fillOpacity=0.8), m=map, ) @@ -128,6 +134,7 @@ def plot_neighbours( height: Union[str, float] = "100%", width: Union[str, float] = "100%", map: Optional[folium.Map] = None, + show_borders: bool = True, ) -> folium.Map: """ Plot neighbours on a map using Folium library. @@ -143,6 +150,8 @@ def plot_neighbours( width (Union[str, float], optional): Width of the plot. Defaults to "100%". map (folium.Map, optional): Existing map instance on which to draw the plot. Defaults to None. + show_borders (bool, optional): Whether to show borders between regions or not. + Defaults to True. Returns: folium.Map: Generated map. @@ -167,7 +176,7 @@ def plot_neighbours( ], categorical=True, categories=["selected", "neighbour", "other"], - style_kwds=dict(color="#444", opacity=0.5, fillOpacity=0.8), + style_kwds=dict(color="#444", opacity=0.5 if show_borders else 0, fillOpacity=0.8), m=map, ) @@ -182,6 +191,7 @@ def plot_all_neighbourhood( width: Union[str, float] = "100%", colormap: Union[str, List[str]] = px.colors.sequential.Agsunset_r, map: Optional[folium.Map] = None, + show_borders: bool = True, ) -> folium.Map: """ Plot full neighbourhood on a map using Folium library. @@ -203,6 +213,8 @@ def plot_all_neighbourhood( Defaults to `px.colors.sequential.Agsunset_r` from plotly library. map (folium.Map, optional): Existing map instance on which to draw the plot. Defaults to None. + show_borders (bool, optional): Whether to show borders between regions or not. + Defaults to True. Returns: folium.Map: Generated map. @@ -239,7 +251,7 @@ def plot_all_neighbourhood( cmap=colormap, categorical=True, categories=["selected", *list(range(distance))[1:], "other"], - style_kwds=dict(color="#444", opacity=0.5, fillOpacity=0.8), + style_kwds=dict(color="#444", opacity=0.5 if show_borders else 0, fillOpacity=0.8), legend=distance <= 11, m=map, ) From 7815a6a17f4b6729e6c17aaae3e185140226e092 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Thu, 13 Apr 2023 17:09:21 +0200 Subject: [PATCH 03/30] chore: change changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5086cfaa..c549aa80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Neighbourhood - H3Neighbourhood - AdjacencyNeighbourhood +- CountEmbedder +- ContextualCountEmbedder - (CI) Changelog Enforcer - Utility plotting module based on Folium and Plotly @@ -27,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Change embedders and joiners interface to have `.transform` method - Change linter to Ruff and removed flake8, isort, pydocstyle +- Change default value inside `transform` function of IntersectionJoiner to not return geometry. ### Deprecated From e2807e656d4b570f9808ba9a77b9d1323e4544a6 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 14 Apr 2023 09:36:14 +0200 Subject: [PATCH 04/30] docs: update srai/embedders/contextual_count_embedder.py Co-authored-by: Piotr Gramacki <37406231+piotrgramacki@users.noreply.github.com> --- srai/embedders/contextual_count_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/srai/embedders/contextual_count_embedder.py b/srai/embedders/contextual_count_embedder.py index f62a37bd..9e152869 100644 --- a/srai/embedders/contextual_count_embedder.py +++ b/srai/embedders/contextual_count_embedder.py @@ -73,7 +73,7 @@ def transform( feature (e.g. amenity, leisure) and rows to hold values of these features for each object. The rows will hold numbers of this type of feature in each region. Numbers can be fractional because neighbourhoods are averaged to represent a single value from - all neighbours on a given leven. + all neighbours on a given level. Args: regions_gdf (gpd.GeoDataFrame): Region indexes and geometries. From 49e2569924be80838f5880fcc9a1324192d73e77 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 14 Apr 2023 09:51:17 +0200 Subject: [PATCH 05/30] refactor: changed embedders imports --- srai/embedders/__init__.py | 6 ++---- srai/embedders/contextual_count_embedder.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/srai/embedders/__init__.py b/srai/embedders/__init__.py index 2086c95a..248940f2 100644 --- a/srai/embedders/__init__.py +++ b/srai/embedders/__init__.py @@ -1,10 +1,8 @@ """Embedders.""" -# Force import of required base classes -from srai.embedders._base import Embedder -from srai.embedders.count_embedder import CountEmbedder - +from ._base import Embedder from .contextual_count_embedder import ContextualCountEmbedder +from .count_embedder import CountEmbedder from .gtfs2vec_embedder import GTFS2VecEmbedder from .highway2vec import Highway2VecEmbedder diff --git a/srai/embedders/contextual_count_embedder.py b/srai/embedders/contextual_count_embedder.py index 9e152869..13c4afec 100644 --- a/srai/embedders/contextual_count_embedder.py +++ b/srai/embedders/contextual_count_embedder.py @@ -15,7 +15,7 @@ import pandas as pd from tqdm import tqdm -from srai.embedders import CountEmbedder +from srai.embedders.count_embedder import CountEmbedder from srai.neighbourhoods import Neighbourhood from srai.neighbourhoods._base import IndexType From ec82d7f73271e6d24f629ad812c4a035ee8ac596 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 14 Apr 2023 09:52:48 +0200 Subject: [PATCH 06/30] chore: added additional info to pbf downloader --- .../osm_loaders/pbf_file_downloader.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/srai/loaders/osm_loaders/pbf_file_downloader.py b/srai/loaders/osm_loaders/pbf_file_downloader.py index 7a259185..66996ec5 100644 --- a/srai/loaders/osm_loaders/pbf_file_downloader.py +++ b/srai/loaders/osm_loaders/pbf_file_downloader.py @@ -20,12 +20,7 @@ from tqdm import tqdm from srai.constants import WGS84_CRS -from srai.utils import ( - buffer_geometry, - download_file, - flatten_geometry, - remove_interiors, -) +from srai.utils import buffer_geometry, download_file, flatten_geometry, remove_interiors class PbfFileDownloader: @@ -46,7 +41,7 @@ class PbfFileDownloader: PROTOMAPS_API_START_URL = "https://app.protomaps.com/downloads/osm" PROTOMAPS_API_DOWNLOAD_URL = "https://app.protomaps.com/downloads/{}/download" - _PBAR_FORMAT = "Downloading pbf file ({})" + _PBAR_FORMAT = "[{}] Downloading pbf file #{} ({})" SIMPLIFICATION_TOLERANCE_VALUES = [ 1e-07, @@ -100,12 +95,15 @@ def download_pbf_files_for_regions_gdf( for region_id, row in regions_gdf.iterrows(): polygons = flatten_geometry(row.geometry) regions_mapping[region_id] = [ - self.download_pbf_file_for_polygon(polygon) for polygon in polygons + self.download_pbf_file_for_polygon(polygon, region_id, polygon_id + 1) + for polygon_id, polygon in enumerate(polygons) ] return regions_mapping - def download_pbf_file_for_polygon(self, polygon: Polygon) -> Path: + def download_pbf_file_for_polygon( + self, polygon: Polygon, region_id: str = "OSM", polygon_id: int = 1 + ) -> Path: """ Download PBF file for a single Polygon. @@ -118,6 +116,10 @@ def download_pbf_file_for_polygon(self, polygon: Polygon) -> Path: Args: polygon (Polygon): Polygon boundary of an area to be extracted. + region_id (str, optional): Region name to be set in progress bar. + Defaults to "OSM". + polygon_id (int, optional): Polygon number to be set in progress bar. + Defaults to 1. Returns: Path: Path to a downloaded `*.osm.pbf` file. @@ -179,15 +181,21 @@ def download_pbf_file_for_polygon(self, polygon: Polygon) -> Path: elems_prog = status_response.get("ElemsProg", None) if cells_total > 0 and cells_prog is not None and cells_prog < cells_total: - pbar.set_description(self._PBAR_FORMAT.format("Cells")) + pbar.set_description( + self._PBAR_FORMAT.format(region_id, polygon_id, "Cells") + ) pbar.total = cells_total + nodes_total + elems_total pbar.n = cells_prog elif nodes_total > 0 and nodes_prog is not None and nodes_prog < nodes_total: - pbar.set_description(self._PBAR_FORMAT.format("Nodes")) + pbar.set_description( + self._PBAR_FORMAT.format(region_id, polygon_id, "Nodes") + ) pbar.total = cells_total + nodes_total + elems_total pbar.n = cells_total + nodes_prog elif elems_total > 0 and elems_prog is not None and elems_prog < elems_total: - pbar.set_description(self._PBAR_FORMAT.format("Elements")) + pbar.set_description( + self._PBAR_FORMAT.format(region_id, polygon_id, "Elements") + ) pbar.total = cells_total + nodes_total + elems_total pbar.n = cells_total + nodes_total + elems_prog else: From 2a55fb150161da5830789ed8b553cf686367c566 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 14 Apr 2023 09:54:04 +0200 Subject: [PATCH 07/30] docs(ContextualCountEmbedder): changed notebook --- examples/embedders/contextual_count_embedder.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/embedders/contextual_count_embedder.ipynb b/examples/embedders/contextual_count_embedder.ipynb index 33d50856..5ac83a51 100644 --- a/examples/embedders/contextual_count_embedder.ipynb +++ b/examples/embedders/contextual_count_embedder.ipynb @@ -132,10 +132,10 @@ "metadata": {}, "outputs": [], "source": [ - "squash_cce = ContextualCountEmbedder(\n", - " neighbourhood=h3n, neighbourhood_distance=10, squash_vectors=True\n", + "cce = ContextualCountEmbedder(\n", + " neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False\n", ")\n", - "embeddings = squash_cce.transform(regions_gdf, features_gdf, joint_gdf)\n", + "embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)\n", "embeddings" ] }, @@ -158,7 +158,7 @@ "outputs": [], "source": [ "wide_cce = ContextualCountEmbedder(\n", - " neighbourhood=h3n, neighbourhood_distance=10, squash_vectors=False\n", + " neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True\n", ")\n", "wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)\n", "wide_embeddings" From c84e8b3223c6fd5ea99a1cff51394438ec27f90c Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 14 Apr 2023 09:56:20 +0200 Subject: [PATCH 08/30] docs(ContextualCountEmbedder): change typo --- srai/embedders/contextual_count_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/srai/embedders/contextual_count_embedder.py b/srai/embedders/contextual_count_embedder.py index 13c4afec..e6cacfc0 100644 --- a/srai/embedders/contextual_count_embedder.py +++ b/srai/embedders/contextual_count_embedder.py @@ -40,7 +40,7 @@ def __init__( neighbourhood_distance (int): How many neighbours levels should be included in the embedding. concatenate_vectors (bool, optional): Whether to sum all neighbours into a single vector - with the same width as 1CountEmbedder1, or to concatenate them to the wide format + with the same width as `CountEmbedder`, or to concatenate them to the wide format and keep all neighbour levels separate. Defaults to False. expected_output_features (List[str], optional): The features that are expected to be found in the resulting embedding. If not None, the missing features are From b27237ed7b805ee46728865c3c303a0ea05c10d0 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Sun, 16 Apr 2023 12:21:22 +0200 Subject: [PATCH 09/30] feat: add new `osm_tags` type for filtering --- .pre-commit-config.yaml | 4 +- pdm.lock | 21 +- pyproject.toml | 1 + srai/loaders/osm_loaders/filters/geofabrik.py | 370 ++++++++++++++++++ .../osm_loaders/filters/osm_tags_type.py | 2 + srai/utils/typing.py | 29 ++ 6 files changed, 422 insertions(+), 5 deletions(-) create mode 100644 srai/loaders/osm_loaders/filters/geofabrik.py create mode 100644 srai/utils/typing.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3b7bfaaf..3696b4ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: conventional-pre-commit stages: [commit-msg] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.0.259' + rev: "v0.0.259" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -28,7 +28,7 @@ repos: - id: mypy additional_dependencies: ["types-requests"] - repo: https://github.com/pdm-project/pdm - rev: 2.4.9 + rev: 2.5.2 hooks: - id: pdm-lock-check - id: pdm-export diff --git a/pdm.lock b/pdm.lock index 2c9d9b37..1f978619 100644 --- a/pdm.lock +++ b/pdm.lock @@ -1717,7 +1717,7 @@ summary = "Python for Window Extensions" name = "pywinpty" version = "2.0.10" requires_python = ">=3.7" -summary = "Pseudo terminal support for Windows from Python." +summary = "" [[package]] name = "pyyaml" @@ -2102,6 +2102,16 @@ dependencies = [ "torch", ] +[[package]] +name = "typeguard" +version = "3.0.2" +requires_python = ">=3.7.4" +summary = "Run-time type checker for Python" +dependencies = [ + "importlib-metadata>=3.6; python_version < \"3.10\"", + "typing-extensions>=4.4.0; python_version < \"3.11\"", +] + [[package]] name = "types-setuptools" version = "67.6.0.5" @@ -2219,8 +2229,9 @@ requires_python = ">=3.7" summary = "Backport of pathlib-compatible object wrapper for zip files" [metadata] -lock_version = "4.1" -content_hash = "sha256:962e57d64c978b93586dde39aafde37e2c322e5d5764b2af1be7b5d5fa101ef4" +lock_version = "4.2" +groups = ["default", "all", "dev", "docs", "gtfs", "license", "lint", "osm", "performance", "plotting", "test", "visualization", "voronoi"] +content_hash = "sha256:0a1cc165f764516cab6a99b789b18c110e0f797523d1bd4d79794dcde4156fd8" [metadata.files] "aiohttp 3.8.4" = [ @@ -4518,6 +4529,10 @@ content_hash = "sha256:962e57d64c978b93586dde39aafde37e2c322e5d5764b2af1be7b5d5f {url = "https://files.pythonhosted.org/packages/e1/1d/fd57199cf8be37f7e846408cacb0a466ae21cd1612e268ab168bc815b007/triton-2.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42a0d2c3fc2eab4ba71384f2e785fbfd47aa41ae05fa58bf12cb31dcbd0aeceb"}, {url = "https://files.pythonhosted.org/packages/ff/dd/606cb34d8060ab1768cefe0e1f4658c21121f9cae13e8c4444a8dc1665eb/triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"}, ] +"typeguard 3.0.2" = [ + {url = "https://files.pythonhosted.org/packages/af/40/3398497c6e6951c92abaf933492d6633e7ac4df0bfc9d81f304b3f977f15/typeguard-3.0.2.tar.gz", hash = "sha256:fee5297fdb28f8e9efcb8142b5ee219e02375509cd77ea9d270b5af826358d5a"}, + {url = "https://files.pythonhosted.org/packages/e2/62/7d206b0ac6fcbb163215ecc622a54eb747f85ad86d14bc513a834442d0f6/typeguard-3.0.2-py3-none-any.whl", hash = "sha256:bbe993854385284ab42fd5bd3bee6f6556577ce8b50696d6cb956d704f286c8e"}, +] "types-setuptools 67.6.0.5" = [ {url = "https://files.pythonhosted.org/packages/52/d9/d655de64223bc1ab4efcf39236ce6945121129bd716bfbc3ba64cda688ff/types-setuptools-67.6.0.5.tar.gz", hash = "sha256:3a708e66c7bdc620e4d0439f344c750c57a4340c895a4c3ed2d0fc4ae8eb9962"}, {url = "https://files.pythonhosted.org/packages/d4/d8/5ececf82707d6797808c2cea97e0eab0cddd26549ee5710f425cfa87ae35/types_setuptools-67.6.0.5-py3-none-any.whl", hash = "sha256:dae5a4a659dbb6dba57773440f6e2dbdd8ef282dc136a174a8a59bd33d949945"}, diff --git a/pyproject.toml b/pyproject.toml index a4c969a3..36ec009b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "s2>=0.1.9", "pytorch-lightning>=1.0.0", "torch", + "typeguard>=3.0.2", ] requires-python = ">=3.8" readme = "README.md" diff --git a/srai/loaders/osm_loaders/filters/geofabrik.py b/srai/loaders/osm_loaders/filters/geofabrik.py new file mode 100644 index 00000000..1952f870 --- /dev/null +++ b/srai/loaders/osm_loaders/filters/geofabrik.py @@ -0,0 +1,370 @@ +""" +Geofabrik layers filter. + +This module contains the grouped OSM tags filter that is defined by Geofabrik [1]. +Based on the document version `0.7.12`. + +Note: not all definitions from the document are implemented, such as boundaries or places. + +References: + 1. https://www.geofabrik.de/data/geofabrik-osm-gis-standard-0.7.pdf +""" +from srai.loaders.osm_loaders.filters.osm_tags_type import grouped_osm_tags_type + +GEOFABRIK_LAYERS: grouped_osm_tags_type = { + "public": { + "amenity": [ + "police", + "fire_station", + "post_box", + "post_office", + "telephone", + "library", + "townhall", + "courthouse", + "prison", + "embassy", + "community_centre", + "nursing_home", + "arts_centre", + "grave_yard", + "marketplace", + "recycling", + "public_building", + ], + "office": ["diplomatic"], + "landuse": ["cemetery"], + }, + "education": { + "amenity": [ + "university", + "school", + "kindergarten", + "college", + ] + }, + "health": { + "amenity": [ + "pharmacy", + "hospital", + "clinic", + "doctors", + "dentist", + "veterinary", + ] + }, + "leisure": { + "amenity": [ + "theatre", + "nightclub", + "cinema", + "swimming_pool", + "theatre", + "theatre", + ], + "leisure": [ + "park", + "playground", + "dog_park", + "sports_centre", + "swimming_pool", + "water_park", + "golf_course", + "stadium", + "ice_rink", + ], + "sport": [ + "swimming", + "tennis", + ], + }, + "catering": { + "amenity": [ + "restaurant", + "fast_food", + "cafe", + "pub", + "bar", + "food_court", + "biergarten", + ] + }, + "accommodation": { + "tourism": [ + "hotel", + "motel", + "bed_and_breakfast", + "guest_house", + "hostel", + "chalet", + "camp_site", + "alpine_hut", + "caravan_site", + ], + "amenity": ["shelter"], + }, + "shopping": { + "shop": [ + "supermarket", + "bakery", + "kiosk", + "mall", + "department_store", + "general", + "convenience", + "clothes", + "florist", + "chemist", + "books", + "butcher", + "shoes", + "alcohol", + "beverages", + "optician", + "jewelry", + "gift", + "sports", + "stationery", + "outdoor", + "mobile_phone", + "toys", + "newsagent", + "greengrocer", + "beauty", + "video", + "car", + "bicycle", + "doityourself", + "hardware", + "furniture", + "computer", + "garden_centre", + "hairdresser", + "car_repair", + "travel_agency", + "laundry", + "dry_cleaning", + ], + "amenity": ["car_rental", "car_wash", "car_sharing", "bicycle_rental", "vending_machine"], + "vending": [ + "cigarettes", + "parking_tickets", + ], + }, + "money": {"amenity": ["bank", "atm"]}, + "tourism": { + "tourism": [ + "information", + "attraction", + "museum", + "artwork", + "picnic_site", + "viewpoint", + "zoo", + "theme_park", + ], + "historic": [ + "monument", + "memorial", + "castle", + "ruins", + "archaeological_site", + "wayside_cross", + "wayside_shrine", + "battlefield", + "fort", + ], + }, + "miscpoi": { + "amenity": [ + "toilets", + "bench", + "drinking_water", + "fountain", + "hunting_stand", + "waste_basket", + "emergency_phone", + "fire_hydrant", + ], + "man_made": [ + "surveillance", + "tower", + "water_tower", + "windmill", + "lighthouse", + "wastewater_plant", + "water_well", + "watermill", + "water_works", + ], + "emergency": ["phone", "fire_hydrant"], + "highway": ["emergency_access_point"], + }, + "pofw": {"amenity": ["place_of_worship"]}, + "natural": { + "natural": [ + "spring", + "glacier", + "peak", + "cliff", + "volcano", + "tree", + "mine", + "cave_entrance", + "beach", + ] + }, + "traffic": { + "highway": [ + "traffic_signals", + "mini_roundabout", + "stop", + "crossing", + "ford", + "motorway_junction", + "turning_circle", + "speed_camera", + "street_lamp", + ], + "railway": ["level_crossing"], + }, + "fuel_parking": {"amenity": ["fuel", "parking", "bicycle_parking"], "highway": ["services"]}, + "water_traffic": { + "leisure": [ + "slipway", + "marina", + ], + "man_made": ["pier"], + "waterway": [ + "dam", + "waterfall", + "lock_gate", + "weir", + ], + }, + "transport": { + "railway": ["station", "halt", "tram_stop"], + "public_transport": ["stop_position"], + "highway": ["bus_stop"], + "amenity": ["bus_station", "taxi", "ferry_terminal"], + "aerialway": ["station"], + }, + "air_traffic": { + "amenity": ["airport"], + "aeroway": [ + "aerodrome", + "airfield", + "aeroway", + "helipad", + "apron", + ], + "military": ["airfield"], + }, + "major_roads": { + "highway": [ + "motorway", + "trunk", + "primary", + "secondary", + "tertiary", + ] + }, + "minor_roads": { + "highway": [ + "unclassified", + "residential", + "living_street", + "pedestrian", + "busway", + ] + }, + "highway_links": { + "highway": [ + "motorway_link", + "trunk_link", + "primary_link", + "secondary_link", + "tertiary_link", + ] + }, + "very_small_roads": { + "highway": [ + "service", + "track", + ] + }, + "paths_unsuitable_for_cars": { + "highway": ["bridleway", "path", "cycleway", "footway", "steps"], + "cycle": ["designated"], + "horse": ["designated"], + "foot": ["designated"], + }, + "unkown_roads": {"highway": ["road"]}, + "railways": { + "railway": [ + "rail", + "light_rail", + "subway", + "tram", + "monorail", + "narrow_gauge", + "miniature", + "funicular", + "rack", + ], + "aerialway": [ + "drag_lift", + "chair_lift", + "high_speed_chair_lift", + "cable_car", + "gondola", + "goods", + "platter", + "t-bar", + "j-bar", + "magic_carpet", + "zip_line", + "rope_tow", + "mixed_lift", + ], + }, + "waterways": { + "waterway": [ + "river", + "stream", + "canal", + "drain", + ] + }, + "buildings": {"building": True}, + "landuse": { + "landuse": [ + "forest", + "residential", + "industrial", + "cemetery", + "allotments", + "meadow", + "commercial", + "recreation_ground", + "retail", + "military", + "quarry", + "orchard", + "vineyard", + "scrub", + "grass", + "military", + "farmland", + "farmyard", + ], + "leisure": ["park", "common", "nature_reserve", "recreation_ground"], + "natural": ["wood", "heath"], + "boundary": ["national_park"], + }, + "water": { + "natural": ["water", "glacier", "wetland"], + "landuse": ["reservoir"], + "waterway": ["riverbank", "dock"], + }, +} diff --git a/srai/loaders/osm_loaders/filters/osm_tags_type.py b/srai/loaders/osm_loaders/filters/osm_tags_type.py index d3086ba3..6de2b656 100644 --- a/srai/loaders/osm_loaders/filters/osm_tags_type.py +++ b/srai/loaders/osm_loaders/filters/osm_tags_type.py @@ -2,3 +2,5 @@ from typing import Dict, List, Union osm_tags_type = Dict[str, Union[List[str], str, bool]] + +grouped_osm_tags_type = Dict[str, Dict[str, Union[List[str], str, bool]]] diff --git a/srai/utils/typing.py b/srai/utils/typing.py new file mode 100644 index 00000000..2356af52 --- /dev/null +++ b/srai/utils/typing.py @@ -0,0 +1,29 @@ +"""Utility function for typing purposes.""" + +from typing import Any + +from typeguard import TypeCheckError, check_type + + +def is_expected_type(value: object, expected_type: Any) -> bool: + """ + Check if an object is a given type. + + Uses `typeguard` library to check objects using `typing` definitions. + + Args: + value (object): Value to be checked against `expected_type`. + expected_type (Any): A class or generic type instance. + + Returns: + bool: Flag whether the object is an instance of the required type. + """ + result = False + + try: + check_type(value, expected_type) + result = True + except TypeCheckError: + pass + + return result From 20e5861af26367db499e21e56cefe62412c8714f Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Mon, 17 Apr 2023 16:31:57 +0200 Subject: [PATCH 10/30] refactor: apply refurb suggestions --- srai/utils/typing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/srai/utils/typing.py b/srai/utils/typing.py index 2356af52..2759fd23 100644 --- a/srai/utils/typing.py +++ b/srai/utils/typing.py @@ -1,5 +1,6 @@ """Utility function for typing purposes.""" +from contextlib import suppress from typing import Any from typeguard import TypeCheckError, check_type @@ -20,10 +21,8 @@ def is_expected_type(value: object, expected_type: Any) -> bool: """ result = False - try: + with suppress(TypeCheckError): check_type(value, expected_type) result = True - except TypeCheckError: - pass return result From 809e44c606d5441a987341553c85c8af26dccd02 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Mon, 17 Apr 2023 16:41:35 +0200 Subject: [PATCH 11/30] feat: add utility function to merge grouped filter --- .../osm_loaders/filters/osm_tags_type.py | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/srai/loaders/osm_loaders/filters/osm_tags_type.py b/srai/loaders/osm_loaders/filters/osm_tags_type.py index 6de2b656..2b0d2b12 100644 --- a/srai/loaders/osm_loaders/filters/osm_tags_type.py +++ b/srai/loaders/osm_loaders/filters/osm_tags_type.py @@ -1,6 +1,48 @@ """Module contains a dedicated type alias for OSM tags filter.""" -from typing import Dict, List, Union +from typing import Dict, List, Union, cast osm_tags_type = Dict[str, Union[List[str], str, bool]] grouped_osm_tags_type = Dict[str, Dict[str, Union[List[str], str, bool]]] + + +def merge_grouped_osm_tags_type(grouped_filter: grouped_osm_tags_type) -> osm_tags_type: + """ + Merge grouped osm tags filter into a base one. + + Function merges all filter categories into a single one for an OSM loader to use. + + Args: + grouped_filter (grouped_osm_tags_type): Grouped filter to be merged into a single one. + + Returns: + osm_tags_type: Merged filter. + """ + result: osm_tags_type = {} + for sub_filter in grouped_filter.values(): + for osm_tag_key, osm_tag_value in sub_filter.items(): + if osm_tag_key not in result: + result[osm_tag_key] = [] + + # If filter is already a positive boolean, skip + if isinstance(result[osm_tag_key], bool) and result[osm_tag_key]: + continue + + # Check bool + if isinstance(osm_tag_value, bool) and osm_tag_value: + result[osm_tag_key] = True + # Check string + elif isinstance(osm_tag_value, str) and osm_tag_value not in cast( + List[str], result[osm_tag_key] + ): + cast(List[str], result[osm_tag_key]).append(osm_tag_value) + # Check list + elif isinstance(osm_tag_value, list): + new_values = [ + value + for value in osm_tag_value + if value not in cast(List[str], result[osm_tag_key]) + ] + cast(List[str], result[osm_tag_key]).extend(new_values) + + return result From 8b9de5d5177e22864dcd4d7d1d4f0b8effc55ede Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Tue, 18 Apr 2023 19:37:37 +0200 Subject: [PATCH 12/30] chore: add merging test skeleton --- srai/loaders/osm_loaders/filters/__init__.py | 3 +- .../filters/test_merge_filter_types.py | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 tests/loaders/osm_loaders/filters/test_merge_filter_types.py diff --git a/srai/loaders/osm_loaders/filters/__init__.py b/srai/loaders/osm_loaders/filters/__init__.py index 69e97ccd..18322c8f 100644 --- a/srai/loaders/osm_loaders/filters/__init__.py +++ b/srai/loaders/osm_loaders/filters/__init__.py @@ -1,6 +1,7 @@ """Filters.""" +from .geofabrik import GEOFABRIK_LAYERS from .hex2vec import HEX2VEC_FILTER from .popular import get_popular_tags -__all__ = ["HEX2VEC_FILTER", "get_popular_tags"] +__all__ = ["GEOFABRIK_LAYERS", "HEX2VEC_FILTER", "get_popular_tags"] diff --git a/tests/loaders/osm_loaders/filters/test_merge_filter_types.py b/tests/loaders/osm_loaders/filters/test_merge_filter_types.py new file mode 100644 index 00000000..af456ea1 --- /dev/null +++ b/tests/loaders/osm_loaders/filters/test_merge_filter_types.py @@ -0,0 +1,30 @@ +"""Tests for merging OSM Loaders filters.""" +from unittest import TestCase + +import pytest + +from srai.loaders.osm_loaders.filters.osm_tags_type import ( + grouped_osm_tags_type, + merge_grouped_osm_tags_type, + osm_tags_type, +) + +ut = TestCase() + + +@pytest.fixture # type: ignore +def expected_result_min_count_8m() -> osm_tags_type: + """Get expected results when using `min_count=8_000_000`.""" + return { + "natural": ["wood"], + "landuse": ["farmland", "residential"], + } + + +def test_merge_grouped_filters() -> None: + """Test merging grouped tags filter into a base osm filter.""" + base_filter: osm_tags_type = {} + grouped_filter: grouped_osm_tags_type = {} + merged_filters = merge_grouped_osm_tags_type(grouped_filter) + + ut.assertDictEqual(base_filter, merged_filters) From 98098f2fc5950c10c771efc43633cdb0154e22ff Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Tue, 18 Apr 2023 23:02:19 +0200 Subject: [PATCH 13/30] feat: add osm filters grouping to the loaders --- srai/loaders/osm_loaders/__init__.py | 3 +- srai/loaders/osm_loaders/_base.py | 171 ++++++++++++++++++ srai/loaders/osm_loaders/filters/__init__.py | 10 +- .../filters/{osm_tags_type.py => _typing.py} | 0 srai/loaders/osm_loaders/filters/geofabrik.py | 2 +- srai/loaders/osm_loaders/filters/hex2vec.py | 2 +- srai/loaders/osm_loaders/filters/popular.py | 2 +- srai/loaders/osm_loaders/osm_online_loader.py | 16 +- srai/loaders/osm_loaders/osm_pbf_loader.py | 21 ++- srai/loaders/osm_loaders/pbf_file_handler.py | 6 +- .../filters/test_merge_filter_types.py | 2 +- .../osm_loaders/test_osm_online_loader.py | 2 +- .../osm_loaders/test_osm_pbf_loader.py | 2 +- 13 files changed, 212 insertions(+), 27 deletions(-) create mode 100644 srai/loaders/osm_loaders/_base.py rename srai/loaders/osm_loaders/filters/{osm_tags_type.py => _typing.py} (100%) diff --git a/srai/loaders/osm_loaders/__init__.py b/srai/loaders/osm_loaders/__init__.py index b65734ea..9eeba186 100644 --- a/srai/loaders/osm_loaders/__init__.py +++ b/srai/loaders/osm_loaders/__init__.py @@ -1,6 +1,7 @@ """OSM Loaders.""" +from ._base import OSMLoader from .osm_online_loader import OSMOnlineLoader from .osm_pbf_loader import OSMPbfLoader -__all__ = ["OSMOnlineLoader", "OSMPbfLoader"] +__all__ = ["OSMLoader", "OSMOnlineLoader", "OSMPbfLoader"] diff --git a/srai/loaders/osm_loaders/_base.py b/srai/loaders/osm_loaders/_base.py new file mode 100644 index 00000000..79a55259 --- /dev/null +++ b/srai/loaders/osm_loaders/_base.py @@ -0,0 +1,171 @@ +"""Base class for OSM loaders.""" + +import abc +from typing import Dict, Optional, Union, cast + +import geopandas as gpd +import pandas as pd +from tqdm import tqdm + +from srai.loaders.osm_loaders.filters._typing import ( + grouped_osm_tags_type, + merge_grouped_osm_tags_type, + osm_tags_type, +) +from srai.utils.typing import is_expected_type + + +class OSMLoader(abc.ABC): + """Abstract class for loaders.""" + + @abc.abstractmethod + def load( + self, + area: gpd.GeoDataFrame, + tags: Union[osm_tags_type, grouped_osm_tags_type], + ) -> gpd.GeoDataFrame: # pragma: no cover + """ + Load data for a given area. + + Args: + area (gpd.GeoDataFrame): GeoDataFrame with the area of interest. + tags (Union[osm_tags_type, grouped_osm_tags_type]): OSM tags filter. + + Returns: + gpd.GeoDataFrame: GeoDataFrame with the downloaded data. + """ + raise NotImplementedError + + def _merge_osm_tags_filter( + self, tags: Union[osm_tags_type, grouped_osm_tags_type] + ) -> osm_tags_type: + """ + Merge OSM tags filter into `osm_tags_type` type. + + Optionally merges `grouped_osm_tags_type` into `osm_tags_type` to allow loaders to load all + defined groups during single operation. + + Args: + tags (Union[osm_tags_type, grouped_osm_tags_type]): OSM tags filter definition. + + Raises: + AttributeError: When provided tags don't match both + `osm_tags_type` or `grouped_osm_tags_type`. + + Returns: + osm_tags_type: Merged filters. + """ + if is_expected_type(tags, osm_tags_type): + return cast(osm_tags_type, tags) + elif is_expected_type(tags, grouped_osm_tags_type): + return merge_grouped_osm_tags_type(cast(grouped_osm_tags_type, tags)) + + raise AttributeError( + "Provided tags don't match required type definitions" + " (osm_tags_type or grouped_osm_tags_type)." + ) + + def _parse_features_gdf_to_groups( + self, features_gdf: gpd.GeoDataFrame, tags: Union[osm_tags_type, grouped_osm_tags_type] + ) -> gpd.GeoDataFrame: + """ + Optionally group raw OSM features into groups defined in `grouped_osm_tags_type`. + + Args: + features_gdf (gpd.GeoDataFrame): Generated features from the loader. + tags (Union[osm_tags_type, grouped_osm_tags_type]): OSM tags filter definition. + + Returns: + gpd.GeoDataFrame: Parsed features_gdf. + """ + if is_expected_type(tags, grouped_osm_tags_type): + features_gdf = self._group_features_gdf(features_gdf, cast(grouped_osm_tags_type, tags)) + return features_gdf + + def _group_features_gdf( + self, features_gdf: gpd.GeoDataFrame, group_filter: grouped_osm_tags_type + ) -> gpd.GeoDataFrame: + """ + Group raw OSM features into groups defined in `grouped_osm_tags_type`. + + Creates new features based on definition from `grouped_osm_tags_type`. + Returns transformed GeoDataFrame with columns based on group names from the filter. + Values are built by concatenation of matching tag key and value with + an equal sign (eg. amenity=parking). Since many tags can match a definition + of a single group, a first match is used as a feature value. + + Args: + features_gdf (gpd.GeoDataFrame): Generated features from the loader. + group_filter (grouped_osm_tags_type): Grouped OSM tags filter definition. + + Returns: + gpd.GeoDataFrame: Parsed grouped features_gdf. + """ + for index, row in tqdm( + features_gdf.iterrows(), desc="Grouping features", total=len(features_gdf.index) + ): + grouped_features = self._get_osm_filter_groups(row=row, group_filter=group_filter) + for group_name, feature_value in grouped_features.items(): + features_gdf.loc[index, group_name] = feature_value + + matching_columns = [ + column for column in group_filter.keys() if column in features_gdf.columns + ] + + return features_gdf[["geometry", *matching_columns]] + + def _get_osm_filter_groups( + self, row: pd.Series, group_filter: grouped_osm_tags_type + ) -> Dict[str, str]: + """ + Get new group features for a single row. + + Args: + row (pd.Series): Row to be analysed. + group_filter (grouped_osm_tags_type): Grouped OSM tags filter definition. + + Returns: + Dict[str, str]: Dictionary with matching group names and values. + """ + result = {} + + for group_name, osm_filter in group_filter.items(): + matching_osm_tag = self._get_first_matching_osm_tag_value( + row=row, osm_filter=osm_filter + ) + if matching_osm_tag: + result[group_name] = matching_osm_tag + + return result + + def _get_first_matching_osm_tag_value( + self, row: pd.Series, osm_filter: osm_tags_type + ) -> Optional[str]: + """ + Find first matching OSM tag key and value pair for a subgroup filter. + + Returns a first matching pair of OSM tag key and value concatenated + with an equal sign (eg. amenity=parking). If none of the values + in the row matches the filter, `None` value is returned. + + Args: + row (pd.Series): Row to be analysed. + osm_filter (osm_tags_type): OSM tags filter definition. + + Returns: + Optional[str]: New feature value. + """ + for osm_tag_key, osm_tag_value in osm_filter.items(): + if osm_tag_key not in row or pd.isna(row[osm_tag_key]): + continue + + if isinstance(osm_tag_value, bool) and osm_tag_value: + return f"{osm_tag_key}={row[osm_tag_key]}" + + if isinstance(osm_tag_value, str) and row[osm_tag_key] == osm_tag_value: + return f"{osm_tag_key}={row[osm_tag_key]}" + + if isinstance(osm_tag_value, list) and row[osm_tag_key] in osm_tag_value: + return f"{osm_tag_key}={row[osm_tag_key]}" + + return None diff --git a/srai/loaders/osm_loaders/filters/__init__.py b/srai/loaders/osm_loaders/filters/__init__.py index 18322c8f..b1d266f6 100644 --- a/srai/loaders/osm_loaders/filters/__init__.py +++ b/srai/loaders/osm_loaders/filters/__init__.py @@ -1,7 +1,15 @@ """Filters.""" +from ._typing import grouped_osm_tags_type, merge_grouped_osm_tags_type, osm_tags_type from .geofabrik import GEOFABRIK_LAYERS from .hex2vec import HEX2VEC_FILTER from .popular import get_popular_tags -__all__ = ["GEOFABRIK_LAYERS", "HEX2VEC_FILTER", "get_popular_tags"] +__all__ = [ + "grouped_osm_tags_type", + "osm_tags_type", + "merge_grouped_osm_tags_type", + "GEOFABRIK_LAYERS", + "HEX2VEC_FILTER", + "get_popular_tags", +] diff --git a/srai/loaders/osm_loaders/filters/osm_tags_type.py b/srai/loaders/osm_loaders/filters/_typing.py similarity index 100% rename from srai/loaders/osm_loaders/filters/osm_tags_type.py rename to srai/loaders/osm_loaders/filters/_typing.py diff --git a/srai/loaders/osm_loaders/filters/geofabrik.py b/srai/loaders/osm_loaders/filters/geofabrik.py index 1952f870..6dcb7c1c 100644 --- a/srai/loaders/osm_loaders/filters/geofabrik.py +++ b/srai/loaders/osm_loaders/filters/geofabrik.py @@ -9,7 +9,7 @@ References: 1. https://www.geofabrik.de/data/geofabrik-osm-gis-standard-0.7.pdf """ -from srai.loaders.osm_loaders.filters.osm_tags_type import grouped_osm_tags_type +from srai.loaders.osm_loaders.filters._typing import grouped_osm_tags_type GEOFABRIK_LAYERS: grouped_osm_tags_type = { "public": { diff --git a/srai/loaders/osm_loaders/filters/hex2vec.py b/srai/loaders/osm_loaders/filters/hex2vec.py index a0c7a211..0f2ca9e8 100644 --- a/srai/loaders/osm_loaders/filters/hex2vec.py +++ b/srai/loaders/osm_loaders/filters/hex2vec.py @@ -6,7 +6,7 @@ References: 1. https://dl.acm.org/doi/10.1145/3486635.3491076 """ -from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type +from srai.loaders.osm_loaders.filters._typing import osm_tags_type HEX2VEC_FILTER: osm_tags_type = { "aeroway": [ diff --git a/srai/loaders/osm_loaders/filters/popular.py b/srai/loaders/osm_loaders/filters/popular.py index 2ac14d5c..2f4ecce1 100644 --- a/srai/loaders/osm_loaders/filters/popular.py +++ b/srai/loaders/osm_loaders/filters/popular.py @@ -12,7 +12,7 @@ import requests from functional import seq -from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type +from srai.loaders.osm_loaders.filters._typing import osm_tags_type _TAGINFO_API_ADDRESS = "https://taginfo.openstreetmap.org" _TAGINFO_API_TAGS = _TAGINFO_API_ADDRESS + "/api/4/tags/popular" diff --git a/srai/loaders/osm_loaders/osm_online_loader.py b/srai/loaders/osm_loaders/osm_online_loader.py index 24460fde..f7a8f2f5 100644 --- a/srai/loaders/osm_loaders/osm_online_loader.py +++ b/srai/loaders/osm_loaders/osm_online_loader.py @@ -12,11 +12,12 @@ from tqdm import tqdm from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS -from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type +from srai.loaders.osm_loaders._base import OSMLoader +from srai.loaders.osm_loaders.filters._typing import grouped_osm_tags_type, osm_tags_type from srai.utils._optional import import_optional_dependencies -class OSMOnlineLoader: +class OSMOnlineLoader(OSMLoader): """ OSMOnlineLoader. @@ -44,7 +45,7 @@ def __init__(self) -> None: def load( self, area: gpd.GeoDataFrame, - tags: osm_tags_type, + tags: Union[osm_tags_type, grouped_osm_tags_type], ) -> gpd.GeoDataFrame: """ Download OSM features with specified tags for a given area. @@ -57,7 +58,7 @@ def load( Args: area (gpd.GeoDataFrame): Area for which to download objects. - tags (osm_tags_type): A dictionary + tags (Union[osm_tags_type, grouped_osm_tags_type]): A dictionary specifying which tags to download. The keys should be OSM tags (e.g. `building`, `amenity`). The values should either be `True` for retrieving all objects with the tag, @@ -74,7 +75,9 @@ def load( area_wgs84 = area.to_crs(crs=WGS84_CRS) - _tags = self._flatten_tags(tags) + merged_tags = self._merge_osm_tags_filter(tags) + + _tags = self._flatten_tags(merged_tags) total_tags_num = len(_tags) total_queries = len(area) * total_tags_num @@ -92,8 +95,9 @@ def load( results.append(geometries[[GEOMETRY_COLUMN, key]]) result_gdf = self._group_gdfs(results).set_crs(WGS84_CRS) + result_gdf = self._flatten_index(result_gdf) - return self._flatten_index(result_gdf) + return self._parse_features_gdf_to_groups(result_gdf, tags) def _flatten_tags(self, tags: osm_tags_type) -> List[Tuple[str, Union[str, bool]]]: tags_flat: List[Tuple[str, Union[str, bool]]] = ( diff --git a/srai/loaders/osm_loaders/osm_pbf_loader.py b/srai/loaders/osm_loaders/osm_pbf_loader.py index f81717bc..c09f3448 100644 --- a/srai/loaders/osm_loaders/osm_pbf_loader.py +++ b/srai/loaders/osm_loaders/osm_pbf_loader.py @@ -9,12 +9,13 @@ import geopandas as gpd import pandas as pd -from srai.constants import FEATURES_INDEX, WGS84_CRS -from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type +from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS +from srai.loaders.osm_loaders._base import OSMLoader +from srai.loaders.osm_loaders.filters._typing import grouped_osm_tags_type, osm_tags_type from srai.utils._optional import import_optional_dependencies -class OSMPbfLoader: +class OSMPbfLoader(OSMLoader): """ OSMPbfLoader. @@ -54,7 +55,7 @@ def __init__( def load( self, area: gpd.GeoDataFrame, - tags: osm_tags_type, + tags: Union[osm_tags_type, grouped_osm_tags_type], ) -> gpd.GeoDataFrame: """ Load OSM features with specified tags for a given area from an `*.osm.pbf` file. @@ -74,7 +75,7 @@ def load( Args: area (gpd.GeoDataFrame): Area for which to download objects. - tags (osm_tags_type): A dictionary + tags (Union[osm_tags_type, grouped_osm_tags_type]): A dictionary specifying which tags to download. The keys should be OSM tags (e.g. `building`, `amenity`). The values should either be `True` for retrieving all objects with the tag, @@ -102,7 +103,9 @@ def load( clipping_polygon = area_wgs84.geometry.unary_union - pbf_handler = PbfFileHandler(tags=tags, region_geometry=clipping_polygon) + merged_tags = self._merge_osm_tags_filter(tags) + + pbf_handler = PbfFileHandler(tags=merged_tags, region_geometry=clipping_polygon) results = [] for region_id, pbf_files in downloaded_pbf_files.items(): @@ -113,10 +116,10 @@ def load( result_gdf = self._group_gdfs(results).set_crs(WGS84_CRS) - features_columns = result_gdf.columns.drop(labels=["geometry"]).sort_values() - result_gdf = result_gdf[["geometry", *features_columns]] + features_columns = result_gdf.columns.drop(labels=[GEOMETRY_COLUMN]).sort_values() + result_gdf = result_gdf[[GEOMETRY_COLUMN, *features_columns]] - return result_gdf + return self._parse_features_gdf_to_groups(result_gdf, tags) def _group_gdfs(self, gdfs: List[gpd.GeoDataFrame]) -> gpd.GeoDataFrame: if not gdfs: diff --git a/srai/loaders/osm_loaders/pbf_file_handler.py b/srai/loaders/osm_loaders/pbf_file_handler.py index 6a9d2a18..63d9e267 100644 --- a/srai/loaders/osm_loaders/pbf_file_handler.py +++ b/srai/loaders/osm_loaders/pbf_file_handler.py @@ -16,8 +16,7 @@ from tqdm import tqdm from srai.constants import FEATURES_INDEX, WGS84_CRS -from srai.loaders.osm_loaders.filters.hex2vec import HEX2VEC_FILTER -from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type +from srai.loaders.osm_loaders.filters._typing import osm_tags_type if TYPE_CHECKING: import os @@ -47,7 +46,7 @@ class PbfFileHandler(osmium.SimpleHandler): # type: ignore def __init__( self, - tags: Optional[osm_tags_type] = HEX2VEC_FILTER, + tags: osm_tags_type, region_geometry: Optional[BaseGeometry] = None, ) -> None: """ @@ -64,7 +63,6 @@ def __init__( `tags={'leisure': 'park, 'amenity': True, 'shop': ['bakery', 'bicycle']}` would return parks, all amenity types, bakeries and bicycle shops. If `None`, handler will allow all of the tags to be parsed. - Defaults to the predefined HEX2VEC_FILTER. region_geometry (BaseGeometry, optional): Region which can be used to filter only intersecting OSM objects. Defaults to None. """ diff --git a/tests/loaders/osm_loaders/filters/test_merge_filter_types.py b/tests/loaders/osm_loaders/filters/test_merge_filter_types.py index af456ea1..de5c295d 100644 --- a/tests/loaders/osm_loaders/filters/test_merge_filter_types.py +++ b/tests/loaders/osm_loaders/filters/test_merge_filter_types.py @@ -3,7 +3,7 @@ import pytest -from srai.loaders.osm_loaders.filters.osm_tags_type import ( +from srai.loaders.osm_loaders.filters._typing import ( grouped_osm_tags_type, merge_grouped_osm_tags_type, osm_tags_type, diff --git a/tests/loaders/osm_loaders/test_osm_online_loader.py b/tests/loaders/osm_loaders/test_osm_online_loader.py index 28e31895..f604745d 100644 --- a/tests/loaders/osm_loaders/test_osm_online_loader.py +++ b/tests/loaders/osm_loaders/test_osm_online_loader.py @@ -7,7 +7,7 @@ from srai.constants import WGS84_CRS from srai.loaders.osm_loaders import OSMOnlineLoader -from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type +from srai.loaders.osm_loaders.filters._typing import osm_tags_type if TYPE_CHECKING: from shapely.geometry import Polygon diff --git a/tests/loaders/osm_loaders/test_osm_pbf_loader.py b/tests/loaders/osm_loaders/test_osm_pbf_loader.py index 33e8d928..60aaf803 100644 --- a/tests/loaders/osm_loaders/test_osm_pbf_loader.py +++ b/tests/loaders/osm_loaders/test_osm_pbf_loader.py @@ -9,8 +9,8 @@ from srai.constants import REGIONS_INDEX, WGS84_CRS from srai.loaders.osm_loaders import OSMPbfLoader +from srai.loaders.osm_loaders.filters._typing import osm_tags_type from srai.loaders.osm_loaders.filters.hex2vec import HEX2VEC_FILTER -from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type from srai.loaders.osm_loaders.pbf_file_downloader import PbfFileDownloader from srai.loaders.osm_loaders.pbf_file_handler import PbfFileHandler From 481aa5f43570f2ca600a9719818b2cbdc333edc1 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Tue, 18 Apr 2023 23:04:58 +0200 Subject: [PATCH 14/30] refactor: change condition syntax --- srai/loaders/osm_loaders/_base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/srai/loaders/osm_loaders/_base.py b/srai/loaders/osm_loaders/_base.py index 79a55259..d4ad1391 100644 --- a/srai/loaders/osm_loaders/_base.py +++ b/srai/loaders/osm_loaders/_base.py @@ -159,13 +159,15 @@ def _get_first_matching_osm_tag_value( if osm_tag_key not in row or pd.isna(row[osm_tag_key]): continue - if isinstance(osm_tag_value, bool) and osm_tag_value: - return f"{osm_tag_key}={row[osm_tag_key]}" - - if isinstance(osm_tag_value, str) and row[osm_tag_key] == osm_tag_value: - return f"{osm_tag_key}={row[osm_tag_key]}" + is_matching_bool_filter = isinstance(osm_tag_value, bool) and osm_tag_value + is_matching_string_filter = ( + isinstance(osm_tag_value, str) and row[osm_tag_key] == osm_tag_value + ) + is_matching_list_filter = ( + isinstance(osm_tag_value, list) and row[osm_tag_key] in osm_tag_value + ) - if isinstance(osm_tag_value, list) and row[osm_tag_key] in osm_tag_value: + if is_matching_bool_filter or is_matching_string_filter or is_matching_list_filter: return f"{osm_tag_key}={row[osm_tag_key]}" return None From d1a341bc401ad2d6bcafb2ddec3f7c69a652c59f Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Wed, 19 Apr 2023 08:27:47 +0200 Subject: [PATCH 15/30] chore: modified tests after removing default value --- tests/loaders/osm_loaders/test_osm_pbf_loader.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/loaders/osm_loaders/test_osm_pbf_loader.py b/tests/loaders/osm_loaders/test_osm_pbf_loader.py index 60aaf803..99767eb7 100644 --- a/tests/loaders/osm_loaders/test_osm_pbf_loader.py +++ b/tests/loaders/osm_loaders/test_osm_pbf_loader.py @@ -193,18 +193,8 @@ def test_pbf_handler_geometry_filtering(): # type: ignore @pytest.mark.parametrize( # type: ignore "test_geometries,pbf_file,query,expected_result_length,expected_features_columns_length", [ - ([Point([(-73.981883, 40.768081)])], None, None, 10, 38), - ([Polygon([(0, 0), (0, 1), (1, 1), (1, 0)])], None, None, 0, 0), + ([Polygon([(0, 0), (0, 1), (1, 1), (1, 0)])], None, HEX2VEC_FILTER, 0, 0), ([Point([(-73.981883, 40.768081)])], None, HEX2VEC_FILTER, 2, 3), - ( - [Point([(-73.981883, 40.768081)])], - Path(__file__).parent - / "test_files" - / "d17f922ed15e9609013a6b895e1e7af2d49158f03586f2c675d17b760af3452e.osm.pbf", - None, - 10, - 38, - ), ( [Point([(-73.981883, 40.768081)])], Path(__file__).parent @@ -219,7 +209,7 @@ def test_pbf_handler_geometry_filtering(): # type: ignore Path(__file__).parent / "test_files" / "d17f922ed15e9609013a6b895e1e7af2d49158f03586f2c675d17b760af3452e.osm.pbf", - None, + HEX2VEC_FILTER, 0, 0, ), From 9eace7cbb57cfc8313e27556a0f53842b17dfb87 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Wed, 19 Apr 2023 16:34:05 +0200 Subject: [PATCH 16/30] chore: add tests for merging filters --- srai/loaders/osm_loaders/filters/_typing.py | 7 ++ .../filters/test_merge_filter_types.py | 103 +++++++++++++++--- 2 files changed, 94 insertions(+), 16 deletions(-) diff --git a/srai/loaders/osm_loaders/filters/_typing.py b/srai/loaders/osm_loaders/filters/_typing.py index 2b0d2b12..bd480ba5 100644 --- a/srai/loaders/osm_loaders/filters/_typing.py +++ b/srai/loaders/osm_loaders/filters/_typing.py @@ -1,6 +1,8 @@ """Module contains a dedicated type alias for OSM tags filter.""" from typing import Dict, List, Union, cast +from srai.utils.typing import is_expected_type + osm_tags_type = Dict[str, Union[List[str], str, bool]] grouped_osm_tags_type = Dict[str, Dict[str, Union[List[str], str, bool]]] @@ -18,6 +20,11 @@ def merge_grouped_osm_tags_type(grouped_filter: grouped_osm_tags_type) -> osm_ta Returns: osm_tags_type: Merged filter. """ + if not is_expected_type(grouped_filter, grouped_osm_tags_type): + raise AttributeError( + "Provided filter doesn't match required `grouped_osm_tags_type` definition." + ) + result: osm_tags_type = {} for sub_filter in grouped_filter.values(): for osm_tag_key, osm_tag_value in sub_filter.items(): diff --git a/tests/loaders/osm_loaders/filters/test_merge_filter_types.py b/tests/loaders/osm_loaders/filters/test_merge_filter_types.py index de5c295d..07894cfb 100644 --- a/tests/loaders/osm_loaders/filters/test_merge_filter_types.py +++ b/tests/loaders/osm_loaders/filters/test_merge_filter_types.py @@ -1,10 +1,11 @@ """Tests for merging OSM Loaders filters.""" +from contextlib import nullcontext as does_not_raise +from typing import Any from unittest import TestCase import pytest from srai.loaders.osm_loaders.filters._typing import ( - grouped_osm_tags_type, merge_grouped_osm_tags_type, osm_tags_type, ) @@ -12,19 +13,89 @@ ut = TestCase() -@pytest.fixture # type: ignore -def expected_result_min_count_8m() -> osm_tags_type: - """Get expected results when using `min_count=8_000_000`.""" - return { - "natural": ["wood"], - "landuse": ["farmland", "residential"], - } - - -def test_merge_grouped_filters() -> None: +@pytest.mark.parametrize( # type: ignore + "grouped_filter,expected_result_filter,expectation", + [ + ({"tag_a": True}, {"tag_a": True}, pytest.raises(AttributeError)), + ({"tag_a": "A"}, {"tag_a": "A"}, pytest.raises(AttributeError)), + ({"tag_a": ["A"]}, {"tag_a": ["A"]}, pytest.raises(AttributeError)), + ({}, {}, does_not_raise()), + ({"group_a": {}}, {}, does_not_raise()), + ({"group_a": {"tag_a": True}}, {"tag_a": True}, does_not_raise()), + ({"group_a": {"tag_a": "A"}}, {"tag_a": ["A"]}, does_not_raise()), + ({"group_a": {"tag_a": ["A"]}}, {"tag_a": ["A"]}, does_not_raise()), + ( + {"group_a": {"tag_a": "A", "tag_b": "B"}}, + {"tag_a": ["A"], "tag_b": ["B"]}, + does_not_raise(), + ), + ( + {"group_a": {"tag_a": ["A"], "tag_b": ["B"]}}, + {"tag_a": ["A"], "tag_b": ["B"]}, + does_not_raise(), + ), + ( + { + "group_a": {"tag_a": "A", "tag_b": "B"}, + "group_b": {"tag_a": "A", "tag_b": "B"}, + }, + {"tag_a": ["A"], "tag_b": ["B"]}, + does_not_raise(), + ), + ( + { + "group_a": {"tag_a": "A", "tag_b": "B"}, + "group_b": {"tag_c": "C", "tag_d": "D"}, + }, + {"tag_a": ["A"], "tag_b": ["B"], "tag_c": ["C"], "tag_d": ["D"]}, + does_not_raise(), + ), + ( + { + "group_a": {"tag_a": "A", "tag_b": "B"}, + "group_b": {"tag_a": "C", "tag_b": "D"}, + }, + {"tag_a": ["A", "C"], "tag_b": ["B", "D"]}, + does_not_raise(), + ), + ( + { + "group_a": {"tag_a": "A", "tag_b": "B"}, + "group_b": {"tag_a": ["C", "D"], "tag_b": "E"}, + }, + {"tag_a": ["A", "C", "D"], "tag_b": ["B", "E"]}, + does_not_raise(), + ), + ( + { + "group_a": {"tag_a": "A", "tag_b": "B"}, + "group_b": {"tag_a": ["C", "D"], "tag_b": True}, + }, + {"tag_a": ["A", "C", "D"], "tag_b": True}, + does_not_raise(), + ), + ( + { + "group_a": {"tag_a": ["A"], "tag_b": ["B"]}, + "group_b": {"tag_a": ["C", "D"], "tag_b": False}, + }, + {"tag_a": ["A", "C", "D"], "tag_b": ["B"]}, + does_not_raise(), + ), + ( + { + "group_a": {"tag_a": ["A", "C"], "tag_b": ["B", "E"]}, + "group_b": {"tag_a": ["C", "D"], "tag_b": ["B"]}, + }, + {"tag_a": ["A", "C", "D"], "tag_b": ["B", "E"]}, + does_not_raise(), + ), + ], +) +def test_merge_grouped_filters( + grouped_filter: Any, expected_result_filter: osm_tags_type, expectation: Any +) -> None: """Test merging grouped tags filter into a base osm filter.""" - base_filter: osm_tags_type = {} - grouped_filter: grouped_osm_tags_type = {} - merged_filters = merge_grouped_osm_tags_type(grouped_filter) - - ut.assertDictEqual(base_filter, merged_filters) + with expectation: + merged_filters = merge_grouped_osm_tags_type(grouped_filter) + ut.assertDictEqual(expected_result_filter, merged_filters) From aefc88b43880987b43a337fe55397c9b65495876 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Wed, 19 Apr 2023 18:57:28 +0200 Subject: [PATCH 17/30] style: changed references formatting --- srai/embedders/contextual_count_embedder.py | 6 +++--- srai/embedders/gtfs2vec_embedder.py | 2 +- srai/embedders/highway2vec/embedder.py | 2 +- srai/embedders/highway2vec/model.py | 2 +- srai/loaders/gtfs_loader.py | 4 ++-- srai/loaders/osm_way_loader/constants.py | 8 ++++---- srai/loaders/osm_way_loader/osm_way_loader.py | 2 +- srai/models/gtfs2vec_model.py | 2 +- 8 files changed, 14 insertions(+), 14 deletions(-) diff --git a/srai/embedders/contextual_count_embedder.py b/srai/embedders/contextual_count_embedder.py index e6cacfc0..b76460d8 100644 --- a/srai/embedders/contextual_count_embedder.py +++ b/srai/embedders/contextual_count_embedder.py @@ -1,11 +1,11 @@ """ Contextual Count Embedder. -This module contains contextual count embedder implementation from ARIC@SIGSPATIAL 2021 paper[1]. +This module contains contextual count embedder implementation from ARIC@SIGSPATIAL 2021 paper [1]. References: - [1] https://doi.org/10.1145/3486626.3493434 - [1] https://arxiv.org/abs/2111.00990 + 1. https://doi.org/10.1145/3486626.3493434 + 1. https://arxiv.org/abs/2111.00990 """ from typing import List, Optional diff --git a/srai/embedders/gtfs2vec_embedder.py b/srai/embedders/gtfs2vec_embedder.py index 876a7919..e08e7d1f 100644 --- a/srai/embedders/gtfs2vec_embedder.py +++ b/srai/embedders/gtfs2vec_embedder.py @@ -4,7 +4,7 @@ This module contains embedder from gtfs2vec paper [1]. References: - [1] https://doi.org/10.1145/3486640.3491392 + 1. https://doi.org/10.1145/3486640.3491392 """ diff --git a/srai/embedders/highway2vec/embedder.py b/srai/embedders/highway2vec/embedder.py index 8eb60767..cc81673b 100644 --- a/srai/embedders/highway2vec/embedder.py +++ b/srai/embedders/highway2vec/embedder.py @@ -4,7 +4,7 @@ This module contains the embedder from the `highway2vec` paper [1]. References: - [1] https://doi.org/10.1145/3557918.3565865 + 1. https://doi.org/10.1145/3557918.3565865 """ from typing import Any, Dict, Optional diff --git a/srai/embedders/highway2vec/model.py b/srai/embedders/highway2vec/model.py index f89e4d79..19e6e2a6 100644 --- a/srai/embedders/highway2vec/model.py +++ b/srai/embedders/highway2vec/model.py @@ -4,7 +4,7 @@ This module contains the embedding model from the `highway2vec` paper [1]. References: - [1] https://doi.org/10.1145/3557918.3565865 + 1. https://doi.org/10.1145/3557918.3565865 """ import pytorch_lightning as pl import torch diff --git a/srai/loaders/gtfs_loader.py b/srai/loaders/gtfs_loader.py index 1b669def..69509719 100644 --- a/srai/loaders/gtfs_loader.py +++ b/srai/loaders/gtfs_loader.py @@ -6,8 +6,8 @@ the gtfs2vec project [2]. References: - [1] https://gitlab.com/mrcagney/gtfs_kit - [2] https://doi.org/10.1145/3486640.3491392 + 1. https://gitlab.com/mrcagney/gtfs_kit + 2. https://doi.org/10.1145/3486640.3491392 """ from pathlib import Path diff --git a/srai/loaders/osm_way_loader/constants.py b/srai/loaders/osm_way_loader/constants.py index 7b7e7be3..a3045901 100644 --- a/srai/loaders/osm_way_loader/constants.py +++ b/srai/loaders/osm_way_loader/constants.py @@ -191,8 +191,8 @@ Assembled values that are officially defined in the wiki and are related to the `way` OSM type. References: - [1] https://taginfo.openstreetmap.org/ - [2] https://wiki.openstreetmap.org/wiki/Main_Page + 1. https://taginfo.openstreetmap.org/ + 2. https://wiki.openstreetmap.org/wiki/Main_Page """ @@ -340,6 +340,6 @@ Consider using better (full) table in machine readable format available in [2]. References: - [1] https://wiki.openstreetmap.org/wiki/Key:maxspeed#Implicit_maxspeed_values - [2] https://wiki.openstreetmap.org/wiki/Default_speed_limits + 1. https://wiki.openstreetmap.org/wiki/Key:maxspeed#Implicit_maxspeed_values + 2. https://wiki.openstreetmap.org/wiki/Default_speed_limits """ diff --git a/srai/loaders/osm_way_loader/osm_way_loader.py b/srai/loaders/osm_way_loader/osm_way_loader.py index d65b24c5..8a29a548 100644 --- a/srai/loaders/osm_way_loader/osm_way_loader.py +++ b/srai/loaders/osm_way_loader/osm_way_loader.py @@ -30,7 +30,7 @@ class NetworkType(str, Enum): See [1] for more details. References: - [1] https://osmnx.readthedocs.io/en/stable/osmnx.html#osmnx.graph.graph_from_place + 1. https://osmnx.readthedocs.io/en/stable/osmnx.html#osmnx.graph.graph_from_place """ ALL_PRIVATE = "all_private" diff --git a/srai/models/gtfs2vec_model.py b/srai/models/gtfs2vec_model.py index 79a891e8..4928037c 100644 --- a/srai/models/gtfs2vec_model.py +++ b/srai/models/gtfs2vec_model.py @@ -4,7 +4,7 @@ This module contains embedding model from gtfs2vec paper [1]. References: - [1] https://doi.org/10.1145/3486640.3491392 + 1. https://doi.org/10.1145/3486640.3491392 """ from typing import Any From 523784bc88ebbfb249dcdfccb62f010bbd360f9c Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Wed, 19 Apr 2023 19:02:20 +0200 Subject: [PATCH 18/30] feat: added base osm group filter --- .../osm_loaders/filters/base_osm_groups.py | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 srai/loaders/osm_loaders/filters/base_osm_groups.py diff --git a/srai/loaders/osm_loaders/filters/base_osm_groups.py b/srai/loaders/osm_loaders/filters/base_osm_groups.py new file mode 100644 index 00000000..7a35fd36 --- /dev/null +++ b/srai/loaders/osm_loaders/filters/base_osm_groups.py @@ -0,0 +1,209 @@ +""" +Base OSM groups filter. + +This module contains the grouped OSM tags filter that was used in ARIC@SIGSPATIAL 2021 paper [1]. + +References: + 1. https://doi.org/10.1145/3486626.3493434 + 1. https://arxiv.org/abs/2111.00990 +""" +from srai.loaders.osm_loaders.filters._typing import grouped_osm_tags_type + +BASE_OSM_GROUPS_FILTER: grouped_osm_tags_type = { + "water": {"natual": ["water", "bay", "beach", "coastline"], "waterway": ["riverbank"]}, + "aerialway": { + "aerialway": [ + "cable_car", + "gondola", + "mixed_lift", + "chair_lift", + "drag_lift", + "t-bar", + "j-bar", + "platter", + "rope_tow", + "magic_carpet", + "zip_line", + "goods", + "station", + ] + }, + "airports": {"aeroway": ["aerodrome", "heliport", "spaceport"]}, + "sustenance": { + "amenity": [ + "bar", + "bbq", + "biergarten", + "cafe", + "fast_food", + "food_court", + "ice_cream", + "pub", + "restaurant", + ] + }, + "education": { + "amenity": [ + "college", + "driving_school", + "kindergarten", + "language_school", + "library", + "toy_library", + "music_school", + "school", + "university", + ] + }, + "transportation": { + "amenity": [ + "bicycle_parking", + "bicycle_repair_station", + "bicycle_rental", + "boat_rental", + "boat_sharing", + "car_rental", + "car_sharing", + "car_wash", + "charging_station", + "bus_stop", + "ferry_terminal", + "fuel", + "motorcycle_parking", + "parking", + "taxi", + "bus_station", + ], + "public_transport": ["station", "stop_position"], + "railway": ["station", "subway_entrance", "tram_stop"], + "building": ["train_station"], + "highway": ["bus_stop"], + }, + "finances": {"amenity": ["atm", "bank", "bureau_de_change"]}, + "healthcare": { + "amenity": [ + "baby_hatch", + "clinic", + "dentist", + "doctors", + "hospital", + "nursing_home", + "pharmacy", + "social_facility", + "veterinary", + ] + }, + "culture_art_entertainment": { + "amenity": [ + "arts_centre", + "brothel", + "casino", + "cinema", + "community_centre", + "gambling", + "nightclub", + "planetarium", + "public_bookcase", + "social_centre", + "stripclub", + "studio", + "theatre", + ] + }, + "other": { + "amenity": [ + "animal_boarding", + "animal_shelter", + "childcare", + "conference_centre", + "courthouse", + "crematorium", + "embassy", + "fire_station", + "grave_yard", + "internet_cafe", + "marketplace", + "monastery", + "place_of_worship", + "police", + "post_office", + "prison", + "ranger_station", + "refugee_site", + "townhall", + ] + }, + "buildings": { + "building": ["commercial", "industrial", "warehouse"], + "office": True, + "waterway": ["dock", "boatyard"], + }, + "emergency": {"emergency": ["ambulance_station", "defibrillator", "landing_site"]}, + "historic": { + "historic": [ + "aqueduct", + "battlefield", + "building", + "castle", + "church", + "citywalls", + "fort", + "memorial", + "monastery", + "monument", + "ruins", + "tower", + ] + }, + "leisure": { + "leisure": [ + "adult_gaming_centre", + "amusement_arcade", + "beach_resort", + "common", + "dance", + "dog_park", + "escape_game", + "fitness_centre", + "fitness_station", + "garden", + "hackerspace", + "horse_riding", + "ice_rink", + "marina", + "miniature_golf", + "nature_reserve", + "park", + "pitch", + "slipway", + "sports_centre", + "stadium", + "summer_camp", + "swimming_area", + "swimming_pool", + "track", + "water_park", + ], + "amenity": ["public_bath", "dive_centre"], + }, + "shops": {"shop": True}, + "sport": {"sport": True}, + "tourism": {"tourism": True}, + "greenery": { + "leisure": ["park"], + "natural": ["grassland", "scrub"], + "landuse": [ + "grass", + "allotments", + "forest", + "flowerbed", + "meadow", + "village_green", + "grassland", + "scrub", + "garden", + "park", + "recreation_ground", + ], + }, +} From 16d44e33477adfa03985ff159407e7afc66838b6 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Wed, 19 Apr 2023 20:25:16 +0200 Subject: [PATCH 19/30] docs: change examples --- .../embedders/contextual_count_embedder.ipynb | 10 ++- examples/embedders/count_embedder.ipynb | 45 +++++++--- examples/loaders/osm_online_loader.ipynb | 34 +++++-- examples/loaders/osm_pbf_loader.ipynb | 89 ++++++++++++++++++- srai/loaders/osm_loaders/_base.py | 5 +- srai/loaders/osm_loaders/filters/__init__.py | 2 + 6 files changed, 160 insertions(+), 25 deletions(-) diff --git a/examples/embedders/contextual_count_embedder.ipynb b/examples/embedders/contextual_count_embedder.ipynb index 5ac83a51..c33dd68b 100644 --- a/examples/embedders/contextual_count_embedder.ipynb +++ b/examples/embedders/contextual_count_embedder.ipynb @@ -62,7 +62,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Download some objects from OpenStreetMap" + "### Download some objects from OpenStreetMap\n", + "\n", + "You can use both `osm_tags_type` and `grouped_osm_tags_type` filters. In this example, a predefined `grouped_osm_tags_type` filter `BASE_OSM_GROUPS_FILTER` is used." ] }, { @@ -71,10 +73,10 @@ "metadata": {}, "outputs": [], "source": [ - "from srai.loaders.osm_loaders.filters.hex2vec import HEX2VEC_FILTER\n", + "from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER\n", "\n", "loader = OSMPbfLoader()\n", - "features_gdf = loader.load(area_gdf, tags=HEX2VEC_FILTER)\n", + "features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)\n", "features_gdf" ] }, @@ -187,7 +189,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot_numeric_data(regions_gdf, embeddings, \"building\", tiles_style=\"CartoDB positron\")" + "plot_numeric_data(regions_gdf, embeddings, \"transportation\", tiles_style=\"CartoDB positron\")" ] } ], diff --git a/examples/embedders/count_embedder.ipynb b/examples/embedders/count_embedder.ipynb index 8a5d00c7..df2598de 100644 --- a/examples/embedders/count_embedder.ipynb +++ b/examples/embedders/count_embedder.ipynb @@ -65,7 +65,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Regionize the area using a H3Regionizer" + "### Regionize the area using an H3Regionizer" ] }, { @@ -85,7 +85,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Download some objects from OpenStreetMap" + "### Download some objects from OpenStreetMap\n", + "\n", + "You can use both `osm_tags_type` and `grouped_osm_tags_type` filters. In this example, a simple `osm_tags_type` filter is used." ] }, { @@ -128,7 +130,7 @@ "outputs": [], "source": [ "joiner = IntersectionJoiner()\n", - "joint_gdf = joiner.transform(regions_gdf, features_gdf)\n", + "joint_gdf = joiner.transform(regions_gdf, features_gdf, return_geom=True)\n", "joint_gdf" ] }, @@ -145,10 +147,24 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## Embed using features existing in data" + "## Embed using features existing in data\n", + "\n", + "Count Embedder can count features on a higher level (tag key) or separately for each value (tag key and value). Both examples are shown below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wide_embedder = CountEmbedder(count_subcategories=True)\n", + "wide_embedding = wide_embedder.transform(regions_gdf, features_gdf, joint_gdf)\n", + "wide_embedding" ] }, { @@ -157,9 +173,9 @@ "metadata": {}, "outputs": [], "source": [ - "embedder = CountEmbedder()\n", - "embedding = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n", - "embedding" + "dense_embedder = CountEmbedder(count_subcategories=False)\n", + "dense_embedding = dense_embedder.transform(regions_gdf, features_gdf, joint_gdf)\n", + "dense_embedding" ] }, { @@ -176,7 +192,12 @@ "outputs": [], "source": [ "embedder = CountEmbedder(\n", - " expected_output_features=[\"amenity_parking\", \"leisure_park\", \"amenity_pub\"]\n", + " expected_output_features=[\n", + " \"amenity_parking\",\n", + " \"leisure_park\",\n", + " \"leisure_playground\",\n", + " \"amenity_pub\",\n", + " ]\n", ")\n", "embedding_expected_features = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n", "embedding_expected_features" @@ -189,7 +210,7 @@ "outputs": [], "source": [ "plot_numeric_data(\n", - " regions_gdf, embedding_expected_features, \"amenity_pub\", tiles_style=\"CartoDB positron\"\n", + " regions_gdf, embedding_expected_features, \"leisure_playground\", tiles_style=\"CartoDB positron\"\n", ")" ] }, @@ -197,10 +218,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The resulting embedding contains only the columns specified in ```expected_output_features```. \n", - "The ones that were not present in the data (```leisure_park```, ```amenity_parking```) are added and filled with zeros. \n", + "The resulting embedding contains only the columns specified in `expected_output_features`. \n", + "The ones that were not present in the data (`leisure_park`, `amenity_parking`) are added and filled with zeros. \n", "The features that are both expected and present in the data are counted as usual. \n", - "The ones that are present in the data but are not expected (```leisure_adult_gaming_centre```, ```leisure_playground```) are discarded." + "The ones that are present in the data but are not expected (`leisure_adult_gaming_centre`) are discarded." ] } ], diff --git a/examples/loaders/osm_online_loader.ipynb b/examples/loaders/osm_online_loader.ipynb index 6a608ee1..b96f02a8 100644 --- a/examples/loaders/osm_online_loader.ipynb +++ b/examples/loaders/osm_online_loader.ipynb @@ -15,7 +15,7 @@ "outputs": [], "source": [ "from srai.loaders.osm_loaders.filters.popular import get_popular_tags\n", - "from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER\n", + "from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS, HEX2VEC_FILTER\n", "from srai.loaders.osm_loaders import OSMOnlineLoader\n", "from srai.utils import geocode_to_region_gdf\n", "from srai.plotting.folium_wrapper import plot_regions\n", @@ -30,7 +30,10 @@ "## Filters\n", "Filters are dictionaries used for specifying what type of objects one would like to download from OpenStreetMap. \n", "There is currently one predefined filter (from Hex2Vec paper) and one way to download a filter - using popular tags from taginfo API. \n", - "They can also be defined manually in code." + "They can also be defined manually in code.\n", + "\n", + "Additionally, few predefined grouped filters are available (eg. `BASE_OSM_GROUPS_FILTER` and `GEOFABRIK_LAYERS`).\n", + "Grouped filters categorize base filters into groups." ] }, { @@ -91,6 +94,19 @@ "f\"Unique keys: {hex_2_vec_keys}. Key/value pairs: {hex_2_vec_key_values}\"" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "geofabrik_layers_keys = len(GEOFABRIK_LAYERS)\n", + "geofabrik_layers_key_values = (\n", + " seq(GEOFABRIK_LAYERS.values()).flat_map(lambda filter: filter.items()).map(len).sum()\n", + ")\n", + "f\"Unique groups: {geofabrik_layers_keys}. Key/value pairs: {geofabrik_layers_key_values}\"" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -135,7 +151,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Download hotels, bars, cafes, pubs and soccer related objects in Barcelona" + "### Download hotels, bars, cafes, pubs and sport related objects in Barcelona\n", + "\n", + "Uses grouped filters as an example." ] }, { @@ -145,7 +163,11 @@ "outputs": [], "source": [ "barcelona_gdf = geocode_to_region_gdf(\"Barcelona\")\n", - "barcelona_filter = {\"building\": \"hotel\", \"amenity\": [\"bar\", \"cafe\", \"pub\"], \"sport\": \"soccer\"}\n", + "barcelona_filter = {\n", + " \"tourism\": {\"building\": \"hotel\", \"amenity\": [\"bar\", \"cafe\", \"pub\"]},\n", + " \"sport\": {\"sport\": \"soccer\", \"leisure\": [\"pitch\", \"sports_centre\", \"stadium\"]},\n", + "}\n", + "\n", "barcelona_objects_gdf = loader.load(barcelona_gdf, barcelona_filter)\n", "barcelona_objects_gdf" ] @@ -155,7 +177,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Amenities" + "### Tourism group" ] }, { @@ -165,7 +187,7 @@ "outputs": [], "source": [ "folium_map = plot_regions(barcelona_gdf, colormap=[\"lightgray\"], tiles_style=\"CartoDB positron\")\n", - "barcelona_objects_gdf.query(\"amenity.notna()\").explore(\n", + "barcelona_objects_gdf.query(\"tourism.notna()\").explore(\n", " m=folium_map,\n", " color=\"orangered\",\n", " marker_kwds=dict(radius=1),\n", diff --git a/examples/loaders/osm_pbf_loader.ipynb b/examples/loaders/osm_pbf_loader.ipynb index ee521c38..ad923990 100644 --- a/examples/loaders/osm_pbf_loader.ipynb +++ b/examples/loaders/osm_pbf_loader.ipynb @@ -20,13 +20,13 @@ "metadata": {}, "outputs": [], "source": [ - "from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER\n", + "from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER, GEOFABRIK_LAYERS\n", "from srai.loaders.osm_loaders.filters.popular import get_popular_tags\n", "from srai.loaders.osm_loaders import OSMPbfLoader\n", "from srai.constants import REGIONS_INDEX, WGS84_CRS\n", "from srai.utils import buffer_geometry, geocode_to_region_gdf\n", "\n", - "from shapely.geometry import Point\n", + "from shapely.geometry import Point, box\n", "import geopandas as gpd" ] }, @@ -222,6 +222,86 @@ "\n", "ax.set_axis_off()" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download all grouped features based on Geofabrik layers in New York, USA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "manhattan_bbox = box(-73.994551, 40.762396, -73.936872, 40.804239)\n", + "manhattan_bbox_gdf = gpd.GeoDataFrame(\n", + " geometry=[manhattan_bbox],\n", + " crs=WGS84_CRS,\n", + " index=gpd.pd.Index(data=[\"New York\"], name=REGIONS_INDEX),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loader = OSMPbfLoader()\n", + "new_york_features_gdf = loader.load(manhattan_bbox_gdf, GEOFABRIK_LAYERS)\n", + "new_york_features_gdf" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot features\n", + "\n", + "Inspired by https://snazzymaps.com/style/14889/flat-pale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ax = manhattan_bbox_gdf.plot(color=\"#e7e7df\", figsize=(16, 16))\n", + "\n", + "# plot greenery\n", + "new_york_features_gdf[new_york_features_gdf[\"leisure\"] == \"leisure=park\"].plot(\n", + " ax=ax, color=\"#bae5ce\"\n", + ")\n", + "\n", + "# plot water\n", + "new_york_features_gdf.dropna(subset=[\"water\", \"waterways\"], how=\"all\").plot(ax=ax, color=\"#c7eced\")\n", + "\n", + "# plot streets\n", + "new_york_features_gdf.dropna(subset=[\"paths_unsuitable_for_cars\"], how=\"all\").plot(\n", + " ax=ax, color=\"#e7e7df\", linewidth=1\n", + ")\n", + "new_york_features_gdf.dropna(\n", + " subset=[\"very_small_roads\", \"highway_links\", \"minor_roads\"], how=\"all\"\n", + ").plot(ax=ax, color=\"#fff\", linewidth=2)\n", + "new_york_features_gdf.dropna(subset=[\"major_roads\"], how=\"all\").plot(\n", + " ax=ax, color=\"#fac9a9\", linewidth=3\n", + ")\n", + "\n", + "# plot buildings\n", + "new_york_features_gdf.dropna(subset=[\"buildings\"], how=\"all\").plot(ax=ax, color=\"#cecebd\")\n", + "\n", + "xmin, ymin, xmax, ymax = manhattan_bbox_gdf.total_bounds\n", + "ax.set_xlim(xmin, xmax)\n", + "ax.set_ylim(ymin, ymax)\n", + "\n", + "ax.set_axis_off()" + ] } ], "metadata": { @@ -241,6 +321,11 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" + }, + "vscode": { + "interpreter": { + "hash": "4153976b658cb8b76d04b10dc7a0c871c2dac1d3dcfe690ad61d83a61969a12e" + } } }, "nbformat": 4, diff --git a/srai/loaders/osm_loaders/_base.py b/srai/loaders/osm_loaders/_base.py index d4ad1391..bd90192a 100644 --- a/srai/loaders/osm_loaders/_base.py +++ b/srai/loaders/osm_loaders/_base.py @@ -4,6 +4,7 @@ from typing import Dict, Optional, Union, cast import geopandas as gpd +import numpy as np import pandas as pd from tqdm import tqdm @@ -112,7 +113,9 @@ def _group_features_gdf( column for column in group_filter.keys() if column in features_gdf.columns ] - return features_gdf[["geometry", *matching_columns]] + return features_gdf[["geometry", *matching_columns]].replace( + to_replace=[None], value=np.nan + ) def _get_osm_filter_groups( self, row: pd.Series, group_filter: grouped_osm_tags_type diff --git a/srai/loaders/osm_loaders/filters/__init__.py b/srai/loaders/osm_loaders/filters/__init__.py index b1d266f6..0801e89d 100644 --- a/srai/loaders/osm_loaders/filters/__init__.py +++ b/srai/loaders/osm_loaders/filters/__init__.py @@ -1,6 +1,7 @@ """Filters.""" from ._typing import grouped_osm_tags_type, merge_grouped_osm_tags_type, osm_tags_type +from .base_osm_groups import BASE_OSM_GROUPS_FILTER from .geofabrik import GEOFABRIK_LAYERS from .hex2vec import HEX2VEC_FILTER from .popular import get_popular_tags @@ -9,6 +10,7 @@ "grouped_osm_tags_type", "osm_tags_type", "merge_grouped_osm_tags_type", + "BASE_OSM_GROUPS_FILTER", "GEOFABRIK_LAYERS", "HEX2VEC_FILTER", "get_popular_tags", From 0fb81ca165aff86caecc17a93a54e3eff01c3080 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 21 Apr 2023 21:43:35 +0200 Subject: [PATCH 20/30] test: change count embedder tests --- .../test_contextual_count_embedder.py | 1 + tests/embedders/test_count_embedder.py | 121 ++++++++++++++---- 2 files changed, 97 insertions(+), 25 deletions(-) create mode 100644 tests/embedders/test_contextual_count_embedder.py diff --git a/tests/embedders/test_contextual_count_embedder.py b/tests/embedders/test_contextual_count_embedder.py new file mode 100644 index 00000000..f183af29 --- /dev/null +++ b/tests/embedders/test_contextual_count_embedder.py @@ -0,0 +1 @@ +"""ContextualCountEmbedder tests.""" diff --git a/tests/embedders/test_count_embedder.py b/tests/embedders/test_count_embedder.py index 2ebd6ca7..be0a28da 100644 --- a/tests/embedders/test_count_embedder.py +++ b/tests/embedders/test_count_embedder.py @@ -1,8 +1,7 @@ """CountEmbedder tests.""" from contextlib import nullcontext as does_not_raise -from typing import Any, List, Union +from typing import TYPE_CHECKING, Any, List, Union -import geopandas as gpd import pandas as pd import pytest from pandas.testing import assert_frame_equal, assert_index_equal @@ -10,9 +9,27 @@ from srai.constants import REGIONS_INDEX from srai.embedders import CountEmbedder +if TYPE_CHECKING: # pragma: no cover + import geopandas as gpd + @pytest.fixture # type: ignore def expected_embedding_df() -> pd.DataFrame: + """Get expected CountEmbedder output for the default case.""" + expected_df = pd.DataFrame( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure": [0, 1, 1], + "amenity": [1, 0, 1], + }, + ) + expected_df.set_index(REGIONS_INDEX, inplace=True) + + return expected_df + + +@pytest.fixture # type: ignore +def expected_subcategories_embedding_df() -> pd.DataFrame: """Get expected CountEmbedder output for the default case.""" expected_df = pd.DataFrame( { @@ -42,7 +59,7 @@ def specified_features_expected_embedding_df() -> pd.DataFrame: REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], "amenity_parking": [0, 0, 0], "leisure_park": [0, 0, 0], - "amenity_pub": [1, 0, 1], + "amenity_pub": [0, 0, 0], }, ) expected_df.set_index(REGIONS_INDEX, inplace=True) @@ -50,31 +67,85 @@ def specified_features_expected_embedding_df() -> pd.DataFrame: return expected_df -def test_correct_embedding( - gdf_regions: gpd.GeoDataFrame, - gdf_features: gpd.GeoDataFrame, - gdf_joint: gpd.GeoDataFrame, - expected_embedding_df: pd.DataFrame, -) -> None: - """Test if CountEmbedder returns correct result in the default case.""" - embedding_df = CountEmbedder().transform( - regions_gdf=gdf_regions, features_gdf=gdf_features, joint_gdf=gdf_joint +@pytest.fixture # type: ignore +def specified_subcategories_features_expected_embedding_df() -> pd.DataFrame: + """Get expected CountEmbedder output for the case with specified features.""" + expected_df = pd.DataFrame( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking": [0, 0, 0], + "leisure_park": [0, 0, 0], + "amenity_pub": [1, 0, 1], + }, ) - assert_frame_equal(embedding_df, expected_embedding_df, check_dtype=False) + expected_df.set_index(REGIONS_INDEX, inplace=True) + + return expected_df -def test_correct_embedding_expected_features( - gdf_regions: gpd.GeoDataFrame, - gdf_features: gpd.GeoDataFrame, - gdf_joint: gpd.GeoDataFrame, - expected_feature_names: List[str], - specified_features_expected_embedding_df: pd.DataFrame, +@pytest.mark.parametrize( # type: ignore + "regions_fixture,features_fixture,joint_fixture,expected_embedding_fixture,count_subcategories,expected_features_fixture", + [ + ( + "gdf_regions", + "gdf_features", + "gdf_joint", + "expected_embedding_df", + False, + None, + ), + ( + "gdf_regions", + "gdf_features", + "gdf_joint", + "expected_subcategories_embedding_df", + True, + None, + ), + ( + "gdf_regions", + "gdf_features", + "gdf_joint", + "specified_features_expected_embedding_df", + False, + "expected_feature_names", + ), + ( + "gdf_regions", + "gdf_features", + "gdf_joint", + "specified_subcategories_features_expected_embedding_df", + True, + "expected_feature_names", + ), + ], +) +def test_correct_embedding( + regions_fixture: str, + features_fixture: str, + joint_fixture: str, + expected_embedding_fixture: str, + count_subcategories: bool, + expected_features_fixture: Union[str, None], + request: Any, ) -> None: - """Test if CountEmbedder returns correct result in the specified features case.""" - embedding_df = CountEmbedder(expected_output_features=expected_feature_names).transform( + """Test if CountEmbedder returns correct result with different parameters.""" + expected_output_features = ( + None + if expected_features_fixture is None + else request.getfixturevalue(expected_features_fixture) + ) + embedder = CountEmbedder( + expected_output_features=expected_output_features, count_subcategories=count_subcategories + ) + gdf_regions: "gpd.GeoDataFrame" = request.getfixturevalue(regions_fixture) + gdf_features: "gpd.GeoDataFrame" = request.getfixturevalue(features_fixture) + gdf_joint: "gpd.GeoDataFrame" = request.getfixturevalue(joint_fixture) + embedding_df = embedder.transform( regions_gdf=gdf_regions, features_gdf=gdf_features, joint_gdf=gdf_joint ) - assert_frame_equal(embedding_df, specified_features_expected_embedding_df, check_dtype=False) + expected_result_df = request.getfixturevalue(expected_embedding_fixture) + assert_frame_equal(embedding_df, expected_result_df, check_dtype=False) @pytest.mark.parametrize( # type: ignore @@ -125,9 +196,9 @@ def test_empty( else request.getfixturevalue(expected_features_fixture) ) embedder = CountEmbedder(expected_output_features) - gdf_regions: gpd.GeoDataFrame = request.getfixturevalue(regions_fixture) - gdf_features: gpd.GeoDataFrame = request.getfixturevalue(features_fixture) - gdf_joint: gpd.GeoDataFrame = request.getfixturevalue(joint_fixture) + gdf_regions: "gpd.GeoDataFrame" = request.getfixturevalue(regions_fixture) + gdf_features: "gpd.GeoDataFrame" = request.getfixturevalue(features_fixture) + gdf_joint: "gpd.GeoDataFrame" = request.getfixturevalue(joint_fixture) with expectation: embedding = embedder.transform(gdf_regions, gdf_features, gdf_joint) From ee5322b787132a0da1832182f2772334f6c30b5c Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 21 Apr 2023 21:53:13 +0200 Subject: [PATCH 21/30] test: fix intersection joiner geom test --- tests/joiners/test_intersection_joiner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/joiners/test_intersection_joiner.py b/tests/joiners/test_intersection_joiner.py index 85aef481..df0770f8 100644 --- a/tests/joiners/test_intersection_joiner.py +++ b/tests/joiners/test_intersection_joiner.py @@ -47,10 +47,13 @@ def test_correct_multiindex_intersection_joiner( regions_gdf: gpd.GeoDataFrame, features_gdf: gpd.GeoDataFrame, joint_multiindex: pd.MultiIndex ) -> None: """Test checks if intersection joiner returns correct MultiIndex.""" - joint = IntersectionJoiner().transform(regions=regions_gdf, features=features_gdf) + joint = IntersectionJoiner().transform( + regions=regions_gdf, features=features_gdf, return_geom=True + ) ut.assertEqual(joint.index.names, joint_multiindex.names) ut.assertCountEqual(joint.index, joint_multiindex) + ut.assertIn(GEOMETRY_COLUMN, joint.columns) def test_correct_multiindex_intersection_joiner_without_geom( From 1c3cf912e855edc5f7b41decbee7f925d7e600c4 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 21 Apr 2023 22:07:16 +0200 Subject: [PATCH 22/30] test: modify osm_pbf_loader tests --- srai/loaders/osm_loaders/pbf_file_handler.py | 6 ++-- .../osm_loaders/test_osm_pbf_loader.py | 28 ++++++++++++++++--- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/srai/loaders/osm_loaders/pbf_file_handler.py b/srai/loaders/osm_loaders/pbf_file_handler.py index 63d9e267..4cd2fef5 100644 --- a/srai/loaders/osm_loaders/pbf_file_handler.py +++ b/srai/loaders/osm_loaders/pbf_file_handler.py @@ -46,14 +46,14 @@ class PbfFileHandler(osmium.SimpleHandler): # type: ignore def __init__( self, - tags: osm_tags_type, + tags: Optional[osm_tags_type] = None, region_geometry: Optional[BaseGeometry] = None, ) -> None: """ Initialize PbfFileHandler. Args: - tags (osm_tags_type): A dictionary + tags (osm_tags_type, optional): A dictionary specifying which tags to download. The keys should be OSM tags (e.g. `building`, `amenity`). The values should either be `True` for retrieving all objects with the tag, @@ -62,7 +62,7 @@ def __init__( `tags={'leisure': 'park}` would return parks from the area. `tags={'leisure': 'park, 'amenity': True, 'shop': ['bakery', 'bicycle']}` would return parks, all amenity types, bakeries and bicycle shops. - If `None`, handler will allow all of the tags to be parsed. + If `None`, handler will allow all of the tags to be parsed. Defaults to `None`. region_geometry (BaseGeometry, optional): Region which can be used to filter only intersecting OSM objects. Defaults to None. """ diff --git a/tests/loaders/osm_loaders/test_osm_pbf_loader.py b/tests/loaders/osm_loaders/test_osm_pbf_loader.py index 99767eb7..82b2c79e 100644 --- a/tests/loaders/osm_loaders/test_osm_pbf_loader.py +++ b/tests/loaders/osm_loaders/test_osm_pbf_loader.py @@ -1,6 +1,6 @@ """Tests for OSMPbfLoader.""" from pathlib import Path -from typing import List +from typing import List, Union import geopandas as gpd import pytest @@ -9,8 +9,8 @@ from srai.constants import REGIONS_INDEX, WGS84_CRS from srai.loaders.osm_loaders import OSMPbfLoader -from srai.loaders.osm_loaders.filters._typing import osm_tags_type -from srai.loaders.osm_loaders.filters.hex2vec import HEX2VEC_FILTER +from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER, HEX2VEC_FILTER +from srai.loaders.osm_loaders.filters._typing import grouped_osm_tags_type, osm_tags_type from srai.loaders.osm_loaders.pbf_file_downloader import PbfFileDownloader from srai.loaders.osm_loaders.pbf_file_handler import PbfFileHandler @@ -194,7 +194,9 @@ def test_pbf_handler_geometry_filtering(): # type: ignore "test_geometries,pbf_file,query,expected_result_length,expected_features_columns_length", [ ([Polygon([(0, 0), (0, 1), (1, 1), (1, 0)])], None, HEX2VEC_FILTER, 0, 0), + ([Polygon([(0, 0), (0, 1), (1, 1), (1, 0)])], None, BASE_OSM_GROUPS_FILTER, 0, 0), ([Point([(-73.981883, 40.768081)])], None, HEX2VEC_FILTER, 2, 3), + ([Point([(-73.981883, 40.768081)])], None, BASE_OSM_GROUPS_FILTER, 3, 3), ( [Point([(-73.981883, 40.768081)])], Path(__file__).parent @@ -204,6 +206,15 @@ def test_pbf_handler_geometry_filtering(): # type: ignore 2, 3, ), + ( + [Point([(-73.981883, 40.768081)])], + Path(__file__).parent + / "test_files" + / "d17f922ed15e9609013a6b895e1e7af2d49158f03586f2c675d17b760af3452e.osm.pbf", + BASE_OSM_GROUPS_FILTER, + 3, + 3, + ), ( [Polygon([(0, 0), (0, 1), (1, 1), (1, 0)])], Path(__file__).parent @@ -213,12 +224,21 @@ def test_pbf_handler_geometry_filtering(): # type: ignore 0, 0, ), + ( + [Polygon([(0, 0), (0, 1), (1, 1), (1, 0)])], + Path(__file__).parent + / "test_files" + / "d17f922ed15e9609013a6b895e1e7af2d49158f03586f2c675d17b760af3452e.osm.pbf", + BASE_OSM_GROUPS_FILTER, + 0, + 0, + ), ], ) def test_osm_pbf_loader( test_geometries: List[BaseGeometry], pbf_file: Path, - query: osm_tags_type, + query: Union[osm_tags_type, grouped_osm_tags_type], expected_result_length: int, expected_features_columns_length: int, ): From cc94cd8fe09ee210994bda1562050ed81e10a947 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 21 Apr 2023 22:25:51 +0200 Subject: [PATCH 23/30] test: remove code coverage for pbf file download --- srai/loaders/osm_loaders/pbf_file_downloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/srai/loaders/osm_loaders/pbf_file_downloader.py b/srai/loaders/osm_loaders/pbf_file_downloader.py index 66996ec5..5c995205 100644 --- a/srai/loaders/osm_loaders/pbf_file_downloader.py +++ b/srai/loaders/osm_loaders/pbf_file_downloader.py @@ -127,7 +127,7 @@ def download_pbf_file_for_polygon( geometry_hash = self._get_geometry_hash(polygon) pbf_file_path = Path(self.download_directory).resolve() / f"{geometry_hash}.osm.pbf" - if not pbf_file_path.exists(): + if not pbf_file_path.exists(): # pragma: no cover boundary_polygon = self._prepare_polygon_for_download(polygon) geometry_geojson = mapping(boundary_polygon) From d3b238a0de2d66e367304b348df9d2a11da26bab Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 21 Apr 2023 22:44:40 +0200 Subject: [PATCH 24/30] docs: fix typo --- srai/loaders/osm_loaders/pbf_file_downloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/srai/loaders/osm_loaders/pbf_file_downloader.py b/srai/loaders/osm_loaders/pbf_file_downloader.py index 5c995205..787c19ae 100644 --- a/srai/loaders/osm_loaders/pbf_file_downloader.py +++ b/srai/loaders/osm_loaders/pbf_file_downloader.py @@ -218,7 +218,7 @@ def _prepare_polygon_for_download(self, polygon: Polygon) -> Polygon: Function buffers the polygon, closes internal holes and simplifies its boundary to 1000 points. - Makes sure that the generated polygon with fully cover the original one by increasing + Makes sure that the generated polygon will fully cover the original one by increasing the buffer size incrementally. """ is_fully_covered = False From 08ace83648972cd3203c557cd0251a591cec4ea5 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Sun, 23 Apr 2023 11:37:31 +0200 Subject: [PATCH 25/30] refactor: changed AttributeErrors into ValueErrors --- srai/embedders/contextual_count_embedder.py | 12 ++++++++++++ srai/loaders/osm_loaders/filters/_typing.py | 2 +- srai/neighbourhoods/adjacency_neighbourhood.py | 4 ++-- srai/plotting/folium_wrapper.py | 4 ++-- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/srai/embedders/contextual_count_embedder.py b/srai/embedders/contextual_count_embedder.py index b76460d8..db567d08 100644 --- a/srai/embedders/contextual_count_embedder.py +++ b/srai/embedders/contextual_count_embedder.py @@ -49,6 +49,9 @@ def __init__( count_subcategories (bool, optional): Whether to count all subcategories individually or count features only on the highest level based on features column name. Defaults to False. + + Raises: + ValueError: If `neighbourhood_distance` is negative. """ super().__init__(expected_output_features, count_subcategories) @@ -56,6 +59,9 @@ def __init__( self.neighbourhood_distance = neighbourhood_distance self.concatenate_vectors = concatenate_vectors + if self.neighbourhood_distance < 0: + raise ValueError("Neighbourhood distance must be positive.") + def transform( self, regions_gdf: gpd.GeoDataFrame, @@ -120,6 +126,9 @@ def _get_squashed_embeddings(self, counts_df: pd.DataFrame) -> pd.DataFrame: for idx, region_id in tqdm( enumerate(counts_df.index), desc="Generating embeddings", total=len(counts_df.index) ): + if self.neighbourhood_distance == 0: + continue + for distance in range(1, self.neighbourhood_distance + 1): neighbours = self.neighbourhood.get_neighbours_at_distance(region_id, distance) matching_neighbours = counts_df.index.intersection(neighbours) @@ -162,6 +171,9 @@ def _get_concatenated_embeddings(self, counts_df: pd.DataFrame) -> pd.DataFrame: for idx, region_id in tqdm( enumerate(counts_df.index), desc="Generating embeddings", total=len(counts_df.index) ): + if self.neighbourhood_distance == 0: + continue + for distance in range(1, self.neighbourhood_distance + 1): neighbours = self.neighbourhood.get_neighbours_at_distance(region_id, distance) matching_neighbours = counts_df.index.intersection(neighbours) diff --git a/srai/loaders/osm_loaders/filters/_typing.py b/srai/loaders/osm_loaders/filters/_typing.py index bd480ba5..ad6fdaf4 100644 --- a/srai/loaders/osm_loaders/filters/_typing.py +++ b/srai/loaders/osm_loaders/filters/_typing.py @@ -21,7 +21,7 @@ def merge_grouped_osm_tags_type(grouped_filter: grouped_osm_tags_type) -> osm_ta osm_tags_type: Merged filter. """ if not is_expected_type(grouped_filter, grouped_osm_tags_type): - raise AttributeError( + raise ValueError( "Provided filter doesn't match required `grouped_osm_tags_type` definition." ) diff --git a/srai/neighbourhoods/adjacency_neighbourhood.py b/srai/neighbourhoods/adjacency_neighbourhood.py index 529c67e8..fbe757b8 100644 --- a/srai/neighbourhoods/adjacency_neighbourhood.py +++ b/srai/neighbourhoods/adjacency_neighbourhood.py @@ -31,10 +31,10 @@ def __init__(self, regions_gdf: gpd.GeoDataFrame) -> None: regions_gdf (gpd.GeoDataFrame): regions for which a neighbourhood will be calculated. Raises: - AttributeError: If regions_gdf doesn't have geometry column. + ValueError: If regions_gdf doesn't have geometry column. """ if GEOMETRY_COLUMN not in regions_gdf.columns: - raise AttributeError("Regions must have a geometry column.") + raise ValueError("Regions must have a geometry column.") self.regions_gdf = regions_gdf self.lookup: Dict[Hashable, Set[Hashable]] = {} diff --git a/srai/plotting/folium_wrapper.py b/srai/plotting/folium_wrapper.py index 15584a07..2cd5beed 100644 --- a/srai/plotting/folium_wrapper.py +++ b/srai/plotting/folium_wrapper.py @@ -157,7 +157,7 @@ def plot_neighbours( folium.Map: Generated map. """ if region_id not in regions_gdf.index: - raise AttributeError(f"{region_id!r} doesn't exist in provided regions_gdf.") + raise ValueError(f"{region_id!r} doesn't exist in provided regions_gdf.") regions_gdf_copy = regions_gdf.copy() regions_gdf_copy["region"] = "other" @@ -220,7 +220,7 @@ def plot_all_neighbourhood( folium.Map: Generated map. """ if region_id not in regions_gdf.index: - raise AttributeError(f"{region_id!r} doesn't exist in provided regions_gdf.") + raise ValueError(f"{region_id!r} doesn't exist in provided regions_gdf.") regions_gdf_copy = regions_gdf.copy() regions_gdf_copy["region"] = "other" From 0fb808c64b83e064bb96201338837ba73ae1fa57 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Sun, 23 Apr 2023 11:39:35 +0200 Subject: [PATCH 26/30] refactor: change AttributeError to ValueError in tests --- .../loaders/osm_loaders/filters/test_merge_filter_types.py | 6 +++--- tests/neighbourhoods/test_adjacency_neighbourhood.py | 2 +- .../regionizers/test_administrative_boundary_regionizer.py | 2 +- tests/regionizers/test_h3_regionizer.py | 2 +- tests/regionizers/test_s2_regionizer.py | 2 +- tests/regionizers/test_voronoi_regionizer.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/loaders/osm_loaders/filters/test_merge_filter_types.py b/tests/loaders/osm_loaders/filters/test_merge_filter_types.py index 07894cfb..0f209adb 100644 --- a/tests/loaders/osm_loaders/filters/test_merge_filter_types.py +++ b/tests/loaders/osm_loaders/filters/test_merge_filter_types.py @@ -16,9 +16,9 @@ @pytest.mark.parametrize( # type: ignore "grouped_filter,expected_result_filter,expectation", [ - ({"tag_a": True}, {"tag_a": True}, pytest.raises(AttributeError)), - ({"tag_a": "A"}, {"tag_a": "A"}, pytest.raises(AttributeError)), - ({"tag_a": ["A"]}, {"tag_a": ["A"]}, pytest.raises(AttributeError)), + ({"tag_a": True}, {"tag_a": True}, pytest.raises(ValueError)), + ({"tag_a": "A"}, {"tag_a": "A"}, pytest.raises(ValueError)), + ({"tag_a": ["A"]}, {"tag_a": ["A"]}, pytest.raises(ValueError)), ({}, {}, does_not_raise()), ({"group_a": {}}, {}, does_not_raise()), ({"group_a": {"tag_a": True}}, {"tag_a": True}, does_not_raise()), diff --git a/tests/neighbourhoods/test_adjacency_neighbourhood.py b/tests/neighbourhoods/test_adjacency_neighbourhood.py index 1055c7a6..a12e5742 100644 --- a/tests/neighbourhoods/test_adjacency_neighbourhood.py +++ b/tests/neighbourhoods/test_adjacency_neighbourhood.py @@ -91,7 +91,7 @@ def rounded_regions_fixture() -> gpd.GeoDataFrame: def test_no_geometry_gdf_attribute_error(no_geometry_gdf: gpd.GeoDataFrame) -> None: """Test checks if GeoDataFrames without geometry are disallowed.""" - with pytest.raises(AttributeError): + with pytest.raises(ValueError): AdjacencyNeighbourhood(no_geometry_gdf) diff --git a/tests/regionizers/test_administrative_boundary_regionizer.py b/tests/regionizers/test_administrative_boundary_regionizer.py index 6f0480ba..9f271e54 100644 --- a/tests/regionizers/test_administrative_boundary_regionizer.py +++ b/tests/regionizers/test_administrative_boundary_regionizer.py @@ -46,7 +46,7 @@ def test_admin_level( def test_empty_gdf_attribute_error(gdf_empty) -> None: # type: ignore """Test checks if empty GeoDataFrames are disallowed.""" - with pytest.raises(AttributeError): + with pytest.raises(ValueError): abr = AdministrativeBoundaryRegionizer(admin_level=4) abr.transform(gdf_empty) diff --git a/tests/regionizers/test_h3_regionizer.py b/tests/regionizers/test_h3_regionizer.py index 6ea63667..84623e79 100644 --- a/tests/regionizers/test_h3_regionizer.py +++ b/tests/regionizers/test_h3_regionizer.py @@ -44,7 +44,7 @@ def expected_unbuffered_h3_indexes() -> List[str]: ("gdf_polygons", "expected_h3_indexes", H3_RESOLUTION, True, does_not_raise()), ("gdf_polygons", "expected_unbuffered_h3_indexes", H3_RESOLUTION, False, does_not_raise()), ("gdf_multipolygon", "expected_h3_indexes", H3_RESOLUTION, True, does_not_raise()), - ("gdf_empty", "expected_h3_indexes", H3_RESOLUTION, True, pytest.raises(AttributeError)), + ("gdf_empty", "expected_h3_indexes", H3_RESOLUTION, True, pytest.raises(ValueError)), ("gdf_polygons", "expected_h3_indexes", -1, True, pytest.raises(ValueError)), ("gdf_polygons", "expected_h3_indexes", 16, True, pytest.raises(ValueError)), ("gdf_no_crs", "expected_h3_indexes", H3_RESOLUTION, True, pytest.raises(ValueError)), diff --git a/tests/regionizers/test_s2_regionizer.py b/tests/regionizers/test_s2_regionizer.py index 90bb6891..ea85b4a6 100644 --- a/tests/regionizers/test_s2_regionizer.py +++ b/tests/regionizers/test_s2_regionizer.py @@ -38,7 +38,7 @@ def expected_s2_indexes() -> List[str]: [ ("gdf_polygons", "expected_s2_indexes", S2_RESOLUTION, does_not_raise()), ("gdf_multipolygon", "expected_s2_indexes", S2_RESOLUTION, does_not_raise()), - ("gdf_empty", "expected_s2_indexes", S2_RESOLUTION, pytest.raises(AttributeError)), + ("gdf_empty", "expected_s2_indexes", S2_RESOLUTION, pytest.raises(ValueError)), ("gdf_polygons", "expected_s2_indexes", -1, pytest.raises(ValueError)), ("gdf_polygons", "expected_s2_indexes", 31, pytest.raises(ValueError)), ("gdf_no_crs", "expected_s2_indexes", S2_RESOLUTION, pytest.raises(ValueError)), diff --git a/tests/regionizers/test_voronoi_regionizer.py b/tests/regionizers/test_voronoi_regionizer.py index 0d24622d..1abc181d 100644 --- a/tests/regionizers/test_voronoi_regionizer.py +++ b/tests/regionizers/test_voronoi_regionizer.py @@ -14,7 +14,7 @@ def test_empty_gdf_attribute_error(gdf_empty: gpd.GeoDataFrame) -> None: """Test checks if empty GeoDataFrames are disallowed.""" - with pytest.raises(AttributeError): + with pytest.raises(ValueError): VoronoiRegionizer(seeds=gdf_empty) From eee49a6a0c64b49c8f4197892cacb3faa1215aab Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Sun, 23 Apr 2023 11:52:32 +0200 Subject: [PATCH 27/30] chore: apply CR suggestions --- tests/embedders/test_count_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/embedders/test_count_embedder.py b/tests/embedders/test_count_embedder.py index be0a28da..7bd3381b 100644 --- a/tests/embedders/test_count_embedder.py +++ b/tests/embedders/test_count_embedder.py @@ -30,7 +30,7 @@ def expected_embedding_df() -> pd.DataFrame: @pytest.fixture # type: ignore def expected_subcategories_embedding_df() -> pd.DataFrame: - """Get expected CountEmbedder output for the default case.""" + """Get expected CountEmbedder output with subcategories for the default case.""" expected_df = pd.DataFrame( { REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], @@ -69,7 +69,7 @@ def specified_features_expected_embedding_df() -> pd.DataFrame: @pytest.fixture # type: ignore def specified_subcategories_features_expected_embedding_df() -> pd.DataFrame: - """Get expected CountEmbedder output for the case with specified features.""" + """Get expected CountEmbedder output with subcategories for the case with specified features.""" expected_df = pd.DataFrame( { REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], From 606f4d43b800110e417aaf9563b55a372f93a787 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Sun, 23 Apr 2023 11:54:03 +0200 Subject: [PATCH 28/30] test: add tests for ContextualCountEmbedder --- .../test_contextual_count_embedder.py | 764 ++++++++++++++++++ 1 file changed, 764 insertions(+) diff --git a/tests/embedders/test_contextual_count_embedder.py b/tests/embedders/test_contextual_count_embedder.py index f183af29..861ff259 100644 --- a/tests/embedders/test_contextual_count_embedder.py +++ b/tests/embedders/test_contextual_count_embedder.py @@ -1 +1,765 @@ """ContextualCountEmbedder tests.""" +from contextlib import nullcontext as does_not_raise +from typing import TYPE_CHECKING, Any, Dict, List, Union + +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal, assert_index_equal +from parametrization import Parametrization as P + +from srai.constants import REGIONS_INDEX +from srai.embedders import ContextualCountEmbedder +from srai.neighbourhoods import H3Neighbourhood + +if TYPE_CHECKING: # pragma: no cover + import geopandas as gpd + + +def _create_features_dataframe(data: Dict[str, Any]) -> pd.DataFrame: + return pd.DataFrame(data).set_index(REGIONS_INDEX) + + +@pytest.fixture # type: ignore +def expected_embedding_df_squashed_distance_0() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count without subcategories. Squashed features, distance 0. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure": [0, 1, 1], + "amenity": [1, 0, 1], + } + ) + + +@pytest.fixture # type: ignore +def expected_embedding_df_squashed_distance_1() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count without subcategories. Squashed features, distance 1+. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure": [0.25, 1.125, 1.125], + "amenity": [1.125, 0.25, 1.125], + }, + ) + + +@pytest.fixture # type: ignore +def expected_embedding_df_concatenated_distance_0() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count without subcategories. Concatenated features, distance 0. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure_0": [0, 1, 1], + "amenity_0": [1, 0, 1], + } + ) + + +@pytest.fixture # type: ignore +def expected_embedding_df_concatenated_distance_1() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count without subcategories. Concatenated features, distance 1. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure_0": [0, 1, 1], + "amenity_0": [1, 0, 1], + "leisure_1": [1, 0.5, 0.5], + "amenity_1": [0.5, 1, 0.5], + }, + ) + + +@pytest.fixture # type: ignore +def expected_embedding_df_concatenated_distance_2() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count without subcategories. Concatenated features, distance 2. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure_0": [0, 1, 1], + "amenity_0": [1, 0, 1], + "leisure_1": [1, 0.5, 0.5], + "amenity_1": [0.5, 1, 0.5], + "leisure_2": [0, 0, 0], + "amenity_2": [0, 0, 0], + }, + ) + + +@pytest.fixture # type: ignore +def expected_subcategories_embedding_df_squashed_distance_0() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count with subcategories. Squashed features, distance 0. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure_adult_gaming_centre": [0, 0, 1], + "leisure_playground": [0, 1, 0], + "amenity_pub": [1, 0, 1], + }, + ) + + +@pytest.fixture # type: ignore +def expected_subcategories_embedding_df_squashed_distance_1() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count with subcategories. Squashed features, distance 1+. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure_adult_gaming_centre": [0.125, 0.125, 1], + "leisure_playground": [0.125, 1, 0.125], + "amenity_pub": [1.125, 0.25, 1.125], + }, + ) + + +@pytest.fixture # type: ignore +def expected_subcategories_embedding_df_concatenated_distance_0() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count with subcategories. Concatenated features, distance 0. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure_adult_gaming_centre_0": [0, 0, 1], + "leisure_playground_0": [0, 1, 0], + "amenity_pub_0": [1, 0, 1], + }, + ) + + +@pytest.fixture # type: ignore +def expected_subcategories_embedding_df_concatenated_distance_1() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count with subcategories. Concatenated features, distance 1. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure_adult_gaming_centre_0": [0, 0, 1], + "leisure_playground_0": [0, 1, 0], + "amenity_pub_0": [1, 0, 1], + "leisure_adult_gaming_centre_1": [0.5, 0.5, 0], + "leisure_playground_1": [0.5, 0, 0.5], + "amenity_pub_1": [0.5, 1, 0.5], + }, + ) + + +@pytest.fixture # type: ignore +def expected_subcategories_embedding_df_concatenated_distance_2() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output. + + Count with subcategories. Concatenated features, distance 2. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "leisure_adult_gaming_centre_0": [0, 0, 1], + "leisure_playground_0": [0, 1, 0], + "amenity_pub_0": [1, 0, 1], + "leisure_adult_gaming_centre_1": [0.5, 0.5, 0], + "leisure_playground_1": [0.5, 0, 0.5], + "amenity_pub_1": [0.5, 1, 0.5], + "leisure_adult_gaming_centre_2": [0, 0, 0], + "leisure_playground_2": [0, 0, 0], + "amenity_pub_2": [0, 0, 0], + }, + ) + + +@pytest.fixture # type: ignore +def expected_feature_names() -> List[str]: + """Get expected feature names for ContextualCountEmbedder.""" + expected_feature_names = ["amenity_parking", "leisure_park", "amenity_pub"] + return expected_feature_names + + +@pytest.fixture # type: ignore +def specified_features_expected_embedding_df_squashed_empty() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output for the case with specified features. + + Count without subcategories. Squashed features, distance 0+. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking": [0, 0, 0], + "leisure_park": [0, 0, 0], + "amenity_pub": [0, 0, 0], + } + ) + + +@pytest.fixture # type: ignore +def specified_features_expected_subcategories_embedding_df_squashed_distance_0() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output for the case with specified features. + + Count with subcategories. Squashed features, distance 0. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking": [0, 0, 0], + "leisure_park": [0, 0, 0], + "amenity_pub": [1, 0, 1], + } + ) + + +@pytest.fixture # type: ignore +def specified_features_expected_subcategories_embedding_df_squashed_distance_1() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output for the case with specified features. + + Count with subcategories. Squashed features, distance 1+. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking": [0, 0, 0], + "leisure_park": [0, 0, 0], + "amenity_pub": [1.125, 0.25, 1.125], + } + ) + + +@pytest.fixture # type: ignore +def specified_features_expected_embedding_df_concatenated_distance_0() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output for the case with specified features. + + Count without subcategories. Concatenated features, distance 0. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking_0": [0, 0, 0], + "leisure_park_0": [0, 0, 0], + "amenity_pub_0": [0, 0, 0], + } + ) + + +@pytest.fixture # type: ignore +def specified_features_expected_embedding_df_concatenated_distance_1() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output for the case with specified features. + + Count without subcategories. Concatenated features, distance 1. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking_0": [0, 0, 0], + "leisure_park_0": [0, 0, 0], + "amenity_pub_0": [0, 0, 0], + "amenity_parking_1": [0, 0, 0], + "leisure_park_1": [0, 0, 0], + "amenity_pub_1": [0, 0, 0], + } + ) + + +@pytest.fixture # type: ignore +def specified_features_expected_embedding_df_concatenated_distance_2() -> pd.DataFrame: + """ + Get expected ContextualCountEmbedder output for the case with specified features. + + Count without subcategories. Concatenated features, distance 2. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking_0": [0, 0, 0], + "leisure_park_0": [0, 0, 0], + "amenity_pub_0": [0, 0, 0], + "amenity_parking_1": [0, 0, 0], + "leisure_park_1": [0, 0, 0], + "amenity_pub_1": [0, 0, 0], + "amenity_parking_2": [0, 0, 0], + "leisure_park_2": [0, 0, 0], + "amenity_pub_2": [0, 0, 0], + } + ) + + +@pytest.fixture # type: ignore +def specified_features_expected_subcategories_embedding_df_concatenated_distance_0() -> ( + pd.DataFrame +): + """ + Get expected ContextualCountEmbedder output for the case with specified features. + + Count with subcategories. Concatenated features, distance 0. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking_0": [0, 0, 0], + "leisure_park_0": [0, 0, 0], + "amenity_pub_0": [1, 0, 1], + } + ) + + +@pytest.fixture # type: ignore +def specified_features_expected_subcategories_embedding_df_concatenated_distance_1() -> ( + pd.DataFrame +): + """ + Get expected ContextualCountEmbedder output for the case with specified features. + + Count with subcategories. Concatenated features, distance 1. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking_0": [0, 0, 0], + "leisure_park_0": [0, 0, 0], + "amenity_pub_0": [1, 0, 1], + "amenity_parking_1": [0, 0, 0], + "leisure_park_1": [0, 0, 0], + "amenity_pub_1": [0.5, 1, 0.5], + } + ) + + +@pytest.fixture # type: ignore +def specified_features_expected_subcategories_embedding_df_concatenated_distance_2() -> ( + pd.DataFrame +): + """ + Get expected ContextualCountEmbedder output for the case with specified features. + + Count with subcategories. Concatenated features, distance 2. + """ + return _create_features_dataframe( + { + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + "amenity_parking_0": [0, 0, 0], + "leisure_park_0": [0, 0, 0], + "amenity_pub_0": [1, 0, 1], + "amenity_parking_1": [0, 0, 0], + "leisure_park_1": [0, 0, 0], + "amenity_pub_1": [0.5, 1, 0.5], + "amenity_parking_2": [0, 0, 0], + "leisure_park_2": [0, 0, 0], + "amenity_pub_2": [0, 0, 0], + } + ) + + +@P.parameters( + "expected_embedding_fixture", + "neighbourhood_distance", + "concatenate_features", + "count_subcategories", + "expected_features_fixture", +) # type: ignore +@P.case( # type: ignore + "Squashed features, distance 0, without subcategories", + "expected_embedding_df_squashed_distance_0", + 0, + False, + False, + None, +) +@P.case( # type: ignore + "Squashed features, distance 1, without subcategories", + "expected_embedding_df_squashed_distance_1", + 1, + False, + False, + None, +) +@P.case( # type: ignore + "Squashed features, distance 2, without subcategories", + "expected_embedding_df_squashed_distance_1", + 2, + False, + False, + None, +) +@P.case( # type: ignore + "Concatenated features, distance 0, without subcategories", + "expected_embedding_df_concatenated_distance_0", + 0, + True, + False, + None, +) +@P.case( # type: ignore + "Concatenated features, distance 1, without subcategories", + "expected_embedding_df_concatenated_distance_1", + 1, + True, + False, + None, +) +@P.case( # type: ignore + "Concatenated features, distance 2, without subcategories", + "expected_embedding_df_concatenated_distance_2", + 2, + True, + False, + None, +) +@P.case( # type: ignore + "Squashed features, distance 0, witt subcategories", + "expected_subcategories_embedding_df_squashed_distance_0", + 0, + False, + True, + None, +) +@P.case( # type: ignore + "Squashed features, distance 1, with subcategories", + "expected_subcategories_embedding_df_squashed_distance_1", + 1, + False, + True, + None, +) +@P.case( # type: ignore + "Squashed features, distance 2, with subcategories", + "expected_subcategories_embedding_df_squashed_distance_1", + 2, + False, + True, + None, +) +@P.case( # type: ignore + "Concatenated features, distance 0, with subcategories", + "expected_subcategories_embedding_df_concatenated_distance_0", + 0, + True, + True, + None, +) +@P.case( # type: ignore + "Concatenated features, distance 1, with subcategories", + "expected_subcategories_embedding_df_concatenated_distance_1", + 1, + True, + True, + None, +) +@P.case( # type: ignore + "Concatenated features, distance 2, with subcategories", + "expected_subcategories_embedding_df_concatenated_distance_2", + 2, + True, + True, + None, +) +@P.case( # type: ignore + "Squashed features, distance 0, without subcategories, specified features", + "specified_features_expected_embedding_df_squashed_empty", + 0, + False, + False, + "expected_feature_names", +) +@P.case( # type: ignore + "Squashed features, distance 1, without subcategories, specified features", + "specified_features_expected_embedding_df_squashed_empty", + 1, + False, + False, + "expected_feature_names", +) +@P.case( # type: ignore + "Squashed features, distance 2, without subcategories, specified features", + "specified_features_expected_embedding_df_squashed_empty", + 2, + False, + False, + "expected_feature_names", +) +@P.case( # type: ignore + "Squashed features, distance 0, with subcategories, specified features", + "specified_features_expected_subcategories_embedding_df_squashed_distance_0", + 0, + False, + True, + "expected_feature_names", +) +@P.case( # type: ignore + "Squashed features, distance 1, with subcategories, specified features", + "specified_features_expected_subcategories_embedding_df_squashed_distance_1", + 1, + False, + True, + "expected_feature_names", +) +@P.case( # type: ignore + "Squashed features, distance 2, with subcategories, specified features", + "specified_features_expected_subcategories_embedding_df_squashed_distance_1", + 2, + False, + True, + "expected_feature_names", +) +@P.case( # type: ignore + "Concatenated features, distance 0, without subcategories, specified features", + "specified_features_expected_embedding_df_concatenated_distance_0", + 0, + True, + False, + "expected_feature_names", +) +@P.case( # type: ignore + "Concatenated features, distance 1, without subcategories, specified features", + "specified_features_expected_embedding_df_concatenated_distance_1", + 1, + True, + False, + "expected_feature_names", +) +@P.case( # type: ignore + "Concatenated features, distance 2, without subcategories, specified features", + "specified_features_expected_embedding_df_concatenated_distance_2", + 2, + True, + False, + "expected_feature_names", +) +@P.case( # type: ignore + "Concatenated features, distance 0, with subcategories, specified features", + "specified_features_expected_subcategories_embedding_df_concatenated_distance_0", + 0, + True, + True, + "expected_feature_names", +) +@P.case( # type: ignore + "Concatenated features, distance 1, with subcategories, specified features", + "specified_features_expected_subcategories_embedding_df_concatenated_distance_1", + 1, + True, + True, + "expected_feature_names", +) +@P.case( # type: ignore + "Concatenated features, distance 2, with subcategories, specified features", + "specified_features_expected_subcategories_embedding_df_concatenated_distance_2", + 2, + True, + True, + "expected_feature_names", +) +def test_correct_embedding( + expected_embedding_fixture: str, + neighbourhood_distance: int, + concatenate_features: bool, + count_subcategories: bool, + expected_features_fixture: Union[str, None], + request: Any, +) -> None: + """Test if ContextualCountEmbedder returns correct result with different parameters.""" + expected_output_features = ( + None + if expected_features_fixture is None + else request.getfixturevalue(expected_features_fixture) + ) + gdf_regions: "gpd.GeoDataFrame" = request.getfixturevalue("gdf_regions") + gdf_features: "gpd.GeoDataFrame" = request.getfixturevalue("gdf_features") + gdf_joint: "gpd.GeoDataFrame" = request.getfixturevalue("gdf_joint") + + embedder = ContextualCountEmbedder( + neighbourhood=H3Neighbourhood(), + neighbourhood_distance=neighbourhood_distance, + expected_output_features=expected_output_features, + count_subcategories=count_subcategories, + concatenate_vectors=concatenate_features, + ) + embedding_df = embedder.transform( + regions_gdf=gdf_regions, features_gdf=gdf_features, joint_gdf=gdf_joint + ) + + expected_result_df = request.getfixturevalue(expected_embedding_fixture) + assert_frame_equal(embedding_df, expected_result_df, check_dtype=False) + + +def test_negative_nighbourhood_distance() -> None: + """Test checks if negative neighbouthood distance is disallowed.""" + with pytest.raises(ValueError): + ContextualCountEmbedder(neighbourhood=H3Neighbourhood(), neighbourhood_distance=-1) + + +@pytest.mark.parametrize( # type: ignore + "regions_fixture,features_fixture,joint_fixture,expected_features_fixture,expectation", + [ + ( + "gdf_regions_empty", + "gdf_features", + "gdf_joint", + None, + does_not_raise(), + ), + ( + "gdf_regions", + "gdf_features_empty", + "gdf_joint", + None, + pytest.raises(ValueError), + ), + ( + "gdf_regions", + "gdf_features_empty", + "gdf_joint", + "expected_feature_names", + does_not_raise(), + ), + ( + "gdf_regions", + "gdf_features", + "gdf_joint_empty", + None, + does_not_raise(), + ), + ], +) +@pytest.mark.parametrize("concatenate_features", [False, True]) # type: ignore +@pytest.mark.parametrize("count_subcategories", [False, True]) # type: ignore +@pytest.mark.parametrize("neighbourhood_distance", [0, 1, 2]) # type: ignore +def test_empty( + regions_fixture: str, + features_fixture: str, + joint_fixture: str, + concatenate_features: bool, + count_subcategories: bool, + neighbourhood_distance: int, + expected_features_fixture: Union[str, None], + expectation: Any, + request: Any, +) -> None: + """Test ContextualCountEmbedder handling of empty input data frames.""" + expected_output_features = ( + None + if expected_features_fixture is None + else request.getfixturevalue(expected_features_fixture) + ) + embedder = ContextualCountEmbedder( + neighbourhood=H3Neighbourhood(), + neighbourhood_distance=neighbourhood_distance, + expected_output_features=expected_output_features, + count_subcategories=count_subcategories, + concatenate_vectors=concatenate_features, + ) + gdf_regions: "gpd.GeoDataFrame" = request.getfixturevalue(regions_fixture) + gdf_features: "gpd.GeoDataFrame" = request.getfixturevalue(features_fixture) + gdf_joint: "gpd.GeoDataFrame" = request.getfixturevalue(joint_fixture) + + with expectation: + embedding = embedder.transform(gdf_regions, gdf_features, gdf_joint) + assert len(embedding) == len(gdf_regions) + assert_index_equal(embedding.index, gdf_regions.index) + if expected_output_features: + assert len(embedding.columns) == len(expected_output_features) * ( + 1 if not concatenate_features else 1 + neighbourhood_distance + ) + + assert (embedding == 0).all().all() + + +@pytest.mark.parametrize( # type: ignore + "regions_fixture,features_fixture,joint_fixture,expectation", + [ + ( + "gdf_unnamed_single_index", + "gdf_features", + "gdf_joint", + pytest.raises(ValueError), + ), + ( + "gdf_regions", + "gdf_unnamed_single_index", + "gdf_joint", + pytest.raises(ValueError), + ), + ( + "gdf_regions", + "gdf_features", + "gdf_unnamed_single_index", + pytest.raises(ValueError), + ), + ( + "gdf_regions", + "gdf_features", + "gdf_three_level_multi_index", + pytest.raises(ValueError), + ), + ( + "gdf_incorrectly_named_single_index", + "gdf_features", + "gdf_joint", + pytest.raises(ValueError), + ), + ( + "gdf_regions", + "gdf_incorrectly_named_single_index", + "gdf_joint", + pytest.raises(ValueError), + ), + ], +) +@pytest.mark.parametrize("concatenate_features", [False, True]) # type: ignore +@pytest.mark.parametrize("count_subcategories", [False, True]) # type: ignore +@pytest.mark.parametrize("neighbourhood_distance", [0, 1, 2]) # type: ignore +def test_incorrect_indexes( + regions_fixture: str, + features_fixture: str, + joint_fixture: str, + concatenate_features: bool, + count_subcategories: bool, + neighbourhood_distance: int, + expectation: Any, + request: Any, +) -> None: + """Test if cannot embed with incorrect dataframe indexes.""" + regions_gdf = request.getfixturevalue(regions_fixture) + features_gdf = request.getfixturevalue(features_fixture) + joint_gdf = request.getfixturevalue(joint_fixture) + + with expectation: + ContextualCountEmbedder( + neighbourhood=H3Neighbourhood(), + count_subcategories=count_subcategories, + concatenate_vectors=concatenate_features, + neighbourhood_distance=neighbourhood_distance, + ).transform(regions_gdf=regions_gdf, features_gdf=features_gdf, joint_gdf=joint_gdf) From a060fe15caf30adbb788cbd4991dc63be4a81518 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Sun, 23 Apr 2023 11:57:29 +0200 Subject: [PATCH 29/30] fix: change errors in tests --- tests/regionizers/test_administrative_boundary_regionizer.py | 2 +- tests/regionizers/test_h3_regionizer.py | 2 +- tests/regionizers/test_s2_regionizer.py | 2 +- tests/regionizers/test_voronoi_regionizer.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/regionizers/test_administrative_boundary_regionizer.py b/tests/regionizers/test_administrative_boundary_regionizer.py index 9f271e54..6f0480ba 100644 --- a/tests/regionizers/test_administrative_boundary_regionizer.py +++ b/tests/regionizers/test_administrative_boundary_regionizer.py @@ -46,7 +46,7 @@ def test_admin_level( def test_empty_gdf_attribute_error(gdf_empty) -> None: # type: ignore """Test checks if empty GeoDataFrames are disallowed.""" - with pytest.raises(ValueError): + with pytest.raises(AttributeError): abr = AdministrativeBoundaryRegionizer(admin_level=4) abr.transform(gdf_empty) diff --git a/tests/regionizers/test_h3_regionizer.py b/tests/regionizers/test_h3_regionizer.py index 84623e79..6ea63667 100644 --- a/tests/regionizers/test_h3_regionizer.py +++ b/tests/regionizers/test_h3_regionizer.py @@ -44,7 +44,7 @@ def expected_unbuffered_h3_indexes() -> List[str]: ("gdf_polygons", "expected_h3_indexes", H3_RESOLUTION, True, does_not_raise()), ("gdf_polygons", "expected_unbuffered_h3_indexes", H3_RESOLUTION, False, does_not_raise()), ("gdf_multipolygon", "expected_h3_indexes", H3_RESOLUTION, True, does_not_raise()), - ("gdf_empty", "expected_h3_indexes", H3_RESOLUTION, True, pytest.raises(ValueError)), + ("gdf_empty", "expected_h3_indexes", H3_RESOLUTION, True, pytest.raises(AttributeError)), ("gdf_polygons", "expected_h3_indexes", -1, True, pytest.raises(ValueError)), ("gdf_polygons", "expected_h3_indexes", 16, True, pytest.raises(ValueError)), ("gdf_no_crs", "expected_h3_indexes", H3_RESOLUTION, True, pytest.raises(ValueError)), diff --git a/tests/regionizers/test_s2_regionizer.py b/tests/regionizers/test_s2_regionizer.py index ea85b4a6..90bb6891 100644 --- a/tests/regionizers/test_s2_regionizer.py +++ b/tests/regionizers/test_s2_regionizer.py @@ -38,7 +38,7 @@ def expected_s2_indexes() -> List[str]: [ ("gdf_polygons", "expected_s2_indexes", S2_RESOLUTION, does_not_raise()), ("gdf_multipolygon", "expected_s2_indexes", S2_RESOLUTION, does_not_raise()), - ("gdf_empty", "expected_s2_indexes", S2_RESOLUTION, pytest.raises(ValueError)), + ("gdf_empty", "expected_s2_indexes", S2_RESOLUTION, pytest.raises(AttributeError)), ("gdf_polygons", "expected_s2_indexes", -1, pytest.raises(ValueError)), ("gdf_polygons", "expected_s2_indexes", 31, pytest.raises(ValueError)), ("gdf_no_crs", "expected_s2_indexes", S2_RESOLUTION, pytest.raises(ValueError)), diff --git a/tests/regionizers/test_voronoi_regionizer.py b/tests/regionizers/test_voronoi_regionizer.py index 1abc181d..0d24622d 100644 --- a/tests/regionizers/test_voronoi_regionizer.py +++ b/tests/regionizers/test_voronoi_regionizer.py @@ -14,7 +14,7 @@ def test_empty_gdf_attribute_error(gdf_empty: gpd.GeoDataFrame) -> None: """Test checks if empty GeoDataFrames are disallowed.""" - with pytest.raises(ValueError): + with pytest.raises(AttributeError): VoronoiRegionizer(seeds=gdf_empty) From d3edb7977c1657efc905c59cfd7104b3c4328add Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Mon, 24 Apr 2023 22:36:53 +0200 Subject: [PATCH 30/30] docs: fix heading in example --- examples/embedders/contextual_count_embedder.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/embedders/contextual_count_embedder.ipynb b/examples/embedders/contextual_count_embedder.ipynb index c33dd68b..e43b86db 100644 --- a/examples/embedders/contextual_count_embedder.ipynb +++ b/examples/embedders/contextual_count_embedder.ipynb @@ -146,7 +146,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Concatenated vector version (default)\n", + "### Concatenated vector version\n", "\n", "Embedder will return vector of length `n * distance` where `n` is number of features from the `CountEmbedder` and `distance` is number of neighbourhoods analysed.\n", "\n",