Skip to content

Commit

Permalink
[SEDONA-475] Add RS_NormalizeAll (#1221)
Browse files Browse the repository at this point in the history
* Implement RS_NormalizeAll

* add IllegalArgumentException

* Fix override issue

* Handle same band values

* Refactor NormalizeAll; Add tests

* Update Documentation

* Add flag argument: normalizeAcrossBands

* fix lint

* Optimize normalizeAll: remove redundant min/max calculations
  • Loading branch information
prantogg authored Feb 3, 2024
1 parent 4562a7b commit adebdac
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 0 deletions.
122 changes: 122 additions & 0 deletions common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,128 @@ public static double[] normalize(double[] band) {
return result;
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom) {
return normalizeAll(rasterGeom, 0d, 255d, null, null, null, true);
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim) {
return normalizeAll(rasterGeom, minLim, maxLim, null, null, null, true);
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim, double noDataValue) {
return normalizeAll(rasterGeom, minLim, maxLim, noDataValue, null, null, true);
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim, Double noDataValue, boolean normalizeAcrossBands) {
return normalizeAll(rasterGeom, minLim, maxLim, noDataValue, null, null, normalizeAcrossBands);
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim, Double noDataValue, Double minValue, Double maxValue) {
return normalizeAll(rasterGeom, minLim, maxLim, noDataValue, minValue, maxValue, true);
}

/**
*
* @param rasterGeom Raster to be normalized
* @param minLim Lower limit of normalization range
* @param maxLim Upper limit of normalization range
* @param noDataValue NoDataValue used in raster
* @param minValue Minimum value in raster
* @param maxValue Maximum value in raster
* @param normalizeAcrossBands flag to determine the normalization method
* @return a raster with all values in all bands normalized between minLim and maxLim
*/
public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim, Double noDataValue, Double minValue, Double maxValue, boolean normalizeAcrossBands) {
if (minLim > maxLim) {
throw new IllegalArgumentException("minLim cannot be greater than maxLim");
}

int numBands = rasterGeom.getNumSampleDimensions();
RenderedImage renderedImage = rasterGeom.getRenderedImage();
int rasterDataType = renderedImage.getSampleModel().getDataType();

double globalMin = minValue != null ? minValue : Double.MAX_VALUE;
double globalMax = maxValue != null ? maxValue : -Double.MAX_VALUE;

// Initialize arrays to store band-wise min and max values
double[] minValues = new double[numBands];
double[] maxValues = new double[numBands];
Arrays.fill(minValues, Double.MAX_VALUE);
Arrays.fill(maxValues, -Double.MAX_VALUE);

// Compute global min and max values across all bands if necessary and not provided
if (minValue == null || maxValue == null) {
for (int bandIndex = 0; bandIndex < numBands; bandIndex++) {
double[] bandValues = bandAsArray(rasterGeom, bandIndex + 1);
double bandNoDataValue = RasterUtils.getNoDataValue(rasterGeom.getSampleDimension(bandIndex));

if (noDataValue == null) {
noDataValue = maxLim;
}

for (double val : bandValues) {
if (val != bandNoDataValue) {
if (normalizeAcrossBands) {
globalMin = Math.min(globalMin, val);
globalMax = Math.max(globalMax, val);
} else {
minValues[bandIndex] = Math.min(minValues[bandIndex], val);
maxValues[bandIndex] = Math.max(maxValues[bandIndex], val);
}
}
}
}
} else {
globalMin = minValue;
globalMax = maxValue;
}

// Normalize each band
for (int bandIndex = 0; bandIndex < numBands; bandIndex++) {
double[] bandValues = bandAsArray(rasterGeom, bandIndex + 1);
double bandNoDataValue = RasterUtils.getNoDataValue(rasterGeom.getSampleDimension(bandIndex));
double currentMin = normalizeAcrossBands ? globalMin : (minValue != null ? minValue : minValues[bandIndex]);
double currentMax = normalizeAcrossBands ? globalMax : (maxValue != null ? maxValue : maxValues[bandIndex]);

if (Double.compare(currentMax, currentMin) == 0) {
Arrays.fill(bandValues, minLim);
} else {
for (int i = 0; i < bandValues.length; i++) {
if (bandValues[i] != bandNoDataValue) {
double normalizedValue = minLim + ((bandValues[i] - currentMin) * (maxLim - minLim)) / (currentMax - currentMin);
bandValues[i] = castRasterDataType(normalizedValue, rasterDataType);
} else {
bandValues[i] = noDataValue;
}
}
}

// Update the raster with the normalized band and noDataValue
rasterGeom = addBandFromArray(rasterGeom, bandValues, bandIndex+1);
rasterGeom = RasterBandEditors.setBandNoDataValue(rasterGeom, bandIndex+1, noDataValue);
}

return rasterGeom;
}

