From ba4a3937fd40bc5abfbe2809bedf200a95a9f852 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20Wo=C5=BAniak?= Date: Thu, 23 Feb 2023 00:52:41 +0100 Subject: [PATCH] fix: properly use osmnx as an optional dependency in OSMTagLoader, chore: utilize other library parts in CountEmbedder example, chore: simplify OSMTagLoader example, fix: wrong print in OSMTagLoader example --- examples/embedders/count_embedder.ipynb | 237 ++++++------------ examples/loaders/osm_tag_loader.ipynb | 92 ++----- srai/loaders/osm_tag_loader/osm_tag_loader.py | 8 +- 3 files changed, 102 insertions(+), 235 deletions(-) diff --git a/examples/embedders/count_embedder.ipynb b/examples/embedders/count_embedder.ipynb index 57c86ed4..fc8e6fe9 100644 --- a/examples/embedders/count_embedder.ipynb +++ b/examples/embedders/count_embedder.ipynb @@ -8,7 +8,11 @@ "source": [ "from shapely import geometry\n", "import geopandas as gpd\n", - "import pandas as pd" + "from srai.utils.constants import WGS84_CRS\n", + "from srai.loaders.osm_tag_loader import OSMTagLoader\n", + "from srai.regionizers import H3Regionizer\n", + "from srai.joiners import IntersectionJoiner\n", + "from srai.embedders import CountEmbedder" ] }, { @@ -19,62 +23,21 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "In order to use ```CountEmbedder``` we need to prepare some example data. \n", + "In order to use ```CountEmbedder``` we need to prepare some data. \n", "Namely we need: ```regions_gdf```, ```features_gdf```, and ```joint_gdf```. \n", - "NOTE: These are normally output by the previous pipeline steps." + "These are the outputs of Regionizers, Loaders and Joiners respectively." ] }, { - "cell_type": "code", - "execution_count": null, + "attachments": {}, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "regions_gdf = gpd.GeoDataFrame(\n", - " geometry=[\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.02710946531851, 51.110065389823305),\n", - " (17.029634931698617, 51.1092989279356),\n", - " (17.03212452567607, 51.11021450606774),\n", - " (17.032088692873092, 51.11189657169522),\n", - " (17.029563145936592, 51.11266305206119),\n", - " (17.02707351236059, 51.11174744831988),\n", - " (17.02710946531851, 51.110065389823305),\n", - " ],\n", - " ),\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.03212452567607, 51.11021450606774),\n", - " (17.034649970341516, 51.109447934020366),\n", - " (17.037139662738255, 51.11036340911803),\n", - " (17.037103950094387, 51.11204548186887),\n", - " (17.03457842489355, 51.11281207240022),\n", - " (17.032088692873092, 51.11189657169522),\n", - " (17.03212452567607, 51.11021450606774),\n", - " ],\n", - " ),\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.02952725046974, 51.114345051613405),\n", - " (17.029563145936592, 51.11266305206119),\n", - " (17.032088692873092, 51.11189657169522),\n", - " (17.03457842489355, 51.11281207240022),\n", - " (17.03454264959235, 51.11449407907883),\n", - " (17.03201702210393, 51.115260577927586),\n", - " (17.02952725046974, 51.114345051613405),\n", - " ],\n", - " ),\n", - " ],\n", - " index=pd.Index(\n", - " data=[\"891e2040897ffff\", \"891e2040d4bffff\", \"891e2040d5bffff\"], name=\"region_id\"\n", - " ),\n", - " crs=\"epsg:4326\",\n", - ")\n", - "regions_gdf" + "### Define the bounding box polygon" ] }, { @@ -83,54 +46,25 @@ "metadata": {}, "outputs": [], "source": [ - "features_gdf = gpd.GeoDataFrame(\n", - " {\n", - " \"leisure\": [\"playground\", None, \"adult_gaming_centre\", None],\n", - " \"amenity\": [None, \"pub\", \"pub\", None],\n", - " },\n", - " geometry=[\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.0360858, 51.1103927),\n", - " (17.0358804, 51.1104389),\n", - " (17.0357855, 51.1105503),\n", - " (17.0359451, 51.1105907),\n", - " (17.0361589, 51.1105402),\n", - " (17.0360858, 51.1103927),\n", - " ]\n", - " ),\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.0317168, 51.1114868),\n", - " (17.0320, 51.1114868),\n", - " (17.0320, 51.1117503),\n", - " (17.0317168, 51.1117503),\n", - " ]\n", - " ),\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.0317168, 51.1124868),\n", - " (17.0320, 51.1124868),\n", - " (17.0320, 51.1127503),\n", - " (17.0317168, 51.1127503),\n", - " ]\n", - " ),\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.0307168, 51.1104868),\n", - " (17.0310, 51.1104868),\n", - " (17.0310, 51.1107503),\n", - " (17.0307168, 51.1107503),\n", - " ]\n", - " ),\n", - " ],\n", - " index=pd.Index(\n", - " data=[\"way/312457804\", \"way/1533817161\", \"way/312457812\", \"way/312457834\"],\n", - " name=\"feature_id\",\n", - " ),\n", - " crs=\"epsg:4326\",\n", + "bbox_polygon = geometry.Polygon(\n", + " [\n", + " [17.0198822, 51.1191217],\n", + " [17.017436, 51.105004],\n", + " [17.0485067, 51.1027944],\n", + " [17.0511246, 51.1175054],\n", + " [17.0198822, 51.1191217],\n", + " ]\n", ")\n", - "features_gdf" + "bbox_gdf = gpd.GeoDataFrame(geometry=[bbox_polygon], crs=WGS84_CRS)\n", + "bbox_gdf" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Regionize the area using a H3Regionizer" ] }, { @@ -139,8 +73,18 @@ "metadata": {}, "outputs": [], "source": [ - "ax = regions_gdf.plot()\n", - "features_gdf.plot(ax=ax, color=\"red\")" + "regionizer = H3Regionizer(resolution=8, buffer=True)\n", + "regions_gdf = regionizer.transform(bbox_gdf)\n", + "ax = bbox_gdf.plot()\n", + "regions_gdf.plot(ax=ax, color=\"red\", alpha=0.5)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download some objects from OpenStreetMap" ] }, { @@ -149,56 +93,13 @@ "metadata": {}, "outputs": [], "source": [ - "joint_gdf = gpd.GeoDataFrame(\n", - " geometry=[\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.0358804, 51.1104389),\n", - " (17.0357855, 51.1105503),\n", - " (17.0359451, 51.1105907),\n", - " (17.0361589, 51.1105402),\n", - " (17.0360858, 51.1103927),\n", - " (17.0358804, 51.1104389),\n", - " ]\n", - " ),\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.0317168, 51.1117503),\n", - " (17.032, 51.1117503),\n", - " (17.032, 51.1114868),\n", - " (17.0317168, 51.1114868),\n", - " (17.0317168, 51.1117503),\n", - " ]\n", - " ),\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.0307168, 51.1107503),\n", - " (17.031, 51.1107503),\n", - " (17.031, 51.1104868),\n", - " (17.0307168, 51.1104868),\n", - " (17.0307168, 51.1107503),\n", - " ]\n", - " ),\n", - " geometry.Polygon(\n", - " shell=[\n", - " (17.0317168, 51.1127503),\n", - " (17.032, 51.1127503),\n", - " (17.032, 51.1124868),\n", - " (17.0317168, 51.1124868),\n", - " (17.0317168, 51.1127503),\n", - " ]\n", - " ),\n", - " ],\n", - " index=pd.MultiIndex.from_arrays(\n", - " arrays=[\n", - " [\"891e2040d4bffff\", \"891e2040897ffff\", \"891e2040897ffff\", \"891e2040d5bffff\"],\n", - " [\"way/312457804\", \"way/1533817161\", \"way/312457834\", \"way/312457812\"],\n", - " ],\n", - " names=[\"region_id\", \"feature_id\"],\n", - " ),\n", - " crs=\"epsg:4326\",\n", - ")\n", - "joint_gdf" + "loader = OSMTagLoader()\n", + "tags = {\n", + " \"leisure\": [\"playground\", \"adult_gaming_centre\"],\n", + " \"amenity\": \"pub\",\n", + "}\n", + "features_gdf = loader.load(bbox_gdf, tags=tags)\n", + "features_gdf" ] }, { @@ -207,14 +108,19 @@ "metadata": {}, "outputs": [], "source": [ - "joint_gdf.plot()" + "ax = regions_gdf.plot()\n", + "features_gdf.plot(\n", + " ax=ax,\n", + " color=\"red\",\n", + ")" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## Embed using features existing in data" + "### Join the objects with the regions they belong to" ] }, { @@ -223,9 +129,9 @@ "metadata": {}, "outputs": [], "source": [ - "from srai.embedders import CountEmbedder\n", - "\n", - "embedder = CountEmbedder()" + "joiner = IntersectionJoiner()\n", + "joint_gdf = joiner.transform(regions_gdf, features_gdf)\n", + "joint_gdf" ] }, { @@ -234,15 +140,14 @@ "metadata": {}, "outputs": [], "source": [ - "embedding = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n", - "embedding" + "joint_gdf.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Embed with specifying expected output features" + "## Embed using features existing in data" ] }, { @@ -251,11 +156,16 @@ "metadata": {}, "outputs": [], "source": [ - "from srai.embedders import CountEmbedder\n", - "\n", - "embedder = CountEmbedder(\n", - " expected_output_features=[\"amenity_parking\", \"leisure_park\", \"amenity_pub\"]\n", - ")" + "embedder = CountEmbedder()\n", + "embedding = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n", + "embedding" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Embed with specifying expected output features" ] }, { @@ -264,6 +174,9 @@ "metadata": {}, "outputs": [], "source": [ + "embedder = CountEmbedder(\n", + " expected_output_features=[\"amenity_parking\", \"leisure_park\", \"amenity_pub\"]\n", + ")\n", "embedding_expected_features = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n", "embedding_expected_features" ] @@ -295,11 +208,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.14" + "version": "3.8.16" }, "vscode": { "interpreter": { - "hash": "f39c7279c85c8be5d827e53eddb5011e966102d239fe8b81ca4bd9f0123eda8f" + "hash": "cdb8aaadc0decb944250d2ac9b06f485d1fc395bd22d4875475c731b86175a8b" } } }, diff --git a/examples/loaders/osm_tag_loader.ipynb b/examples/loaders/osm_tag_loader.ipynb index 3b8f3d81..56026371 100644 --- a/examples/loaders/osm_tag_loader.ipynb +++ b/examples/loaders/osm_tag_loader.ipynb @@ -8,6 +8,19 @@ "# OSM Tag Loader" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from srai.loaders.osm_tag_loader.filters.popular import get_popular_tags\n", + "from srai.loaders.osm_tag_loader.filters import HEX2VEC_FILTER\n", + "from srai.loaders.osm_tag_loader import OSMTagLoader\n", + "from functional import seq\n", + "import osmnx as ox" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -33,26 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "from srai.loaders.osm_tag_loader.filters.popular import get_popular_tags" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "all_popular_tags = get_popular_tags()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from functional import seq\n", - "\n", + "all_popular_tags = get_popular_tags()\n", "num_keys = len(all_popular_tags)\n", "num_values = seq(all_popular_tags.values()).map(len).sum()\n", "f\"Unique keys: {num_keys}. Key/value pairs: {num_values}\"" @@ -91,20 +85,9 @@ "metadata": {}, "outputs": [], "source": [ - "from srai.loaders.osm_tag_loader.filters import HEX2VEC_FILTER" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from functional import seq\n", - "\n", "hex_2_vec_keys = len(HEX2VEC_FILTER)\n", "hex_2_vec_key_values = seq(HEX2VEC_FILTER.values()).map(len).sum()\n", - "f\"Unique keys: {num_keys}. Key/value pairs: {num_values}\"" + "f\"Unique keys: {hex_2_vec_keys}. Key/value pairs: {hex_2_vec_key_values}\"" ] }, { @@ -115,18 +98,6 @@ "## Using OSMTagLoader to download data for a specific area" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import osmnx as ox\n", - "from srai.loaders.osm_tag_loader import OSMTagLoader\n", - "\n", - "loader = OSMTagLoader()" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -141,24 +112,9 @@ "metadata": {}, "outputs": [], "source": [ - "parks_filter = {\"leisure\": \"park\"}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wroclaw_gdf = ox.geocode_to_gdf(\"Wrocław, Poland\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "loader = OSMTagLoader()\n", + "parks_filter = {\"leisure\": \"park\"}\n", + "wroclaw_gdf = ox.geocode_to_gdf(\"Wrocław, Poland\")\n", "parks_gdf = loader.load(wroclaw_gdf, parks_filter)\n", "parks_gdf" ] @@ -187,15 +143,7 @@ "metadata": {}, "outputs": [], "source": [ - "barcelona_gdf = ox.geocode_to_gdf(\"Barcelona\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "barcelona_gdf = ox.geocode_to_gdf(\"Barcelona\")\n", "barcelona_filter = {\"building\": \"hotel\", \"amenity\": [\"bar\", \"cafe\", \"pub\"], \"sport\": \"soccer\"}\n", "barcelona_objects_gdf = loader.load(barcelona_gdf, barcelona_filter)\n", "barcelona_objects_gdf" @@ -236,7 +184,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.8.16" }, "vscode": { "interpreter": { diff --git a/srai/loaders/osm_tag_loader/osm_tag_loader.py b/srai/loaders/osm_tag_loader/osm_tag_loader.py index 6e462680..4d4d9e0f 100644 --- a/srai/loaders/osm_tag_loader/osm_tag_loader.py +++ b/srai/loaders/osm_tag_loader/osm_tag_loader.py @@ -7,11 +7,11 @@ from typing import Dict, List, Tuple, Union import geopandas as gpd -import osmnx as ox import pandas as pd from functional import seq from tqdm import tqdm +from srai.utils._optional import import_optional_dependencies from srai.utils.constants import FEATURES_INDEX, WGS84_CRS @@ -35,6 +35,10 @@ class OSMTagLoader: _OSMID_INDEX_NAME = "osmid" _RESULT_INDEX_NAMES = [_ELEMENT_TYPE_INDEX_NAME, _OSMID_INDEX_NAME] + def __init__(self) -> None: + """Initialize OSMTagLoader.""" + import_optional_dependencies(dependency_group="osm", modules=["osmnx"]) + def load( self, area: gpd.GeoDataFrame, @@ -64,6 +68,8 @@ def load( Returns: gpd.GeoDataFrame: Downloaded objects as a GeoDataFrame. """ + import osmnx as ox + area_wgs84 = area.to_crs(crs=WGS84_CRS) _tags = self._flatten_tags(tags)