diff --git a/autotest/ogr/ogr_mem.py b/autotest/ogr/ogr_mem.py index 2875a8fce028..6eaa34776080 100755 --- a/autotest/ogr/ogr_mem.py +++ b/autotest/ogr/ogr_mem.py @@ -757,6 +757,66 @@ def test_ogr_mem_arrow_stream_pycapsule_interface(): del stream +############################################################################### +# Test consuming __arrow_c_stream__() interface. +# Cf https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + + +@gdaltest.enable_exceptions() +def test_ogr_mem_consume_arrow_stream_pycapsule_interface(): + + ds = ogr.GetDriverByName("Memory").CreateDataSource("") + lyr = ds.CreateLayer("foo", geom_type=ogr.wkbNone) + lyr.CreateGeomField(ogr.GeomFieldDefn("my_geometry")) + lyr.CreateField(ogr.FieldDefn("foo")) + f = ogr.Feature(lyr.GetLayerDefn()) + f["foo"] = "bar" + f.SetGeometry(ogr.CreateGeometryFromWkt("POINT (1 2)")) + lyr.CreateFeature(f) + + lyr2 = ds.CreateLayer("foo2") + lyr2.WriteArrow(lyr) + + f = lyr2.GetNextFeature() + assert f["foo"] == "bar" + assert f.GetGeometryRef().ExportToIsoWkt() == "POINT (1 2)" + + +############################################################################### +# Test consuming __arrow_c_array__() interface. +# Cf https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + + +@gdaltest.enable_exceptions() +def test_ogr_mem_consume_arrow_arrow_pycapsule_interface(): + pyarrow = pytest.importorskip("pyarrow") + if int(pyarrow.__version__.split(".")[0]) < 14: + pytest.skip("pyarrow >= 14 needed") + + ds = ogr.GetDriverByName("Memory").CreateDataSource("") + lyr = ds.CreateLayer("foo") + lyr.CreateField(ogr.FieldDefn("foo")) + f = ogr.Feature(lyr.GetLayerDefn()) + f["foo"] = "bar" + f.SetGeometry(ogr.CreateGeometryFromWkt("POINT (1 2)")) + lyr.CreateFeature(f) + + table = pyarrow.table(lyr) + + lyr2 = ds.CreateLayer("foo2") + batches = table.to_batches() + for batch in batches: + array = batch.to_struct_array() + if not hasattr(array, "__arrow_c_array__"): + pytest.skip("table does not declare __arrow_c_array__") + + lyr2.WriteArrow(array) + + f = lyr2.GetNextFeature() + assert f["foo"] == "bar" + assert f.GetGeometryRef().ExportToIsoWkt() == "POINT (1 2)" + + ############################################################################### diff --git a/swig/include/ogr.i b/swig/include/ogr.i index 19b4de6e2667..f35c906a5688 100644 --- a/swig/include/ogr.i +++ b/swig/include/ogr.i @@ -1155,6 +1155,73 @@ static void ReleaseArrowArrayStreamPyCapsule(PyObject* capsule) { } CPLFree(stream); } + +static char** ParseArrowMetadata(const char *pabyMetadata) +{ + char** ret = NULL; + int32_t nKVP; + memcpy(&nKVP, pabyMetadata, sizeof(int32_t)); + pabyMetadata += sizeof(int32_t); + for (int i = 0; i < nKVP; ++i) + { + int32_t nSizeKey; + memcpy(&nSizeKey, pabyMetadata, sizeof(int32_t)); + pabyMetadata += sizeof(int32_t); + std::string osKey; + osKey.assign(pabyMetadata, nSizeKey); + pabyMetadata += nSizeKey; + + int32_t nSizeValue; + memcpy(&nSizeValue, pabyMetadata, sizeof(int32_t)); + pabyMetadata += sizeof(int32_t); + std::string osValue; + osValue.assign(pabyMetadata, nSizeValue); + pabyMetadata += nSizeValue; + + ret = CSLSetNameValue(ret, osKey.c_str(), osValue.c_str()); + } + + return ret; +} + +// Create output fields using CreateFieldFromArrowSchema() +static bool CreateFieldsFromArrowSchema(OGRLayerH hDstLayer, + const struct ArrowSchema* schemaSrc, + char** options) +{ + for (int i = 0; i < schemaSrc->n_children; ++i) + { + const char *metadata = + schemaSrc->children[i]->metadata; + if( metadata ) + { + char** keyValues = ParseArrowMetadata(metadata); + const char *ARROW_EXTENSION_NAME_KEY = "ARROW:extension:name"; + const char *EXTENSION_NAME_OGC_WKB = "ogc.wkb"; + const char *EXTENSION_NAME_GEOARROW_WKB = "geoarrow.wkb"; + const char* value = CSLFetchNameValue(keyValues, ARROW_EXTENSION_NAME_KEY); + const bool bSkip = ( value && (EQUAL(value, EXTENSION_NAME_OGC_WKB) || EQUAL(value, EXTENSION_NAME_GEOARROW_WKB)) ); + CSLDestroy(keyValues); + if( bSkip ) + continue; + } + + const char *pszFieldName = + schemaSrc->children[i]->name; + if (!EQUAL(pszFieldName, "OGC_FID") && + !EQUAL(pszFieldName, "wkb_geometry") && + !OGR_L_CreateFieldFromArrowSchema( + hDstLayer, schemaSrc->children[i], options)) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Cannot create field %s", + pszFieldName); + return false; + } + } + return true; +} + %} #endif @@ -1580,6 +1647,120 @@ public: { return OGR_L_WriteArrowBatch(self, schema, array, options) ? OGRERR_NONE : OGRERR_FAILURE; } + + OGRErr WriteArrowStreamCapsule(PyObject* capsule, int createFieldsFromSchema, char** options = NULL) + { + ArrowArrayStream* stream = (ArrowArrayStream*)PyCapsule_GetPointer(capsule, "arrow_array_stream"); + if( !stream ) + { + CPLError(CE_Failure, CPLE_AppDefined, "PyCapsule_GetPointer(capsule, \"arrow_array_stream\") failed"); + return OGRERR_FAILURE; + } + if( stream->release == NULL ) + { + CPLError(CE_Failure, CPLE_AppDefined, "stream->release == NULL"); + return OGRERR_FAILURE; + } + + ArrowSchema schema; + if( stream->get_schema(stream, &schema) != 0 ) + { + stream->release(stream); + return OGRERR_FAILURE; + } + + if( createFieldsFromSchema == TRUE || + (createFieldsFromSchema == -1 && OGR_FD_GetFieldCount(OGR_L_GetLayerDefn(self)) == 0) ) + { + if( !CreateFieldsFromArrowSchema(self, &schema, options) ) + { + schema.release(&schema); + stream->release(stream); + return OGRERR_FAILURE; + } + } + + while( true ) + { + ArrowArray array; + if( stream->get_next(stream, &array) == 0 ) + { + if( array.release == NULL ) + break; + if( !OGR_L_WriteArrowBatch(self, &schema, &array, options) ) + { + if( array.release ) + array.release(&array); + schema.release(&schema); + stream->release(stream); + return OGRERR_FAILURE; + } + if( array.release ) + array.release(&array); + } + else + { + CPLError(CE_Failure, CPLE_AppDefined, "stream->get_next(stream, &array) failed"); + schema.release(&schema); + stream->release(stream); + return OGRERR_FAILURE; + } + } + schema.release(&schema); + stream->release(stream); + return OGRERR_NONE; + } + + OGRErr WriteArrowSchemaAndArrowArrayCapsule(PyObject* schemaCapsule, PyObject* arrayCapsule, int createFieldsFromSchema, char** options = NULL) + { + ArrowSchema* schema = (ArrowSchema*)PyCapsule_GetPointer(schemaCapsule, "arrow_schema"); + if( !schema ) + { + CPLError(CE_Failure, CPLE_AppDefined, "PyCapsule_GetPointer(schemaCapsule, \"arrow_schema\") failed"); + return OGRERR_FAILURE; + } + if( schema->release == NULL ) + { + CPLError(CE_Failure, CPLE_AppDefined, "schema->release == NULL"); + return OGRERR_FAILURE; + } + + if( createFieldsFromSchema == TRUE || + (createFieldsFromSchema == -1 && OGR_FD_GetFieldCount(OGR_L_GetLayerDefn(self)) == 0) ) + { + if( !CreateFieldsFromArrowSchema(self, schema, options) ) + { + schema->release(schema); + return OGRERR_FAILURE; + } + } + + ArrowArray* array = (ArrowArray*)PyCapsule_GetPointer(arrayCapsule, "arrow_array"); + if( !array ) + { + CPLError(CE_Failure, CPLE_AppDefined, "PyCapsule_GetPointer(arrayCapsule, \"arrow_array\") failed"); + schema->release(schema); + return OGRERR_FAILURE; + } + if( array->release == NULL ) + { + CPLError(CE_Failure, CPLE_AppDefined, "array->release == NULL"); + schema->release(schema); + return OGRERR_FAILURE; + } + + OGRErr eErr = OGRERR_NONE; + if( !OGR_L_WriteArrowBatch(self, schema, array, options) ) + { + eErr = OGRERR_FAILURE; + } + + if( schema->release ) + schema->release(schema); + if( array->release ) + array->release(array); + return eErr; + } #endif #ifdef SWIGPYTHON diff --git a/swig/include/python/ogr_python.i b/swig/include/python/ogr_python.i index bf27eaba9526..38c2253cb62d 100644 --- a/swig/include/python/ogr_python.i +++ b/swig/include/python/ogr_python.i @@ -643,8 +643,52 @@ def ReleaseResultSet(self, sql_lyr): return self.CreateFieldFromArrowSchema(schema, options) + def WriteArrow(self, obj, requested_schema=None, createFieldsFromSchema=None, options=[]): + """Write the content of the passed object, which must implement the + __arrow_c_stream__ or __arrow_c_array__ interface, into the layer. + + Parameters + ---------- + obj: + Object implementing the __arrow_c_stream__ or __arrow_c_array__ interface + + requested_schema: PyCapsule, default None + The schema to which the stream should be casted, passed as a + PyCapsule containing a C ArrowSchema representation of the + requested schema. + + createFieldsFromSchema: boolean or None. Default to None + Whether OGRLayer::CreateFieldFromArrowSchema() should be called. If None + specified, it is called if no fields have been created yet + + options: list of strings + Options to pass to OGRLayer::CreateFieldFromArrowSchema() and OGRLayer::WriteArrowBatch() + + """ + + if createFieldsFromSchema is None: + createFieldsFromSchema = -1 + elif createFieldsFromSchema is True: + createFieldsFromSchema = 1 + else: + createFieldsFromSchema = 0 + + if hasattr(obj, "__arrow_c_stream__"): + stream_capsule = obj.__arrow_c_stream__(requested_schema=requested_schema) + return self.WriteArrowStreamCapsule(stream_capsule, createFieldsFromSchema, options) + + if hasattr(obj, "__arrow_c_array__"): + schema_capsule, array_capsule = obj.__arrow_c_array__(requested_schema=requested_schema) + return self.WriteArrowSchemaAndArrowArrayCapsule(schema_capsule, array_capsule, createFieldsFromSchema, options) + + raise Exception("Passed object does not implement the __arrow_c_stream__ or __arrow_c_array__ interface.") + + def WritePyArrow(self, pa_batch, options=[]): - """Write the content of the passed PyArrow batch (either a pyarrow.Table, a pyarrow.RecordBatch or a pyarrow.StructArray) into the layer.""" + """Write the content of the passed PyArrow batch (either a pyarrow.Table, a pyarrow.RecordBatch or a pyarrow.StructArray) into the layer. + + See also the WriteArrow() method to be independent of PyArrow + """ import pyarrow as pa