From 7044c0c1dec4df41e7a8c4b0a8ebbf73087bc84e Mon Sep 17 00:00:00 2001 From: "zongsi.zhang" Date: Wed, 19 Apr 2023 13:36:40 +0800 Subject: [PATCH] [SEDONA-265] Migrate all ST functions to Sedona Inferred Expressions (#820) --- .../apache/sedona/common/Constructors.java | 66 ++++++ .../org/apache/sedona/common/Functions.java | 105 ++++++++- .../org/apache/sedona/common/Predicates.java | 52 +++++ .../sedona/common/utils/GeoHashDecoder.java | 84 +++++++ .../org/apache/sedona/sql/UDF/Catalog.scala | 2 +- .../sedona_sql/expressions/Constructors.scala | 194 +--------------- .../sedona_sql/expressions/Functions.scala | 207 ++---------------- .../expressions/NullSafeExpressions.scala | 17 +- .../sedona_sql/expressions/Predicates.scala | 32 ++- .../sedona_sql/expressions/implicits.scala | 19 ++ .../expressions/st_constructors.scala | 4 + .../sedona_sql/expressions/st_functions.scala | 5 +- 12 files changed, 390 insertions(+), 397 deletions(-) create mode 100644 common/src/main/java/org/apache/sedona/common/Predicates.java create mode 100644 common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java diff --git a/common/src/main/java/org/apache/sedona/common/Constructors.java b/common/src/main/java/org/apache/sedona/common/Constructors.java index 4cc5c48e20..793738deb2 100644 --- a/common/src/main/java/org/apache/sedona/common/Constructors.java +++ b/common/src/main/java/org/apache/sedona/common/Constructors.java @@ -16,15 +16,25 @@ import org.apache.sedona.common.enums.FileDataSplitter; import org.apache.sedona.common.enums.GeometryType; import org.apache.sedona.common.utils.FormatUtils; +import org.apache.sedona.common.utils.GeoHashDecoder; import org.locationtech.jts.geom.Coordinate; import org.locationtech.jts.geom.Geometry; import org.locationtech.jts.geom.GeometryFactory; import org.locationtech.jts.geom.PrecisionModel; import org.locationtech.jts.io.ParseException; +import org.locationtech.jts.io.WKBReader; import org.locationtech.jts.io.WKTReader; +import org.locationtech.jts.io.gml2.GMLReader; +import org.locationtech.jts.io.kml.KMLReader; +import org.xml.sax.SAXException; + +import javax.xml.parsers.ParserConfigurationException; +import java.io.IOException; public class Constructors { + private static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory(); + public static Geometry geomFromWKT(String wkt, int srid) throws ParseException { if (wkt == null) { return null; @@ -33,6 +43,10 @@ public static Geometry geomFromWKT(String wkt, int srid) throws ParseException { return new WKTReader(geometryFactory).read(wkt); } + public static Geometry geomFromWKB(byte[] wkb) throws ParseException { + return new WKBReader().read(wkb); + } + public static Geometry mLineFromText(String wkt, int srid) throws ParseException { if (wkt == null || !wkt.startsWith("MULTILINESTRING")) { return null; @@ -100,4 +114,56 @@ public static Geometry geomFromText(String geomString, FileDataSplitter fileData throw new RuntimeException(e); } } + + public static Geometry pointFromText(String geomString, String geomFormat) { + return geomFromText(geomString, geomFormat, GeometryType.POINT); + } + + public static Geometry polygonFromText(String geomString, String geomFormat) { + return geomFromText(geomString, geomFormat, GeometryType.POLYGON); + } + + public static Geometry lineStringFromText(String geomString, String geomFormat) { + return geomFromText(geomString, geomFormat, GeometryType.LINESTRING); + } + + public static Geometry lineFromText(String geomString) { + FileDataSplitter fileDataSplitter = FileDataSplitter.WKT; + Geometry geometry = Constructors.geomFromText(geomString, fileDataSplitter); + if(geometry.getGeometryType().contains("LineString")) { + return geometry; + } else { + return null; + } + } + + public static Geometry polygonFromEnvelope(double minX, double minY, double maxX, double maxY) { + Coordinate[] coordinates = new Coordinate[5]; + coordinates[0] = new Coordinate(minX, minY); + coordinates[1] = new Coordinate(minX, maxY); + coordinates[2] = new Coordinate(maxX, maxY); + coordinates[3] = new Coordinate(maxX, minY); + coordinates[4] = coordinates[0]; + return GEOMETRY_FACTORY.createPolygon(coordinates); + } + + public static Geometry geomFromGeoHash(String geoHash, Integer precision) { + System.out.println(geoHash); + System.out.println(precision); + try { + return GeoHashDecoder.decode(geoHash, precision); + } catch (GeoHashDecoder.InvalidGeoHashException e) { + return null; + } + } + + public static Geometry geomFromGML(String gml) throws IOException, ParserConfigurationException, SAXException { + return new GMLReader().read(gml, GEOMETRY_FACTORY); + } + + public static Geometry geomFromKML(String kml) throws ParseException { + return new KMLReader().read(kml); + } + + } diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java index 2c90d06b7e..c429c4c8c3 100644 --- a/common/src/main/java/org/apache/sedona/common/Functions.java +++ b/common/src/main/java/org/apache/sedona/common/Functions.java @@ -14,11 +14,8 @@ package org.apache.sedona.common; import com.google.common.geometry.S2CellId; -import com.google.common.geometry.S2Point; -import com.google.common.geometry.S2Region; -import com.google.common.geometry.S2RegionCoverer; -import org.apache.commons.lang3.ArrayUtils; import org.apache.sedona.common.geometryObjects.Circle; +import org.apache.sedona.common.subDivide.GeometrySubDivider; import org.apache.sedona.common.utils.GeomUtils; import org.apache.sedona.common.utils.GeometryGeoHashEncoder; import org.apache.sedona.common.utils.GeometrySplitter; @@ -37,6 +34,7 @@ import org.locationtech.jts.operation.valid.IsSimpleOp; import org.locationtech.jts.operation.valid.IsValidOp; import org.locationtech.jts.precision.GeometryPrecisionReducer; +import org.locationtech.jts.simplify.TopologyPreservingSimplifier; import org.opengis.referencing.FactoryException; import org.opengis.referencing.NoSuchAuthorityCodeException; import org.opengis.referencing.crs.CoordinateReferenceSystem; @@ -48,7 +46,6 @@ import java.util.Arrays; import java.util.HashSet; import java.util.List; -import java.util.function.Function; import java.util.stream.Collectors; @@ -575,4 +572,102 @@ public static Long[] s2CellIDs(Geometry input, int level) { } return S2Utils.roundCellsToSameLevel(new ArrayList<>(cellIds), level).stream().map(S2CellId::id).collect(Collectors.toList()).toArray(new Long[cellIds.size()]); } + + + // create static function named simplifyPreserveTopology + public static Geometry simplifyPreserveTopology(Geometry geometry, double distanceTolerance) { + return TopologyPreservingSimplifier.simplify(geometry, distanceTolerance); + } + + public static String geometryType(Geometry geometry) { + return "ST_" + geometry.getGeometryType(); + } + + public static Geometry startPoint(Geometry geometry) { + if (geometry instanceof LineString) { + LineString line = (LineString) geometry; + return line.getStartPoint(); + } + return null; + } + + public static Geometry endPoint(Geometry geometry) { + if (geometry instanceof LineString) { + LineString line = (LineString) geometry; + return line.getEndPoint(); + } + return null; + } + + public static Geometry[] dump(Geometry geometry) { + int numGeom = geometry.getNumGeometries(); + if (geometry instanceof GeometryCollection) { + Geometry[] geoms = new Geometry[geometry.getNumGeometries()]; + for (int i = 0; i < numGeom; i++) { + geoms[i] = geometry.getGeometryN(i); + } + return geoms; + } else { + return new Geometry[] {geometry}; + } + } + + public static Geometry[] dumpPoints(Geometry geometry) { + return Arrays.stream(geometry.getCoordinates()).map(GEOMETRY_FACTORY::createPoint).toArray(Point[]::new); + } + + public static Geometry symDifference(Geometry leftGeom, Geometry rightGeom) { + return leftGeom.symDifference(rightGeom); + } + + public static Geometry union(Geometry leftGeom, Geometry rightGeom) { + return leftGeom.union(rightGeom); + } + + public static Geometry createMultiGeometryFromOneElement(Geometry geometry) { + if (geometry instanceof Circle) { + return GEOMETRY_FACTORY.createGeometryCollection(new Circle[] {(Circle) geometry}); + } else if (geometry instanceof GeometryCollection) { + return geometry; + } else if (geometry instanceof LineString) { + return GEOMETRY_FACTORY.createMultiLineString(new LineString[]{(LineString) geometry}); + } else if (geometry instanceof Point) { + return GEOMETRY_FACTORY.createMultiPoint(new Point[] {(Point) geometry}); + } else if (geometry instanceof Polygon) { + return GEOMETRY_FACTORY.createMultiPolygon(new Polygon[] {(Polygon) geometry}); + } else { + return GEOMETRY_FACTORY.createGeometryCollection(); + } + } + + public static Geometry[] subDivide(Geometry geometry, int maxVertices) { + return GeometrySubDivider.subDivide(geometry, maxVertices); + } + + public static Geometry makePolygon(Geometry shell, Geometry[] holes) { + try { + if (holes != null) { + LinearRing[] interiorRings = Arrays.stream(holes).filter( + h -> h != null && !h.isEmpty() && h instanceof LineString && ((LineString) h).isClosed() + ).map( + h -> GEOMETRY_FACTORY.createLinearRing(h.getCoordinates()) + ).toArray(LinearRing[]::new); + if (interiorRings.length != 0) { + return GEOMETRY_FACTORY.createPolygon( + GEOMETRY_FACTORY.createLinearRing(shell.getCoordinates()), + Arrays.stream(holes).filter( + h -> h != null && !h.isEmpty() && h instanceof LineString && ((LineString) h).isClosed() + ).map( + h -> GEOMETRY_FACTORY.createLinearRing(h.getCoordinates()) + ).toArray(LinearRing[]::new) + ); + } + } + return GEOMETRY_FACTORY.createPolygon( + GEOMETRY_FACTORY.createLinearRing(shell.getCoordinates()) + ); + } catch (IllegalArgumentException e) { + return null; + } + } } diff --git a/common/src/main/java/org/apache/sedona/common/Predicates.java b/common/src/main/java/org/apache/sedona/common/Predicates.java new file mode 100644 index 0000000000..cb37b5c18b --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/Predicates.java @@ -0,0 +1,52 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common; + +import org.locationtech.jts.geom.Geometry; + +public class Predicates { + public static boolean contains(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.contains(rightGeometry); + } + public static boolean intersects(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.intersects(rightGeometry); + } + public static boolean within(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.within(rightGeometry); + } + public static boolean covers(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.covers(rightGeometry); + } + public static boolean coveredBy(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.coveredBy(rightGeometry); + } + public static boolean crosses(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.crosses(rightGeometry); + } + public static boolean overlaps(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.overlaps(rightGeometry); + } + public static boolean touches(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.touches(rightGeometry); + } + public static boolean equals(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.symDifference(rightGeometry).isEmpty(); + } + public static boolean disjoint(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.disjoint(rightGeometry); + } + public static boolean orderingEquals(Geometry leftGeometry, Geometry rightGeometry) { + return leftGeometry.equalsExact(rightGeometry); + } +} diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java b/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java new file mode 100644 index 0000000000..dd1bba1c9e --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java @@ -0,0 +1,84 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.utils; + +import org.locationtech.jts.geom.Geometry; + +public class GeoHashDecoder { + private static final int[] bits = new int[] {16, 8, 4, 2, 1}; + private static final String base32 = "0123456789bcdefghjkmnpqrstuvwxyz"; + + public static class InvalidGeoHashException extends Exception { + public InvalidGeoHashException(String message) { + super(message); + } + } + + public static Geometry decode(String geohash, Integer precision) throws InvalidGeoHashException { + return decodeGeoHashBBox(geohash, precision).getBbox().toPolygon(); + } + + private static class LatLon { + public Double[] lons; + + public Double[] lats; + + public LatLon(Double[] lons, Double[] lats) { + this.lons = lons; + this.lats = lats; + } + + BBox getBbox() { + return new BBox( + lons[0], + lons[1], + lats[0], + lats[1] + ); + } + } + + private static LatLon decodeGeoHashBBox(String geohash, Integer precision) throws InvalidGeoHashException { + LatLon latLon = new LatLon(new Double[] {-180.0, 180.0}, new Double[] {-90.0, 90.0}); + String geoHashLowered = geohash.toLowerCase(); + int geoHashLength = geohash.length(); + int targetPrecision = geoHashLength; + if (precision != null) { + if (precision < 0) throw new InvalidGeoHashException("Precision can not be negative"); + else targetPrecision = Math.min(geoHashLength, precision); + } + boolean isEven = true; + + for (int i = 0; i < targetPrecision ; i++){ + char c = geoHashLowered.charAt(i); + byte cd = (byte) base32.indexOf(c); + if (cd == -1){ + throw new InvalidGeoHashException(String.format("Invalid character '%s' found at index %d", c, i)); + } + for (int j = 0;j < 5; j++){ + byte mask = (byte) bits[j]; + int index = (mask & cd) == 0 ? 1 : 0; + if (isEven){ + latLon.lons[index] = (latLon.lons[0] + latLon.lons[1]) / 2; + } + else { + latLon.lats[index] = (latLon.lats[0] + latLon.lats[1]) / 2; + } + isEven = !isEven; + } + } + return latLon; + } + +} diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 03dbbab5ef..fa3493abf1 100644 --- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -117,7 +117,7 @@ object Catalog { function[ST_LineInterpolatePoint](), function[ST_SubDivideExplode](), function[ST_SubDivide](), - function[ST_MakePolygon](), + function[ST_MakePolygon](null), function[ST_GeoHash](), function[ST_GeomFromGeoHash](null), function[ST_Collect](), diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala index 188ab90667..fff0822859 100644 --- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala +++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala @@ -19,20 +19,15 @@ package org.apache.spark.sql.sedona_sql.expressions import org.apache.sedona.common.Constructors -import org.apache.sedona.common.enums.{FileDataSplitter, GeometryType} +import org.apache.sedona.common.enums.FileDataSplitter import org.apache.sedona.sql.utils.GeometrySerializer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.sedona_sql.expressions.geohash.GeoHashDecoder import org.apache.spark.sql.sedona_sql.expressions.implicits.GeometryEnhancer import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.locationtech.jts.geom.{Coordinate, GeometryFactory} -import org.locationtech.jts.io.WKBReader -import org.locationtech.jts.io.gml2.GMLReader -import org.locationtech.jts.io.kml.KMLReader /** * Return a point from a string. The string must be plain string and each coordinate must be separated by a delimiter. @@ -41,25 +36,7 @@ import org.locationtech.jts.io.kml.KMLReader * string, the second parameter is the delimiter. String format should be similar to CSV/TSV */ case class ST_PointFromText(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator { - // This is an expression which takes two input expressions. - assert(inputExpressions.length == 2) - - override def nullable: Boolean = false - - override def eval(inputRow: InternalRow): Any = { - val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString - val geomFormat = inputExpressions(1).eval(inputRow).asInstanceOf[UTF8String].toString - val geometry = Constructors.geomFromText(geomString, geomFormat, GeometryType.POINT) - GeometrySerializer.serialize(geometry) - } - - override def dataType: DataType = GeometryUDT - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - - override def children: Seq[Expression] = inputExpressions - + extends InferredBinaryExpression(Constructors.pointFromText) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } @@ -71,26 +48,7 @@ case class ST_PointFromText(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_PolygonFromText(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator { - // This is an expression which takes two input expressions. - assert(inputExpressions.length == 2) - - override def nullable: Boolean = false - - override def eval(inputRow: InternalRow): Any = { - val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString - val geomFormat = inputExpressions(1).eval(inputRow).asInstanceOf[UTF8String].toString - - var geometry = Constructors.geomFromText(geomString, geomFormat, GeometryType.POLYGON) - GeometrySerializer.serialize(geometry) - } - - override def dataType: DataType = GeometryUDT - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - - override def children: Seq[Expression] = inputExpressions - + extends InferredBinaryExpression(Constructors.polygonFromText) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } @@ -102,68 +60,23 @@ case class ST_PolygonFromText(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_LineFromText(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator { - // This is an expression which takes one input expressions. - assert(inputExpressions.length == 1) - - override def nullable: Boolean = true - - override def eval(inputRow: InternalRow): Any = { - val lineString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString - - val fileDataSplitter = FileDataSplitter.WKT - val geometry = Constructors.geomFromText(lineString, fileDataSplitter) - if(geometry.getGeometryType.contains("LineString")) { - GeometrySerializer.serialize(geometry) - } else { - null - } - } - - override def dataType: DataType = GeometryUDT - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def children: Seq[Expression] = inputExpressions - + extends InferredUnaryExpression(Constructors.lineFromText) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } } - /** * Return a linestring from a string. The string must be plain string and each coordinate must be separated by a delimiter. * * @param inputExpressions */ case class ST_LineStringFromText(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator { - // This is an expression which takes two input expressions. - assert(inputExpressions.length == 2) - - override def nullable: Boolean = false - - override def eval(inputRow: InternalRow): Any = { - val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString - val geomFormat = inputExpressions(1).eval(inputRow).asInstanceOf[UTF8String].toString - - val geometry = Constructors.geomFromText(geomString, geomFormat, GeometryType.LINESTRING) - - GeometrySerializer.serialize(geometry) - } - - override def dataType: DataType = GeometryUDT - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - - override def children: Seq[Expression] = inputExpressions - + extends InferredBinaryExpression(Constructors.lineStringFromText) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } } - /** * Return a Geometry from a WKT string * @@ -212,7 +125,7 @@ case class ST_GeomFromWKB(inputExpressions: Seq[Expression]) } case (wkb: Array[Byte]) => { // convert raw wkb byte array to geometry - new WKBReader().read(wkb).toGenericArrayData + Constructors.geomFromWKB(wkb).toGenericArrayData } case null => null } @@ -294,32 +207,7 @@ case class ST_PointZ(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_PolygonFromEnvelope(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator { - assert(inputExpressions.length == 4) - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - val minX = inputExpressions(0).eval(input).asInstanceOf[Double] - val minY = inputExpressions(1).eval(input).asInstanceOf[Double] - val maxX = inputExpressions(2).eval(input).asInstanceOf[Double] - val maxY = inputExpressions(3).eval(input).asInstanceOf[Double] - var coordinates = new Array[Coordinate](5) - coordinates(0) = new Coordinate(minX, minY) - coordinates(1) = new Coordinate(minX, maxY) - coordinates(2) = new Coordinate(maxX, maxY) - coordinates(3) = new Coordinate(maxX, minY) - coordinates(4) = coordinates(0) - val geometryFactory = new GeometryFactory() - val polygon = geometryFactory.createPolygon(coordinates) - GeometrySerializer.serialize(polygon) - } - - override def dataType: DataType = GeometryUDT - - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, DoubleType) - - override def children: Seq[Expression] = inputExpressions + extends InferredQuarternaryExpression(Constructors.polygonFromEnvelope) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -337,81 +225,23 @@ trait UserDataGeneratator { } } - case class ST_GeomFromGeoHash(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback { - override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = { - val geoHash = Option(inputExpressions.head.eval(input)) - .map(_.asInstanceOf[UTF8String].toString) - val precision = Option(inputExpressions(1).eval(input)).map(_.asInstanceOf[Int]) - - try { - geoHash match { - case Some(value) => GeoHashDecoder.decode(value, precision).toGenericArrayData - case None => null - } - } - catch { - case e: Exception => null - } - } - - override def dataType: DataType = GeometryUDT - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType) - - override def children: Seq[Expression] = inputExpressions - + extends InferredBinaryExpression(Constructors.geomFromGeoHash) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } + override def allowRightNull: Boolean = true } case class ST_GeomFromGML(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback { - assert(inputExpressions.length == 1) - override def nullable: Boolean = true - - override def eval(inputRow: InternalRow): Any = { - (inputExpressions(0).eval(inputRow)) match { - case geomString: UTF8String => - new GMLReader().read(geomString.toString, new GeometryFactory()).toGenericArrayData - case _ => null - } - } - - override def dataType: DataType = GeometryUDT - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def children: Seq[Expression] = inputExpressions - + extends InferredUnaryExpression(Constructors.geomFromGML) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } } case class ST_GeomFromKML(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback { - assert(inputExpressions.length == 1) - override def nullable: Boolean = true - - override def eval(inputRow: InternalRow): Any = { - inputExpressions(0).eval(inputRow) match { - case geomString: UTF8String => - new KMLReader().read(geomString.toString).toGenericArrayData - case _ => null - } - } - - override def dataType: DataType = GeometryUDT - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def children: Seq[Expression] = inputExpressions - + extends InferredUnaryExpression(Constructors.geomFromKML) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index a75b044f78..2c56845ba6 100644 --- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -19,19 +19,15 @@ package org.apache.spark.sql.sedona_sql.expressions import org.apache.sedona.common.Functions -import org.apache.sedona.common.subDivide.GeometrySubDivider import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, ImplicitCastInputTypes} +import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.sedona_sql.expressions.collect.Collect import org.apache.spark.sql.sedona_sql.expressions.implicits._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.locationtech.jts.algorithm.MinimumBoundingCircle -import org.locationtech.jts.geom.{Geometry, _} -import org.locationtech.jts.simplify.TopologyPreservingSimplifier +import org.locationtech.jts.geom._ /** * Return the distance between two geometries. @@ -221,26 +217,7 @@ case class ST_Centroid(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Transform(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback { - - override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = { - val geometry = inputExpressions(0).toGeometry(input) - val sourceCRSString = inputExpressions(1).asString(input) - val targetCRSString = inputExpressions(2).asString(input) - val lenient = inputExpressions(3).eval(input).asInstanceOf[Boolean] - (geometry,sourceCRSString,targetCRSString,lenient) match { - case (null,_,_,_) => null - case _ => Functions.transform(geometry, sourceCRSString, targetCRSString, lenient).toGenericArrayData - } - } - - override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, StringType, StringType, BooleanType) - - override def dataType: DataType = GeometryUDT - - override def children: Seq[Expression] = inputExpressions + extends InferredQuarternaryExpression(Functions.transform) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -309,7 +286,7 @@ case class ST_IsSimple(inputExpressions: Seq[Expression]) * second arg is distance tolerance for the simplification(all vertices in the simplified geometry will be within this distance of the original geometry) */ case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression]) - extends InferredBinaryExpression(TopologyPreservingSimplifier.simplify) with FoldableExpression { + extends InferredBinaryExpression(Functions.simplifyPreserveTopology) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -379,15 +356,7 @@ case class ST_SetSRID(inputExpressions: Seq[Expression]) } case class ST_GeometryType(inputExpressions: Seq[Expression]) - extends UnaryGeometryExpression with FoldableExpression with CodegenFallback { - - override protected def nullSafeEval(geometry: Geometry): Any = { - UTF8String.fromString("ST_" + geometry.getGeometryType) - } - - override def dataType: DataType = StringType - - override def children: Seq[Expression] = inputExpressions + extends InferredUnaryExpression(Functions.geometryType) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -443,27 +412,13 @@ case class ST_Z(inputExpressions: Seq[Expression]) } case class ST_StartPoint(inputExpressions: Seq[Expression]) - extends UnaryGeometryExpression with FoldableExpression with CodegenFallback { - - override protected def nullSafeEval(geometry: Geometry): Any = { - geometry match { - case line: LineString => { - line.getPointN(0) - } - case _ => null - } - } - - override def dataType: DataType = GeometryUDT - - override def children: Seq[Expression] = inputExpressions + extends InferredUnaryExpression(Functions.startPoint) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } } - case class ST_Boundary(inputExpressions: Seq[Expression]) extends InferredUnaryExpression(Functions.boundary) with FoldableExpression { @@ -552,20 +507,8 @@ case class ST_LineInterpolatePoint(inputExpressions: Seq[Expression]) } } - case class ST_EndPoint(inputExpressions: Seq[Expression]) - extends UnaryGeometryExpression with FoldableExpression with CodegenFallback { - - override protected def nullSafeEval(geometry: Geometry): Any = { - geometry match { - case string: LineString => string.getEndPoint - case _ => null - } - } - - override def dataType: DataType = GeometryUDT - - override def children: Seq[Expression] = inputExpressions + extends InferredUnaryExpression(Functions.endPoint) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -598,32 +541,7 @@ case class ST_InteriorRingN(inputExpressions: Seq[Expression]) } case class ST_Dump(inputExpressions: Seq[Expression]) - extends UnaryGeometryExpression with FoldableExpression with CodegenFallback { - - override protected def nullSafeEval(geometry: Geometry): Any = { - geometry match { - case collection: GeometryCollection => { - val numberOfGeometries = collection.getNumGeometries - (0 until numberOfGeometries).map( - index => collection.getGeometryN(index) - ).toArray - } - case geom: Geometry => Array(geom) - } - } - - override protected def serializeResult(result: Any): Any = { - result match { - case array: Array[Geometry] => ArrayData.toArrayData( - array.map(_.toGenericArrayData) - ) - case _ => null - } - } - - override def dataType: DataType = ArrayType(GeometryUDT) - - override def children: Seq[Expression] = inputExpressions + extends InferredUnaryExpression(Functions.dump) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -631,25 +549,7 @@ case class ST_Dump(inputExpressions: Seq[Expression]) } case class ST_DumpPoints(inputExpressions: Seq[Expression]) - extends UnaryGeometryExpression with FoldableExpression with CodegenFallback { - - override protected def nullSafeEval(geometry: Geometry): Any = { - geometry.getPoints.map(geom => geom).toArray - } - - override protected def serializeResult(result: Any): Any = { - result match { - case array: Array[Geometry] => ArrayData.toArrayData( - array.map(geom => geom.toGenericArrayData) - ) - case _ => null - } - - } - - override def dataType: DataType = ArrayType(GeometryUDT) - - override def children: Seq[Expression] = inputExpressions + extends InferredUnaryExpression(Functions.dumpPoints) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -746,24 +646,7 @@ case class ST_FlipCoordinates(inputExpressions: Seq[Expression]) } case class ST_SubDivide(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback { - - override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = { - inputExpressions(0).toGeometry(input) match { - case geom: Geometry => ArrayData.toArrayData( - GeometrySubDivider.subDivide(geom, inputExpressions(1).toInt(input)).map(_.toGenericArrayData) - ) - case null => null - } - } - - override def dataType: DataType = ArrayType(GeometryUDT) - - override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, IntegerType) - - override def children: Seq[Expression] = inputExpressions + extends InferredBinaryExpression(Functions.subDivide) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -779,9 +662,9 @@ case class ST_SubDivideExplode(children: Seq[Expression]) val maxVerticesRaw = children(1) geometryRaw.toGeometry(input) match { case geom: Geometry => ArrayData.toArrayData( - GeometrySubDivider.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData) + Functions.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData) ) - GeometrySubDivider.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData).map(InternalRow(_)) + Functions.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData).map(InternalRow(_)) case _ => new Array[InternalRow](0) } } @@ -795,52 +678,14 @@ case class ST_SubDivideExplode(children: Seq[Expression]) } } - case class ST_MakePolygon(inputExpressions: Seq[Expression]) - extends Expression with FoldableExpression with CodegenFallback { - inputExpressions.betweenLength(1, 2) - - override def nullable: Boolean = true - private val geometryFactory = new GeometryFactory() - - override def eval(input: InternalRow): Any = { - val exteriorRing = inputExpressions.head - val possibleHolesRaw = inputExpressions.tail.headOption.map(_.eval(input).asInstanceOf[ArrayData]) - val numOfElements = possibleHolesRaw.map(_.numElements()).getOrElse(0) - - val holes = (0 until numOfElements).map(el => possibleHolesRaw match { - case Some(value) => Some(value.getBinary(el)) - case None => None - }).filter(_.nonEmpty) - .map(el => el.map(_.toGeometry)) - .flatMap{ - case maybeLine: Option[LineString] => - maybeLine.map(line => geometryFactory.createLinearRing(line.getCoordinates)) - case _ => None - } - - exteriorRing.toGeometry(input) match { - case geom: LineString => - try { - val poly = new Polygon(geometryFactory.createLinearRing(geom.getCoordinates), holes.toArray, geometryFactory) - poly.toGenericArrayData - } - catch { - case e: Exception => null - } - - case _ => null - } - - } - - override def dataType: DataType = GeometryUDT - - override def children: Seq[Expression] = inputExpressions + extends InferredBinaryExpression(Functions.makePolygon) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) } + + override def allowRightNull: Boolean = true } case class ST_GeoHash(inputExpressions: Seq[Expression]) @@ -870,15 +715,7 @@ case class ST_Difference(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_SymDifference(inputExpressions: Seq[Expression]) - extends BinaryGeometryExpression with FoldableExpression with CodegenFallback { - - override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = { - leftGeometry.symDifference(rightGeometry) - } - - override def dataType: DataType = GeometryUDT - - override def children: Seq[Expression] = inputExpressions + extends InferredBinaryExpression(Functions.symDifference) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -891,15 +728,7 @@ case class ST_SymDifference(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Union(inputExpressions: Seq[Expression]) - extends BinaryGeometryExpression with FoldableExpression with CodegenFallback { - - override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = { - leftGeometry.union(rightGeometry) - } - - override def dataType: DataType = GeometryUDT - - override def children: Seq[Expression] = inputExpressions + extends InferredBinaryExpression(Functions.union) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) @@ -907,7 +736,7 @@ case class ST_Union(inputExpressions: Seq[Expression]) } case class ST_Multi(inputExpressions: Seq[Expression]) - extends InferredUnaryExpression(Collect.createMultiGeometryFromOneElement) with FoldableExpression { + extends InferredUnaryExpression(Functions.createMultiGeometryFromOneElement) with FoldableExpression { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala index 054347718a..fc4fbb6eeb 100644 --- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala +++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala @@ -123,6 +123,8 @@ sealed class InferrableType[T: TypeTag] object InferrableType { implicit val geometryInstance: InferrableType[Geometry] = new InferrableType[Geometry] {} + implicit val geometryArrayInstance: InferrableType[Array[Geometry]] = + new InferrableType[Array[Geometry]] {} implicit val javaDoubleInstance: InferrableType[java.lang.Double] = new InferrableType[java.lang.Double] {} implicit val javaIntegerInstance: InferrableType[java.lang.Integer] = @@ -145,6 +147,8 @@ object InferredTypes { def buildExtractor[T: TypeTag](expr: Expression): InternalRow => T = { if (typeOf[T] =:= typeOf[Geometry]) { input: InternalRow => expr.toGeometry(input).asInstanceOf[T] + } else if (typeOf[T] =:= typeOf[Array[Geometry]]) { + input: InternalRow => expr.toGeometryArray(input).asInstanceOf[T] } else if (typeOf[T] =:= typeOf[String]) { input: InternalRow => expr.asString(input).asInstanceOf[T] } else { @@ -172,6 +176,13 @@ object InferredTypes { } else { null } + } else if (typeOf[T] =:= typeOf[Array[Geometry]]) { + output: T => + if (output != null) { + ArrayData.toArrayData(output.asInstanceOf[Array[Geometry]].map(_.toGenericArrayData)) + } else { + null + } } else { output: T => output } @@ -180,6 +191,8 @@ object InferredTypes { def inferSparkType[T: TypeTag]: DataType = { if (typeOf[T] =:= typeOf[Geometry]) { GeometryUDT + } else if (typeOf[T] =:= typeOf[Array[Geometry]]) { + DataTypes.createArrayType(GeometryUDT) } else if (typeOf[T] =:= typeOf[java.lang.Double]) { DoubleType } else if (typeOf[T] =:= typeOf[java.lang.Integer]) { @@ -254,6 +267,8 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType, override def nullable: Boolean = true + def allowRightNull: Boolean = false + override def dataType = inferSparkType[R] lazy val extractLeft = buildExtractor[A1](inputExpressions(0)) @@ -266,7 +281,7 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType, override def evalWithoutSerialization(input: InternalRow): Any = { val left = extractLeft(input) val right = extractRight(input) - if (left != null && right != null) { + if (left != null && (right != null || allowRightNull)) { f(left, right) } else { null diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala index 39f68d4c38..00923fd225 100644 --- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala +++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala @@ -18,15 +18,14 @@ */ package org.apache.spark.sql.sedona_sql.expressions +import org.apache.sedona.common.Predicates import org.apache.sedona.sql.utils.GeometrySerializer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{BooleanType, DataType} -import org.locationtech.jts.geom.Geometry -import org.apache.spark.sql.types.AbstractDataType +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, NullIntolerant} import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType} +import org.locationtech.jts.geom.Geometry abstract class ST_Predicate extends Expression with FoldableExpression @@ -73,7 +72,7 @@ case class ST_Contains(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.contains(rightGeometry) + Predicates.contains(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -90,7 +89,7 @@ case class ST_Intersects(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.intersects(rightGeometry) + Predicates.intersects(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -107,7 +106,7 @@ case class ST_Within(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.within(rightGeometry) + Predicates.within(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -124,7 +123,7 @@ case class ST_Covers(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.covers(rightGeometry) + Predicates.covers(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -141,7 +140,7 @@ case class ST_CoveredBy(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.coveredBy(rightGeometry) + Predicates.coveredBy(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -158,7 +157,7 @@ case class ST_Crosses(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.crosses(rightGeometry) + Predicates.crosses(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -176,7 +175,7 @@ case class ST_Overlaps(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.overlaps(rightGeometry) + Predicates.overlaps(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -193,7 +192,7 @@ case class ST_Touches(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.touches(rightGeometry) + Predicates.touches(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -211,8 +210,7 @@ case class ST_Equals(inputExpressions: Seq[Expression]) override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { // Returns GeometryCollection object - val symDifference = leftGeometry.symDifference(rightGeometry) - symDifference.isEmpty + Predicates.equals(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -229,7 +227,7 @@ case class ST_Disjoint(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.disjoint(rightGeometry) + Predicates.disjoint(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -246,7 +244,7 @@ case class ST_OrderingEquals(inputExpressions: Seq[Expression]) extends ST_Predicate with CodegenFallback { override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = { - leftGeometry.equalsExact(rightGeometry) + Predicates.orderingEquals(leftGeometry, rightGeometry) } protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala index 2bacc76653..5b5f02026b 100644 --- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala +++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala @@ -23,6 +23,7 @@ import org.apache.sedona.sql.utils.GeometrySerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.types.{ByteType, DataTypes} import org.apache.spark.unsafe.types.UTF8String import org.locationtech.jts.geom.{Geometry, GeometryFactory, Point} @@ -40,6 +41,24 @@ object implicits { } } + def toGeometryArray(input: InternalRow): Array[Geometry] = { + inputExpression match { + case aware: SerdeAware => + aware.evalWithoutSerialization(input).asInstanceOf[Array[Geometry]] + case _ => + inputExpression.eval(input).asInstanceOf[ArrayData] match { + case arrayData: ArrayData => + val length = arrayData.numElements() + val geometries = new Array[Geometry](length) + for (i <- 0 until length) { + geometries(i) = arrayData.getBinary(i).toGeometry + } + geometries + case _ => null + } + } + } + def toInt(input: InternalRow): Int = { inputExpression.eval(input).asInstanceOf[Int] } diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala index 005d112f42..aa9eada87b 100644 --- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala +++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala @@ -25,6 +25,10 @@ object st_constructors extends DataFrameAPI { def ST_GeomFromGeoHash(geohash: Column, precision: Column): Column = wrapExpression[ST_GeomFromGeoHash](geohash, precision) def ST_GeomFromGeoHash(geohash: String, precision: Int): Column = wrapExpression[ST_GeomFromGeoHash](geohash, precision) + def ST_GeomFromGeoHash(geohash: Column): Column = wrapExpression[ST_GeomFromGeoHash](geohash, null) + + def ST_GeomFromGeoHash(geohash: String): Column = wrapExpression[ST_GeomFromGeoHash](geohash, null) + def ST_GeomFromGeoJSON(geojsonString: Column): Column = wrapExpression[ST_GeomFromGeoJSON](geojsonString) def ST_GeomFromGeoJSON(geojsonString: String): Column = wrapExpression[ST_GeomFromGeoJSON](geojsonString) diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala index 2ef94936c9..c8c7ac7310 100644 --- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala +++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sedona_sql.expressions import org.apache.spark.sql.Column import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect, ST_CollectionExtract} +import org.locationtech.jts.geom.Geometry import org.locationtech.jts.operation.buffer.BufferParameters object st_functions extends DataFrameAPI { @@ -159,8 +160,8 @@ object st_functions extends DataFrameAPI { def ST_LineSubstring(lineString: Column, startFraction: Column, endFraction: Column): Column = wrapExpression[ST_LineSubstring](lineString, startFraction, endFraction) def ST_LineSubstring(lineString: String, startFraction: Double, endFraction: Double): Column = wrapExpression[ST_LineSubstring](lineString, startFraction, endFraction) - def ST_MakePolygon(lineString: Column): Column = wrapExpression[ST_MakePolygon](lineString) - def ST_MakePolygon(lineString: String): Column = wrapExpression[ST_MakePolygon](lineString) + def ST_MakePolygon(lineString: Column): Column = wrapExpression[ST_MakePolygon](lineString, null) + def ST_MakePolygon(lineString: String): Column = wrapExpression[ST_MakePolygon](lineString, null) def ST_MakePolygon(lineString: Column, holes: Column): Column = wrapExpression[ST_MakePolygon](lineString, holes) def ST_MakePolygon(lineString: String, holes: String): Column = wrapExpression[ST_MakePolygon](lineString, holes)