Skip to content

Commit

Permalink
VRT: Support Int8, (U)Int64 with Python pixel functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Feb 19, 2024
1 parent e7f14b7 commit 81011ae
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
55 changes: 55 additions & 0 deletions autotest/gdrivers/vrtderived.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,61 @@ def identity(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize
_validate(xml)


###############################################################################


@pytest.mark.parametrize(
"dtype",
(
gdal.GDT_Int8,
gdal.GDT_Byte,
gdal.GDT_Int16,
gdal.GDT_UInt16,
gdal.GDT_Int32,
gdal.GDT_UInt32,
gdal.GDT_Int64,
gdal.GDT_UInt64,
),
)
def test_vrt_derived_dtype(tmp_vsimem, dtype):
input_fname = tmp_vsimem / "input.tif"

nx = 1
ny = 1

with gdal.GetDriverByName("GTiff").Create(
input_fname, nx, ny, 1, eType=gdal.GDT_Int8
) as input_ds:
input_ds.GetRasterBand(1).Fill(1)
gt = input_ds.GetGeoTransform()

vrt_xml = f"""
<VRTDataset rasterXSize="{nx}" rasterYSize="{ny}">
<GeoTransform>{', '.join([str(x) for x in gt])}</GeoTransform>
<VRTRasterBand dataType="{gdal.GetDataTypeName(dtype)}" band="1" subClass="VRTDerivedRasterBand">
<PixelFunctionLanguage>Python</PixelFunctionLanguage>
<PixelFunctionType>identity</PixelFunctionType>
<PixelFunctionCode><![CDATA[
import numpy as np
def identity(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize, r, gt, **kwargs):
out_ar[:] = in_ar
]]>
</PixelFunctionCode>
<SimpleSource>
<SourceFilename relativeToVRT="0">{input_fname}</SourceFilename>
<SourceBand>1</SourceBand>
<SrcRect xOff="0" yOff="0" xSize="{nx}" ySize="{ny}" />
<DstRect xOff="0" yOff="0" xSize="{nx}" ySize="{ny}" />
</SimpleSource>
</VRTRasterBand></VRTDataset>"""

with gdal.config_option("GDAL_VRT_ENABLE_PYTHON", "YES"):
with gdal.Open(vrt_xml) as vrt_ds:
vrt_ds.ReadRaster() # materialize VRT
assert vrt_ds.GetRasterBand(1).DataType == dtype


###############################################################################
# Cleanup.

Expand Down
9 changes: 9 additions & 0 deletions frmts/vrt/vrtderivedrasterband.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ static PyObject *GDALCreateNumpyArray(PyObject *pCreateArray, void *pBuffer,
case GDT_Byte:
pszDataType = "uint8";
break;
case GDT_Int8:
pszDataType = "int8";
break;
case GDT_UInt16:
pszDataType = "uint16";
break;
Expand All @@ -100,6 +103,12 @@ static PyObject *GDALCreateNumpyArray(PyObject *pCreateArray, void *pBuffer,
case GDT_Int32:
pszDataType = "int32";
break;
case GDT_Int64:
pszDataType = "int64";
break;
case GDT_UInt64:
pszDataType = "uint64";
break;
case GDT_Float32:
pszDataType = "float32";
break;
Expand Down

0 comments on commit 81011ae

Please sign in to comment.