diff --git a/common/src/main/java/org/apache/sedona/common/raster/Functions.java b/common/src/main/java/org/apache/sedona/common/raster/Functions.java index 489bd1d74f..a7761bb458 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/Functions.java +++ b/common/src/main/java/org/apache/sedona/common/raster/Functions.java @@ -18,30 +18,48 @@ import org.geotools.coverage.grid.GridGeometry2D; import org.geotools.geometry.DirectPosition2D; import org.geotools.geometry.Envelope2D; +import org.geotools.referencing.CRS; +import org.geotools.referencing.crs.DefaultEngineeringCRS; import org.locationtech.jts.geom.*; +import org.opengis.referencing.FactoryException; +import org.opengis.referencing.crs.CoordinateReferenceSystem; import org.opengis.referencing.operation.TransformException; import java.awt.image.Raster; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.function.DoublePredicate; import java.util.stream.Collectors; import java.util.stream.DoubleStream; public class Functions { - public static Geometry envelope(GridCoverage2D raster) { + public static Geometry envelope(GridCoverage2D raster) throws FactoryException { Envelope2D envelope2D = raster.getEnvelope2D(); Envelope envelope = new Envelope(envelope2D.getMinX(), envelope2D.getMaxX(), envelope2D.getMinY(), envelope2D.getMaxY()); - return new GeometryFactory().toGeometry(envelope); + int srid = srid(raster); + return new GeometryFactory(new PrecisionModel(), srid).toGeometry(envelope); } public static int numBands(GridCoverage2D raster) { return raster.getNumSampleDimensions(); } + public static int srid(GridCoverage2D raster) throws FactoryException { + CoordinateReferenceSystem crs = raster.getCoordinateReferenceSystem(); + if (crs instanceof DefaultEngineeringCRS) { + // GeoTools defaults to internal non-standard epsg codes, like 404000, if crs is missing. + // We need to check for this case and return 0 instead. + if (((DefaultEngineeringCRS) crs).isWildcard()) { + return 0; + } + } + return Optional.ofNullable(CRS.lookupEpsgCode(crs, true)).orElse(0); + } + public static Double value(GridCoverage2D rasterGeom, Geometry geometry, int band) throws TransformException { return values(rasterGeom, Collections.singletonList(geometry), band).get(0); } diff --git a/common/src/test/java/org/apache/sedona/common/raster/ConstructorsTest.java b/common/src/test/java/org/apache/sedona/common/raster/ConstructorsTest.java index 69aa085b25..f2ec505b77 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/ConstructorsTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/ConstructorsTest.java @@ -16,6 +16,7 @@ import org.geotools.coverage.grid.GridCoverage2D; import org.junit.Test; import org.locationtech.jts.geom.Geometry; +import org.opengis.referencing.FactoryException; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -25,7 +26,7 @@ public class ConstructorsTest extends RasterTestBase { @Test - public void fromArcInfoAsciiGrid() throws IOException { + public void fromArcInfoAsciiGrid() throws IOException, FactoryException { GridCoverage2D gridCoverage2D = Constructors.fromArcInfoAsciiGrid(arc.getBytes(StandardCharsets.UTF_8)); Geometry envelope = Functions.envelope(gridCoverage2D); @@ -39,7 +40,7 @@ public void fromArcInfoAsciiGrid() throws IOException { } @Test - public void fromGeoTiff() throws IOException { + public void fromGeoTiff() throws IOException, FactoryException { GridCoverage2D gridCoverage2D = Constructors.fromGeoTiff(geoTiff); Geometry envelope = Functions.envelope(gridCoverage2D); diff --git a/common/src/test/java/org/apache/sedona/common/raster/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/raster/FunctionsTest.java index 387d144885..f9b3a40bcd 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/FunctionsTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/FunctionsTest.java @@ -18,6 +18,7 @@ import org.locationtech.jts.geom.Geometry; import org.locationtech.jts.geom.GeometryFactory; import org.locationtech.jts.geom.Point; +import org.opengis.referencing.FactoryException; import org.opengis.referencing.operation.TransformException; import java.util.Arrays; @@ -29,11 +30,13 @@ public class FunctionsTest extends RasterTestBase { @Test - public void envelope() { + public void envelope() throws FactoryException { Geometry envelope = Functions.envelope(oneBandRaster); assertEquals(3600.0d, envelope.getArea(), 0.1d); assertEquals(378922.0d + 30.0d, envelope.getCentroid().getX(), 0.1d); assertEquals(4072345.0d + 30.0d, envelope.getCentroid().getY(), 0.1d); + + assertEquals(4326, Functions.envelope(multiBandRaster).getSRID()); } @Test @@ -42,6 +45,12 @@ public void testNumBands() { assertEquals(4, Functions.numBands(multiBandRaster)); } + @Test + public void testSrid() throws FactoryException { + assertEquals(0, Functions.srid(oneBandRaster)); + assertEquals(4326, Functions.srid(multiBandRaster)); + } + @Test public void value() throws TransformException { assertNull("Points outside of the envelope should return null.", Functions.value(oneBandRaster, point(1, 1), 1)); diff --git a/docs/api/sql/Raster-operators.md b/docs/api/sql/Raster-operators.md index 1011c303b7..0388d35a0a 100644 --- a/docs/api/sql/Raster-operators.md +++ b/docs/api/sql/Raster-operators.md @@ -35,6 +35,24 @@ Output: 4 ``` +### RS_SRID + +Introduction: Returns the spatial reference system identifier (SRID) of the raster geometry. + +Format: `RS_SRID (raster: Raster)` + +Since: `v1.4.1` + +Spark SQL example: +```sql +SELECT RS_SRID(raster) FROM raster_table +``` + +Output: +``` +3857 +``` + ### RS_Value Introduction: Returns the value at the given point in the raster. diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 3af49b6720..109d095da4 100644 --- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -173,6 +173,7 @@ object Catalog { function[RS_FromGeoTiff](), function[RS_Envelope](), function[RS_NumBands](), + function[RS_SRID](), function[RS_Value](1), function[RS_Values](1) ) diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Functions.scala index af2edcdf74..62c43a3f58 100644 --- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Functions.scala +++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Functions.scala @@ -855,6 +855,29 @@ case class RS_NumBands(inputExpressions: Seq[Expression]) extends Expression wit override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT) } +case class RS_SRID(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val raster = inputExpressions(0).toRaster(input) + if (raster == null) { + null + } else { + Functions.srid(raster) + } + } + + override def dataType: DataType = IntegerType + + override def children: Seq[Expression] = inputExpressions + + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } + + override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT) +} + case class RS_Value(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes { override def nullable: Boolean = true diff --git a/sql/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/sql/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala index 1bde6913f2..09ec4d4483 100644 --- a/sql/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala +++ b/sql/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala @@ -290,6 +290,17 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen assert(result == 1) } + it("Passed RS_SRID should handle null values") { + val result = sparkSession.sql("select RS_SRID(null)").first().get(0) + assert(result == null) + } + + it("Passed RS_SRID with raster") { + val df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff") + val result = df.selectExpr("RS_SRID(RS_FromGeoTiff(content))").first().getInt(0) + assert(result == 3857) + } + it("Passed RS_Value should handle null values") { val result = sparkSession.sql("select RS_Value(null, null)").first().get(0) assert(result == null)