Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-714] Add geopandas to spark arrow conversion. #1825

Merged
merged 9 commits into from
Feb 26, 2025
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ clean:
rm -rf __pycache__
rm -rf .mypy_cache
rm -rf .pytest_cache

run-docs:
docker build -f docker/docs/Dockerfile -t mkdocs-sedona .
docker run --rm -it -p 8000:8000 -v ${PWD}:/docs mkdocs-sedona
8 changes: 8 additions & 0 deletions docker/docs/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FROM squidfunk/mkdocs-material:9.6

RUN apk update
RUN apk add gcc musl-dev linux-headers
RUN pip install mkdocs-macros-plugin \
mkdocs-git-revision-date-localized-plugin \
mkdocs-jupyter \
mike
19 changes: 19 additions & 0 deletions docs/tutorial/geopandas-shapely.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ This query will show the following outputs:

```

To leverage Arrow optimization and speed up the conversion, you can use the `create_spatial_dataframe`
that takes a SparkSession and GeoDataFrame as parameters and returns a Sedona DataFrame.

```python
def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) -> DataFrame
```

- spark: SparkSession
- gdf: gpd.GeoDataFrame
- return: DataFrame

Example:

```python
from sedona.utils.geoarrow import create_spatial_dataframe

create_spatial_dataframe(spark, gdf)
```

### From Sedona DataFrame to GeoPandas

Reading data with Spark and converting to GeoPandas
Expand Down
139 changes: 138 additions & 1 deletion python/sedona/utils/geoarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import itertools
from typing import List, Callable

# We may be able to achieve streaming rather than complete materialization by using
# with the ArrowStreamSerializer (instead of the ArrowCollectSerializer)


from sedona.sql.types import GeometryType
from sedona.sql.st_functions import ST_AsEWKB
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType, StructField, DataType, ArrayType, MapType

from sedona.sql.types import GeometryType
import geopandas as gpd
from pyspark.sql.pandas.types import (
from_arrow_type,
)
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer


def dataframe_to_arrow(df, crs=None):
Expand Down Expand Up @@ -186,3 +197,129 @@ def unique_srid_from_ewkb(obj):
import pyproj

return pyproj.CRS(f"EPSG:{epsg_code}")


def _dedup_names(names: List[str]) -> List[str]:
if len(set(names)) == len(names):
return names
else:

def _gen_dedup(_name: str) -> Callable[[], str]:
_i = itertools.count()
return lambda: f"{_name}_{next(_i)}"

def _gen_identity(_name: str) -> Callable[[], str]:
return lambda: _name

gen_new_name = {
name: _gen_dedup(name) if len(list(group)) > 1 else _gen_identity(name)
for name, group in itertools.groupby(sorted(names))
}
return [gen_new_name[name]() for name in names]


# Backport from Spark 4.0
# https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/types.py#L1385
def _deduplicate_field_names(dt: DataType) -> DataType:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _deduplicate_field_names(dt: DataType) -> DataType:
# Backport from Spark 4.0
# https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/types.py#L1385
def _deduplicate_field_names(dt: DataType) -> DataType:

if isinstance(dt, StructType):
dedup_field_names = _dedup_names(dt.names)

return StructType(
[
StructField(
dedup_field_names[i],
_deduplicate_field_names(field.dataType),
nullable=field.nullable,
)
for i, field in enumerate(dt.fields)
]
)
elif isinstance(dt, ArrayType):
return ArrayType(
_deduplicate_field_names(dt.elementType), containsNull=dt.containsNull
)
elif isinstance(dt, MapType):
return MapType(
_deduplicate_field_names(dt.keyType),
_deduplicate_field_names(dt.valueType),
valueContainsNull=dt.valueContainsNull,
)
else:
return dt


def infer_schema(gdf: gpd.GeoDataFrame) -> StructType:
import pyarrow as pa

fields = gdf.dtypes.reset_index().values.tolist()
geom_fields = []
index = 0
for name, dtype in fields:
if dtype == "geometry":
geom_fields.append((index, name))
continue

index += 1

if not geom_fields:
raise ValueError("No geometry field found in the GeoDataFrame")

pa_schema = pa.Schema.from_pandas(
gdf.drop([name for _, name in geom_fields], axis=1)
)

spark_schema = []

for field in pa_schema:
field_type = field.type
spark_type = from_arrow_type(field_type)
spark_schema.append(StructField(field.name, spark_type, True))

for index, geom_field in geom_fields:
spark_schema.insert(index, StructField(geom_field, GeometryType(), True))

return StructType(spark_schema)


# Modified backport from Spark 4.0
# https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/conversion.py#L632
def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) -> DataFrame:
from pyspark.sql.pandas.types import (
to_arrow_type,
)

def reader_func(temp_filename):
return spark._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)

def create_iter_server():
return spark._jvm.ArrowIteratorServer()

schema = infer_schema(gdf)
timezone = spark._jconf.sessionLocalTimeZone()
step = spark._jconf.arrowMaxRecordsPerBatch()
step = step if step > 0 else len(gdf)
pdf_slices = (gdf.iloc[start : start + step] for start in range(0, len(gdf), step))
spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields]

arrow_data = [
[
(c, to_arrow_type(t) if t is not None else None, t)
for (_, c), t in zip(pdf_slice.items(), spark_types)
]
for pdf_slice in pdf_slices
]

safecheck = spark._jconf.arrowSafeTypeConversion()
ser = ArrowStreamPandasSerializer(timezone, safecheck)
jiter = spark._sc._serialize_to_jvm(
arrow_data, ser, reader_func, create_iter_server
)

jsparkSession = spark._jsparkSession
jdf = spark._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession)

df = DataFrame(jdf, spark)

df._schema = schema

return df
94 changes: 94 additions & 0 deletions python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pytest

from sedona.sql.types import GeometryType
from sedona.utils.geoarrow import create_spatial_dataframe
from tests.test_base import TestBase
import geopandas as gpd
import pyspark


class TestGeopandasToSedonaWithArrow(TestBase):

@pytest.mark.skipif(
not pyspark.__version__.startswith("3.5"),
reason="It's only working with Spark 3.5",
)
def test_conversion_dataframe(self):
gdf = gpd.GeoDataFrame(
{
"name": ["Sedona", "Apache"],
"geometry": gpd.points_from_xy([0, 1], [0, 1]),
}
)

df = create_spatial_dataframe(self.spark, gdf)

assert df.count() == 2
assert df.columns == ["name", "geometry"]
assert df.schema["geometry"].dataType == GeometryType()

@pytest.mark.skipif(
not pyspark.__version__.startswith("3.5"),
reason="It's only working with Spark 3.5",
)
def test_different_geometry_positions(self):
gdf = gpd.GeoDataFrame(
{
"geometry": gpd.points_from_xy([0, 1], [0, 1]),
"name": ["Sedona", "Apache"],
}
)

gdf2 = gpd.GeoDataFrame(
{
"name": ["Sedona", "Apache"],
"name1": ["Sedona", "Apache"],
"name2": ["Sedona", "Apache"],
"geometry": gpd.points_from_xy([0, 1], [0, 1]),
}
)

df1 = create_spatial_dataframe(self.spark, gdf)
df2 = create_spatial_dataframe(self.spark, gdf2)

assert df1.count() == 2
assert df1.columns == ["geometry", "name"]
assert df1.schema["geometry"].dataType == GeometryType()

assert df2.count() == 2
assert df2.columns == ["name", "name1", "name2", "geometry"]
assert df2.schema["geometry"].dataType == GeometryType()

@pytest.mark.skipif(
not pyspark.__version__.startswith("3.5"),
reason="It's only working with Spark 3.5",
)
def test_multiple_geometry_columns(self):
gdf = gpd.GeoDataFrame(
{
"name": ["Sedona", "Apache"],
"geometry": gpd.points_from_xy([0, 1], [0, 1]),
"geometry2": gpd.points_from_xy([0, 1], [0, 1]),
}
)

df = create_spatial_dataframe(self.spark, gdf)

assert df.count() == 2
assert df.columns == ["name", "geometry2", "geometry"]
assert df.schema["geometry"].dataType == GeometryType()
assert df.schema["geometry2"].dataType == GeometryType()

@pytest.mark.skipif(
not pyspark.__version__.startswith("3.5"),
reason="It's only working with Spark 3.5",
)
def test_missing_geometry_column(self):
gdf = gpd.GeoDataFrame(
{
"name": ["Sedona", "Apache"],
},
)

with pytest.raises(ValueError):
create_spatial_dataframe(self.spark, gdf)
Loading