Skip to content

Commit

Permalink
fix: osmnx as optional dependency in OSMTagLoader, chore: improve Cou…
Browse files Browse the repository at this point in the history
…ntEmbedder and OSMTagLoader examples

fix: properly use osmnx as an optional dependency in OSMTagLoader, chore: utilize other library parts in CountEmbedder example, chore: simplify OSMTagLoader example, fix: wrong print in OSMTagLoader example
  • Loading branch information
simonusher authored Feb 23, 2023
1 parent 39b75e3 commit 350b843
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 235 deletions.
237 changes: 75 additions & 162 deletions examples/embedders/count_embedder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
"source": [
"from shapely import geometry\n",
"import geopandas as gpd\n",
"import pandas as pd"
"from srai.utils.constants import WGS84_CRS\n",
"from srai.loaders.osm_tag_loader import OSMTagLoader\n",
"from srai.regionizers import H3Regionizer\n",
"from srai.joiners import IntersectionJoiner\n",
"from srai.embedders import CountEmbedder"
]
},
{
Expand All @@ -19,62 +23,21 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to use ```CountEmbedder``` we need to prepare some example data. \n",
"In order to use ```CountEmbedder``` we need to prepare some data. \n",
"Namely we need: ```regions_gdf```, ```features_gdf```, and ```joint_gdf```. \n",
"NOTE: These are normally output by the previous pipeline steps."
"These are the outputs of Regionizers, Loaders and Joiners respectively."
]
},
{
"cell_type": "code",
"execution_count": null,
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"regions_gdf = gpd.GeoDataFrame(\n",
" geometry=[\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.02710946531851, 51.110065389823305),\n",
" (17.029634931698617, 51.1092989279356),\n",
" (17.03212452567607, 51.11021450606774),\n",
" (17.032088692873092, 51.11189657169522),\n",
" (17.029563145936592, 51.11266305206119),\n",
" (17.02707351236059, 51.11174744831988),\n",
" (17.02710946531851, 51.110065389823305),\n",
" ],\n",
" ),\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.03212452567607, 51.11021450606774),\n",
" (17.034649970341516, 51.109447934020366),\n",
" (17.037139662738255, 51.11036340911803),\n",
" (17.037103950094387, 51.11204548186887),\n",
" (17.03457842489355, 51.11281207240022),\n",
" (17.032088692873092, 51.11189657169522),\n",
" (17.03212452567607, 51.11021450606774),\n",
" ],\n",
" ),\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.02952725046974, 51.114345051613405),\n",
" (17.029563145936592, 51.11266305206119),\n",
" (17.032088692873092, 51.11189657169522),\n",
" (17.03457842489355, 51.11281207240022),\n",
" (17.03454264959235, 51.11449407907883),\n",
" (17.03201702210393, 51.115260577927586),\n",
" (17.02952725046974, 51.114345051613405),\n",
" ],\n",
" ),\n",
" ],\n",
" index=pd.Index(\n",
" data=[\"891e2040897ffff\", \"891e2040d4bffff\", \"891e2040d5bffff\"], name=\"region_id\"\n",
" ),\n",
" crs=\"epsg:4326\",\n",
")\n",
"regions_gdf"
"### Define the bounding box polygon"
]
},
{
Expand All @@ -83,54 +46,25 @@
"metadata": {},
"outputs": [],
"source": [
"features_gdf = gpd.GeoDataFrame(\n",
" {\n",
" \"leisure\": [\"playground\", None, \"adult_gaming_centre\", None],\n",
" \"amenity\": [None, \"pub\", \"pub\", None],\n",
" },\n",
" geometry=[\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.0360858, 51.1103927),\n",
" (17.0358804, 51.1104389),\n",
" (17.0357855, 51.1105503),\n",
" (17.0359451, 51.1105907),\n",
" (17.0361589, 51.1105402),\n",
" (17.0360858, 51.1103927),\n",
" ]\n",
" ),\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.0317168, 51.1114868),\n",
" (17.0320, 51.1114868),\n",
" (17.0320, 51.1117503),\n",
" (17.0317168, 51.1117503),\n",
" ]\n",
" ),\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.0317168, 51.1124868),\n",
" (17.0320, 51.1124868),\n",
" (17.0320, 51.1127503),\n",
" (17.0317168, 51.1127503),\n",
" ]\n",
" ),\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.0307168, 51.1104868),\n",
" (17.0310, 51.1104868),\n",
" (17.0310, 51.1107503),\n",
" (17.0307168, 51.1107503),\n",
" ]\n",
" ),\n",
" ],\n",
" index=pd.Index(\n",
" data=[\"way/312457804\", \"way/1533817161\", \"way/312457812\", \"way/312457834\"],\n",
" name=\"feature_id\",\n",
" ),\n",
" crs=\"epsg:4326\",\n",
"bbox_polygon = geometry.Polygon(\n",
" [\n",
" [17.0198822, 51.1191217],\n",
" [17.017436, 51.105004],\n",
" [17.0485067, 51.1027944],\n",
" [17.0511246, 51.1175054],\n",
" [17.0198822, 51.1191217],\n",
" ]\n",
")\n",
"features_gdf"
"bbox_gdf = gpd.GeoDataFrame(geometry=[bbox_polygon], crs=WGS84_CRS)\n",
"bbox_gdf"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regionize the area using a H3Regionizer"
]
},
{
Expand All @@ -139,8 +73,18 @@
"metadata": {},
"outputs": [],
"source": [
"ax = regions_gdf.plot()\n",
"features_gdf.plot(ax=ax, color=\"red\")"
"regionizer = H3Regionizer(resolution=8, buffer=True)\n",
"regions_gdf = regionizer.transform(bbox_gdf)\n",
"ax = bbox_gdf.plot()\n",
"regions_gdf.plot(ax=ax, color=\"red\", alpha=0.5)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download some objects from OpenStreetMap"
]
},
{
Expand All @@ -149,56 +93,13 @@
"metadata": {},
"outputs": [],
"source": [
"joint_gdf = gpd.GeoDataFrame(\n",
" geometry=[\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.0358804, 51.1104389),\n",
" (17.0357855, 51.1105503),\n",
" (17.0359451, 51.1105907),\n",
" (17.0361589, 51.1105402),\n",
" (17.0360858, 51.1103927),\n",
" (17.0358804, 51.1104389),\n",
" ]\n",
" ),\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.0317168, 51.1117503),\n",
" (17.032, 51.1117503),\n",
" (17.032, 51.1114868),\n",
" (17.0317168, 51.1114868),\n",
" (17.0317168, 51.1117503),\n",
" ]\n",
" ),\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.0307168, 51.1107503),\n",
" (17.031, 51.1107503),\n",
" (17.031, 51.1104868),\n",
" (17.0307168, 51.1104868),\n",
" (17.0307168, 51.1107503),\n",
" ]\n",
" ),\n",
" geometry.Polygon(\n",
" shell=[\n",
" (17.0317168, 51.1127503),\n",
" (17.032, 51.1127503),\n",
" (17.032, 51.1124868),\n",
" (17.0317168, 51.1124868),\n",
" (17.0317168, 51.1127503),\n",
" ]\n",
" ),\n",
" ],\n",
" index=pd.MultiIndex.from_arrays(\n",
" arrays=[\n",
" [\"891e2040d4bffff\", \"891e2040897ffff\", \"891e2040897ffff\", \"891e2040d5bffff\"],\n",
" [\"way/312457804\", \"way/1533817161\", \"way/312457834\", \"way/312457812\"],\n",
" ],\n",
" names=[\"region_id\", \"feature_id\"],\n",
" ),\n",
" crs=\"epsg:4326\",\n",
")\n",
"joint_gdf"
"loader = OSMTagLoader()\n",
"tags = {\n",
" \"leisure\": [\"playground\", \"adult_gaming_centre\"],\n",
" \"amenity\": \"pub\",\n",
"}\n",
"features_gdf = loader.load(bbox_gdf, tags=tags)\n",
"features_gdf"
]
},
{
Expand All @@ -207,14 +108,19 @@
"metadata": {},
"outputs": [],
"source": [
"joint_gdf.plot()"
"ax = regions_gdf.plot()\n",
"features_gdf.plot(\n",
" ax=ax,\n",
" color=\"red\",\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Embed using features existing in data"
"### Join the objects with the regions they belong to"
]
},
{
Expand All @@ -223,9 +129,9 @@
"metadata": {},
"outputs": [],
"source": [
"from srai.embedders import CountEmbedder\n",
"\n",
"embedder = CountEmbedder()"
"joiner = IntersectionJoiner()\n",
"joint_gdf = joiner.transform(regions_gdf, features_gdf)\n",
"joint_gdf"
]
},
{
Expand All @@ -234,15 +140,14 @@
"metadata": {},
"outputs": [],
"source": [
"embedding = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n",
"embedding"
"joint_gdf.plot()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Embed with specifying expected output features"
"## Embed using features existing in data"
]
},
{
Expand All @@ -251,11 +156,16 @@
"metadata": {},
"outputs": [],
"source": [
"from srai.embedders import CountEmbedder\n",
"\n",
"embedder = CountEmbedder(\n",
" expected_output_features=[\"amenity_parking\", \"leisure_park\", \"amenity_pub\"]\n",
")"
"embedder = CountEmbedder()\n",
"embedding = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n",
"embedding"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Embed with specifying expected output features"
]
},
{
Expand All @@ -264,6 +174,9 @@
"metadata": {},
"outputs": [],
"source": [
"embedder = CountEmbedder(\n",
" expected_output_features=[\"amenity_parking\", \"leisure_park\", \"amenity_pub\"]\n",
")\n",
"embedding_expected_features = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n",
"embedding_expected_features"
]
Expand Down Expand Up @@ -295,11 +208,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.14"
"version": "3.8.16"
},
"vscode": {
"interpreter": {
"hash": "f39c7279c85c8be5d827e53eddb5011e966102d239fe8b81ca4bd9f0123eda8f"
"hash": "cdb8aaadc0decb944250d2ac9b06f485d1fc395bd22d4875475c731b86175a8b"
}
}
},
Expand Down
Loading

0 comments on commit 350b843

Please sign in to comment.