Skip to content

Commit

Permalink
[SEDONA-271] Add raster function RS_SRID (apache#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
umartin authored and jiayuasu committed Apr 19, 2023
1 parent 7b874c1 commit 9be6756
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,48 @@
import org.geotools.coverage.grid.GridGeometry2D;
import org.geotools.geometry.DirectPosition2D;
import org.geotools.geometry.Envelope2D;
import org.geotools.referencing.CRS;
import org.geotools.referencing.crs.DefaultEngineeringCRS;
import org.locationtech.jts.geom.*;
import org.opengis.referencing.FactoryException;
import org.opengis.referencing.crs.CoordinateReferenceSystem;
import org.opengis.referencing.operation.TransformException;

import java.awt.image.Raster;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.DoublePredicate;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;

public class Functions {

public static Geometry envelope(GridCoverage2D raster) {
public static Geometry envelope(GridCoverage2D raster) throws FactoryException {
Envelope2D envelope2D = raster.getEnvelope2D();

Envelope envelope = new Envelope(envelope2D.getMinX(), envelope2D.getMaxX(), envelope2D.getMinY(), envelope2D.getMaxY());
return new GeometryFactory().toGeometry(envelope);
int srid = srid(raster);
return new GeometryFactory(new PrecisionModel(), srid).toGeometry(envelope);
}

public static int numBands(GridCoverage2D raster) {
return raster.getNumSampleDimensions();
}

public static int srid(GridCoverage2D raster) throws FactoryException {
CoordinateReferenceSystem crs = raster.getCoordinateReferenceSystem();
if (crs instanceof DefaultEngineeringCRS) {
// GeoTools defaults to internal non-standard epsg codes, like 404000, if crs is missing.
// We need to check for this case and return 0 instead.
if (((DefaultEngineeringCRS) crs).isWildcard()) {
return 0;
}
}
return Optional.ofNullable(CRS.lookupEpsgCode(crs, true)).orElse(0);
}

public static Double value(GridCoverage2D rasterGeom, Geometry geometry, int band) throws TransformException {
return values(rasterGeom, Collections.singletonList(geometry), band).get(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.geotools.coverage.grid.GridCoverage2D;
import org.junit.Test;
import org.locationtech.jts.geom.Geometry;
import org.opengis.referencing.FactoryException;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand All @@ -25,7 +26,7 @@
public class ConstructorsTest extends RasterTestBase {

@Test
public void fromArcInfoAsciiGrid() throws IOException {
public void fromArcInfoAsciiGrid() throws IOException, FactoryException {
GridCoverage2D gridCoverage2D = Constructors.fromArcInfoAsciiGrid(arc.getBytes(StandardCharsets.UTF_8));

Geometry envelope = Functions.envelope(gridCoverage2D);
Expand All @@ -39,7 +40,7 @@ public void fromArcInfoAsciiGrid() throws IOException {
}

@Test
public void fromGeoTiff() throws IOException {
public void fromGeoTiff() throws IOException, FactoryException {
GridCoverage2D gridCoverage2D = Constructors.fromGeoTiff(geoTiff);

Geometry envelope = Functions.envelope(gridCoverage2D);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.GeometryFactory;
import org.locationtech.jts.geom.Point;
import org.opengis.referencing.FactoryException;
import org.opengis.referencing.operation.TransformException;

import java.util.Arrays;
Expand All @@ -29,11 +30,13 @@
public class FunctionsTest extends RasterTestBase {

@Test
public void envelope() {
public void envelope() throws FactoryException {
Geometry envelope = Functions.envelope(oneBandRaster);
assertEquals(3600.0d, envelope.getArea(), 0.1d);
assertEquals(378922.0d + 30.0d, envelope.getCentroid().getX(), 0.1d);
assertEquals(4072345.0d + 30.0d, envelope.getCentroid().getY(), 0.1d);

assertEquals(4326, Functions.envelope(multiBandRaster).getSRID());
}

@Test
Expand All @@ -42,6 +45,12 @@ public void testNumBands() {
assertEquals(4, Functions.numBands(multiBandRaster));
}

@Test
public void testSrid() throws FactoryException {
assertEquals(0, Functions.srid(oneBandRaster));
assertEquals(4326, Functions.srid(multiBandRaster));
}

@Test
public void value() throws TransformException {
assertNull("Points outside of the envelope should return null.", Functions.value(oneBandRaster, point(1, 1), 1));
Expand Down
18 changes: 18 additions & 0 deletions docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,24 @@ Output:
4
```

### RS_SRID

Introduction: Returns the spatial reference system identifier (SRID) of the raster geometry.

Format: `RS_SRID (raster: Raster)`

Since: `v1.4.1`

Spark SQL example:
```sql
SELECT RS_SRID(raster) FROM raster_table
```

Output:
```
3857
```

### RS_Value

Introduction: Returns the value at the given point in the raster.
Expand Down
1 change: 1 addition & 0 deletions sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ object Catalog {
function[RS_FromGeoTiff](),
function[RS_Envelope](),
function[RS_NumBands](),
function[RS_SRID](),
function[RS_Value](1),
function[RS_Values](1)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,29 @@ case class RS_NumBands(inputExpressions: Seq[Expression]) extends Expression wit
override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT)
}

case class RS_SRID(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

override def eval(input: InternalRow): Any = {
val raster = inputExpressions(0).toRaster(input)
if (raster == null) {
null
} else {
Functions.srid(raster)
}
}

override def dataType: DataType = IntegerType

override def children: Seq[Expression] = inputExpressions

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}

override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT)
}

case class RS_Value(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes {

override def nullable: Boolean = true
Expand Down
11 changes: 11 additions & 0 deletions sql/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,17 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assert(result == 1)
}

it("Passed RS_SRID should handle null values") {
val result = sparkSession.sql("select RS_SRID(null)").first().get(0)
assert(result == null)
}

it("Passed RS_SRID with raster") {
val df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
val result = df.selectExpr("RS_SRID(RS_FromGeoTiff(content))").first().getInt(0)
assert(result == 3857)
}

it("Passed RS_Value should handle null values") {
val result = sparkSession.sql("select RS_Value(null, null)").first().get(0)
assert(result == null)
Expand Down

0 comments on commit 9be6756

Please sign in to comment.