private static double castRasterDataType(double value, int dataType) {
switch (dataType) {
case DataBuffer.TYPE_BYTE:
return (byte) value;
case DataBuffer.TYPE_SHORT:
return (short) value;
case DataBuffer.TYPE_INT:
return (int) value;
case DataBuffer.TYPE_USHORT:
return (char) value;
case DataBuffer.TYPE_FLOAT:
return (float) value;
case DataBuffer.TYPE_DOUBLE:
default:
return value;
}
}

/**
* @param band1 band values
* @param band2 band values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.opengis.referencing.FactoryException;

import java.awt.image.DataBuffer;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Random;

import static org.junit.Assert.*;
Expand Down Expand Up @@ -321,6 +323,122 @@ public void testNormalize() {
assertArrayEquals(expected, actual, 0.1d);
}

@Test
public void testNormalizeAll() throws FactoryException {
GridCoverage2D raster1 = RasterConstructors.makeEmptyRaster(2, 4, 4, 0, 0, 1);
GridCoverage2D raster2 = RasterConstructors.makeEmptyRaster(2, 4, 4, 0, 0, 1);
GridCoverage2D raster3 = RasterConstructors.makeEmptyRaster(2, "I", 4, 4, 0, 0, 1);
GridCoverage2D raster4 = RasterConstructors.makeEmptyRaster(2, 4, 4, 0, 0, 1);
GridCoverage2D raster5 = RasterConstructors.makeEmptyRaster(2, 4, 4, 0, 0, 1);

for (int band = 1; band <= 2; band++) {
double[] bandValues1 = new double[4 * 4];
double[] bandValues2 = new double[4 * 4];
double[] bandValues3 = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16};
double[] bandValues4 = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0};
double[] bandValues5 = new double[4 * 4];
for (int i = 0; i < bandValues1.length; i++) {
bandValues1[i] = (i) * band;
bandValues2[i] = (1) * (band-1);
bandValues5[i] = i + ((band-1)*15);
}
raster1 = MapAlgebra.addBandFromArray(raster1, bandValues1, band);
raster2 = MapAlgebra.addBandFromArray(raster2, bandValues2, band);
raster3 = MapAlgebra.addBandFromArray(raster3, bandValues3, band);
raster4 = MapAlgebra.addBandFromArray(raster4, bandValues4, band);
raster4 = RasterBandEditors.setBandNoDataValue(raster4, band, 0.0);
raster5 = MapAlgebra.addBandFromArray(raster5, bandValues5, band);
}
raster3 = RasterBandEditors.setBandNoDataValue(raster3, 1, 16.0);
raster3 = RasterBandEditors.setBandNoDataValue(raster3, 2, 1.0);

GridCoverage2D normalizedRaster1 = MapAlgebra.normalizeAll(raster1, 0, 255, -9999.0, false);
GridCoverage2D normalizedRaster2 = MapAlgebra.normalizeAll(raster1, 256d, 511d, -9999.0, false);
GridCoverage2D normalizedRaster3 = MapAlgebra.normalizeAll(raster2);
GridCoverage2D normalizedRaster4 = MapAlgebra.normalizeAll(raster3, 0, 255, 95.0);
GridCoverage2D normalizedRaster5 = MapAlgebra.normalizeAll(raster4, 0, 255);
GridCoverage2D normalizedRaster6 = MapAlgebra.normalizeAll(raster5, 0.0, 255.0, -9999.0, 0.0, 30.0);
GridCoverage2D normalizedRaster7 = MapAlgebra.normalizeAll(raster5, 0, 255, -9999.0, false);

double[] expected1 = {0.0, 17.0, 34.0, 51.0, 68.0, 85.0, 102.0, 119.0, 136.0, 153.0, 170.0, 187.0, 204.0, 221.0, 238.0, 255.0};
double[] expected2 = {256.0, 273.0, 290.0, 307.0, 324.0, 341.0, 358.0, 375.0, 392.0, 409.0, 426.0, 443.0, 460.0, 477.0, 494.0, 511.0};
double[] expected3 = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
double[] expected4 = {0.0, 17.0, 34.0, 51.0, 68.0, 85.0, 102.0, 119.0, 136.0, 153.0, 170.0, 187.0, 204.0, 221.0, 238.0, 95.0};
double[] expected5 = {95.0, 17.0, 34.0, 51.0, 68.0, 85.0, 102.0, 119.0, 136.0, 153.0, 170.0, 187.0, 204.0, 221.0, 238.0, 255.0};
double[] expected6 = {0.0, 18.214285714285715, 36.42857142857143, 54.642857142857146, 72.85714285714286, 91.07142857142857, 109.28571428571429, 127.5, 145.71428571428572, 163.92857142857142, 182.14285714285714, 200.35714285714286, 218.57142857142858, 236.78571428571428, 255.0, 255.0};

// Step 3: Validate the results for each band
for (int band = 1; band <= 2; band++) {
double[] normalizedBand1 = MapAlgebra.bandAsArray(normalizedRaster1, band);
double[] normalizedBand2 = MapAlgebra.bandAsArray(normalizedRaster2, band);
double[] normalizedBand5 = MapAlgebra.bandAsArray(normalizedRaster5, band);
double[] normalizedBand6 = MapAlgebra.bandAsArray(normalizedRaster6, band);
double[] normalizedBand7 = MapAlgebra.bandAsArray(normalizedRaster7, band);
double normalizedMin6 = Arrays.stream(normalizedBand6).min().getAsDouble();
double normalizedMax6 = Arrays.stream(normalizedBand6).max().getAsDouble();

assertEquals(Arrays.toString(expected1), Arrays.toString(normalizedBand1));
assertEquals(Arrays.toString(expected2), Arrays.toString(normalizedBand2));
assertEquals(Arrays.toString(expected6), Arrays.toString(normalizedBand5));
assertEquals(Arrays.toString(expected1), Arrays.toString(normalizedBand7));

assertEquals(0+((band-1)*127.5), normalizedMin6, 0.01d);
assertEquals(127.5+((band-1)*127.5), normalizedMax6, 0.01d);
}

assertEquals(95.0, RasterUtils.getNoDataValue(normalizedRaster4.getSampleDimension(0)), 0.01d);
assertEquals(95.0, RasterUtils.getNoDataValue(normalizedRaster4.getSampleDimension(1)), 0.01d);

assertEquals(Arrays.toString(expected3), Arrays.toString(MapAlgebra.bandAsArray(normalizedRaster3, 1)));
assertEquals(Arrays.toString(expected4), Arrays.toString(MapAlgebra.bandAsArray(normalizedRaster4, 1)));
assertEquals(Arrays.toString(expected5), Arrays.toString(MapAlgebra.bandAsArray(normalizedRaster4, 2)));
}

@Test
public void testNormalizeAll2() throws FactoryException {
String[] pixelTypes = {"B", "I", "S", "US", "F", "D"}; // Byte, Integer, Short, Unsigned Short, Float, Double
for (String pixelType : pixelTypes) {
testNormalizeAll2(10, 10, pixelType);
}
}

private void testNormalizeAll2(int width, int height, String pixelType) throws FactoryException {
// Create an empty raster with the specified pixel type
GridCoverage2D raster = RasterConstructors.makeEmptyRaster(1, pixelType, width, height, 10, 20, 1);

// Fill raster
double[] bandValues = new double[width * height];
for (int i = 0; i < bandValues.length; i++) {
bandValues[i] = i;
}
raster = MapAlgebra.addBandFromArray(raster, bandValues, 1);

GridCoverage2D normalizedRaster = MapAlgebra.normalizeAll(raster, 0, 255);

// Check the normalized values and data type
double[] normalizedBandValues = MapAlgebra.bandAsArray(normalizedRaster, 1);
for (int i = 0; i < bandValues.length; i++) {
double expected = (bandValues[i] - 0) * (255 - 0) / (99 - 0);
double actual = normalizedBandValues[i];
switch (normalizedRaster.getRenderedImage().getSampleModel().getDataType()) {
case DataBuffer.TYPE_BYTE:
case DataBuffer.TYPE_SHORT:
case DataBuffer.TYPE_USHORT:
case DataBuffer.TYPE_INT:
assertEquals((int) expected, (int) actual);
break;
default:
assertEquals(expected, actual, 0.01);
}
}

// Assert the data type remains as expected
int resultDataType = normalizedRaster.getRenderedImage().getSampleModel().getDataType();
int expectedDataType = RasterUtils.getDataTypeCode(pixelType);
assertEquals(expectedDataType, resultDataType);
}


@Test
public void testNormalizedDifference() {
double[] band1 = new double[] {960, 1067, 107, 20, 1868};
Expand Down
41 changes: 41 additions & 0 deletions docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2480,6 +2480,47 @@ Spark SQL Example:
SELECT RS_Normalize(band)
```

### RS_NormalizeAll

Introduction: Normalizes values in all bands of a raster between a given normalization range. The function maintains the data type of the raster values by ensuring that the normalized values are cast back to the original data type of each band in the raster. By default, the values are normalized to range [0, 255]. RS_NormalizeAll can take upto 6 of the following arguments.

- `raster`: The raster to be normalized.
- `minLim` and `maxLim` (Optional): The lower and upper limits of the normalization range. By default, normalization range is set to [0, 255].
- `noDataValue` (Optional): Defines the value to be used for missing or invalid data in raster bands. By default, noDataValue is set to `maxLim`.
- `minValue` and `maxValue` (Optional): Optionally, specific minimum and maximum values of the input raster can be provided. If not provided, these values are computed from the raster data.
- `normalizeAcrossBands` (Optional): A boolean flag to determine the normalization method. If set to true (default), normalization is performed across all bands based on global min and max values. If false, each band is normalized individually based on its own min and max values.

!!! Warning
Using a noDataValue that falls within the normalization range can lead to loss of valid data. If any data value within a raster band matches the specified noDataValue, it will be replaced and cannot be distinguished or recovered later. Exercise caution in selecting a noDataValue to avoid unintentional data alteration.

Formats:
```
RS_NormalizeAll (raster: Raster)`
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double)
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double, noDataValue: Double)
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double, noDataValue: Double, normalizeAcrossBands: Boolean)
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double, noDataValue: Double, minValue: Double, maxValue: Double)
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double, noDataValue: Double, minValue: Double, maxValue: Double, normalizeAcrossBands: Boolean)
```

Since: `v1.6.0`

Spark SQL Example:

```sql
SELECT RS_NormalizeAll(raster, 0, 1)
```

### RS_NormalizedDifference

Introduction: Returns Normalized Difference between two bands(band2 and band1) in a Geotiff image(example: NDVI, NDBI)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ object Catalog {
function[RS_LogicalOver](),
function[RS_Array](),
function[RS_Normalize](),
function[RS_NormalizeAll](),
function[RS_AddBandFromArray](),
function[RS_BandAsArray](),
function[RS_MapAlgebra](null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ case class RS_Normalize(inputExpressions: Seq[Expression]) extends InferredExpre
}
}

case class RS_NormalizeAll(inputExpressions: Seq[Expression]) extends InferredExpression(
inferrableFunction1(MapAlgebra.normalizeAll), inferrableFunction3(MapAlgebra.normalizeAll), inferrableFunction4(MapAlgebra.normalizeAll), inferrableFunction5(MapAlgebra.normalizeAll), inferrableFunction6(MapAlgebra.normalizeAll), inferrableFunction7(MapAlgebra.normalizeAll)
) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}

case class RS_AddBandFromArray(inputExpressions: Seq[Expression])
extends InferredExpression(nullTolerantInferrableFunction3(MapAlgebra.addBandFromArray), nullTolerantInferrableFunction4(MapAlgebra.addBandFromArray), inferrableFunction2(MapAlgebra.addBandFromArray)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assert(df.first().getAs[mutable.WrappedArray[Double]](0)(1) == 255)
}

it("should pass RS_NormalizeAll") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
df = df.selectExpr("RS_FromGeoTiff(content) as raster")
val result1 = df.selectExpr("RS_NormalizeAll(raster, 0, 255) as normalized").first().get(0)
val result2 = df.selectExpr("RS_NormalizeAll(raster, 0, 255, 0) as normalized").first().get(0)
assert(result1.isInstanceOf[GridCoverage2D])
assert(result2.isInstanceOf[GridCoverage2D])
}

it("should pass RS_Array") {
val df = sparkSession.sql("SELECT RS_Array(6, 1e-6) as band")
val result = df.first().getAs[mutable.WrappedArray[Double]](0)
Expand Down

0 comments on commit adebdac

Please sign in to comment.