From edfbfb9c0b8dadd4a99e0db2171ccd70fe5ee06a Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 2 Sep 2021 13:03:03 +0200 Subject: [PATCH] ARROW-13855: [C++][Python] Implement C data interface support for extension types --- cpp/src/arrow/c/bridge.cc | 128 +++++-- cpp/src/arrow/c/bridge_test.cc | 349 +++++++++++++----- cpp/src/arrow/extension_type_test.cc | 2 + cpp/src/arrow/ipc/read_write_test.cc | 10 +- cpp/src/arrow/ipc/test_common.cc | 17 + cpp/src/arrow/ipc/test_common.h | 3 + cpp/src/arrow/testing/extension_type.h | 41 +- cpp/src/arrow/testing/gtest_util.cc | 84 +++-- .../arrow/testing/json_integration_test.cc | 7 +- cpp/src/arrow/util/key_value_metadata.cc | 8 +- cpp/src/arrow/util/key_value_metadata.h | 5 +- python/pyarrow/includes/libarrow.pxd | 9 + python/pyarrow/tests/test_cffi.py | 64 +++- python/pyarrow/tests/test_extension_type.py | 28 ++ python/pyarrow/types.pxi | 39 ++ 15 files changed, 626 insertions(+), 168 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 8b8153465ee4b..9484b44590ab9 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -28,6 +28,7 @@ #include "arrow/buffer.h" #include "arrow/c/helpers.h" #include "arrow/c/util_internal.h" +#include "arrow/extension_type.h" #include "arrow/memory_pool.h" #include "arrow/record_batch.h" #include "arrow/result.h" @@ -56,8 +57,6 @@ using internal::ArrayExportTraits; using internal::SchemaExportGuard; using internal::SchemaExportTraits; -// TODO export / import Extension types and arrays - namespace { Status ExportingNotImplemented(const DataType& type) { @@ -171,23 +170,26 @@ struct SchemaExporter { export_.name_ = field.name(); flags_ = field.nullable() ? ARROW_FLAG_NULLABLE : 0; - const DataType& type = *field.type(); - RETURN_NOT_OK(ExportFormat(type)); - RETURN_NOT_OK(ExportChildren(type.fields())); + const DataType* type = UnwrapExtension(field.type().get()); + RETURN_NOT_OK(ExportFormat(*type)); + RETURN_NOT_OK(ExportChildren(type->fields())); RETURN_NOT_OK(ExportMetadata(field.metadata().get())); return Status::OK(); } - Status ExportType(const DataType& type) { + Status ExportType(const DataType& orig_type) { flags_ = ARROW_FLAG_NULLABLE; - RETURN_NOT_OK(ExportFormat(type)); - RETURN_NOT_OK(ExportChildren(type.fields())); + const DataType* type = UnwrapExtension(&orig_type); + RETURN_NOT_OK(ExportFormat(*type)); + RETURN_NOT_OK(ExportChildren(type->fields())); + // There may be additional metadata to export + RETURN_NOT_OK(ExportMetadata(nullptr)); return Status::OK(); } Status ExportSchema(const Schema& schema) { - static StructType dummy_struct_type({}); + static const StructType dummy_struct_type({}); flags_ = 0; RETURN_NOT_OK(ExportFormat(dummy_struct_type)); @@ -232,6 +234,17 @@ struct SchemaExporter { c_struct->release = ReleaseExportedSchema; } + const DataType* UnwrapExtension(const DataType* type) { + if (type->id() == Type::EXTENSION) { + const auto& ext_type = checked_cast(*type); + additional_metadata_.reserve(2); + additional_metadata_.emplace_back(kExtensionTypeKeyName, ext_type.extension_name()); + additional_metadata_.emplace_back(kExtensionMetadataKeyName, ext_type.Serialize()); + return ext_type.storage_type().get(); + } + return type; + } + Status ExportFormat(const DataType& type) { if (type.id() == Type::DICTIONARY) { const auto& dict_type = checked_cast(type); @@ -259,10 +272,29 @@ struct SchemaExporter { return Status::OK(); } - Status ExportMetadata(const KeyValueMetadata* metadata) { - if (metadata != nullptr && metadata->size() >= 0) { - ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(*metadata)); + Status ExportMetadata(const KeyValueMetadata* orig_metadata) { + static const KeyValueMetadata empty_metadata; + + if (orig_metadata == nullptr) { + orig_metadata = &empty_metadata; } + if (additional_metadata_.empty()) { + if (orig_metadata->size() > 0) { + ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(*orig_metadata)); + } + return Status::OK(); + } + // Additional metadata needs to be appended to the existing + // (for extension types) + KeyValueMetadata metadata(orig_metadata->keys(), orig_metadata->values()); + for (const auto& kv : additional_metadata_) { + // The metadata may already be there => ignore + if (metadata.Contains(kv.first)) { + continue; + } + metadata.Append(kv.first, kv.second); + } + ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(metadata)); return Status::OK(); } @@ -442,6 +474,7 @@ struct SchemaExporter { ExportedSchemaPrivateData export_; int64_t flags_ = 0; + std::vector> additional_metadata_; std::unique_ptr dict_exporter_; std::vector child_exporters_; }; @@ -721,7 +754,13 @@ class FormatStringParser { size_t index_; }; -Result> DecodeMetadata(const char* metadata) { +struct DecodedMetadata { + std::shared_ptr metadata; + std::string extension_name; + std::string extension_serialized; +}; + +Result DecodeMetadata(const char* metadata) { auto read_int32 = [&](int32_t* out) -> Status { int32_t v; memcpy(&v, metadata, 4); @@ -744,21 +783,29 @@ Result> DecodeMetadata(const char* metadata) { return Status::OK(); }; + DecodedMetadata decoded; + if (metadata == nullptr) { - return nullptr; + return decoded; } int32_t npairs; RETURN_NOT_OK(read_int32(&npairs)); if (npairs == 0) { - return nullptr; + return decoded; } std::vector keys(npairs); std::vector values(npairs); for (int32_t i = 0; i < npairs; ++i) { RETURN_NOT_OK(read_string(&keys[i])); RETURN_NOT_OK(read_string(&values[i])); + if (keys[i] == kExtensionTypeKeyName) { + decoded.extension_name = values[i]; + } else if (keys[i] == kExtensionMetadataKeyName) { + decoded.extension_serialized = values[i]; + } } - return key_value_metadata(std::move(keys), std::move(values)); + decoded.metadata = key_value_metadata(std::move(keys), std::move(values)); + return decoded; } struct SchemaImporter { @@ -775,10 +822,9 @@ struct SchemaImporter { } Result> MakeField() const { - ARROW_ASSIGN_OR_RAISE(auto metadata, DecodeMetadata(c_struct_->metadata)); const char* name = c_struct_->name ? c_struct_->name : ""; bool nullable = (c_struct_->flags & ARROW_FLAG_NULLABLE) != 0; - return field(name, type_, nullable, std::move(metadata)); + return field(name, type_, nullable, std::move(metadata_.metadata)); } Result> MakeSchema() const { @@ -787,8 +833,7 @@ struct SchemaImporter { "Cannot import schema: ArrowSchema describes non-struct type ", type_->ToString()); } - ARROW_ASSIGN_OR_RAISE(auto metadata, DecodeMetadata(c_struct_->metadata)); - return schema(type_->fields(), std::move(metadata)); + return schema(type_->fields(), std::move(metadata_.metadata)); } Result> MakeType() const { return type_; } @@ -836,6 +881,20 @@ struct SchemaImporter { bool ordered = (c_struct_->flags & ARROW_FLAG_DICTIONARY_ORDERED) != 0; type_ = dictionary(type_, dict_importer.type_, ordered); } + + // Import metadata + ARROW_ASSIGN_OR_RAISE(metadata_, DecodeMetadata(c_struct_->metadata)); + + // Detect extension type + if (!metadata_.extension_name.empty()) { + const auto registered_ext_type = GetExtensionType(metadata_.extension_name); + if (registered_ext_type) { + ARROW_ASSIGN_OR_RAISE( + type_, registered_ext_type->Deserialize(std::move(type_), + metadata_.extension_serialized)); + } + } + return Status::OK(); } @@ -1130,6 +1189,7 @@ struct SchemaImporter { int64_t recursion_level_; std::vector child_importers_; std::shared_ptr type_; + DecodedMetadata metadata_; }; } // namespace @@ -1255,8 +1315,15 @@ struct ArrayImporter { } Status DoImport() { + // Unwrap extension type + const DataType* storage_type = type_.get(); + if (storage_type->id() == Type::EXTENSION) { + storage_type = + checked_cast(*storage_type).storage_type().get(); + } + // First import children (required for reconstituting parent array data) - const auto& fields = type_->fields(); + const auto& fields = storage_type->fields(); if (c_struct_->n_children != static_cast(fields.size())) { return Status::Invalid("ArrowArray struct has ", c_struct_->n_children, " children, expected ", fields.size(), " for type ", @@ -1270,15 +1337,15 @@ struct ArrayImporter { } // Import main data - RETURN_NOT_OK(ImportMainData()); + RETURN_NOT_OK(VisitTypeInline(*storage_type, this)); - bool is_dict_type = (type_->id() == Type::DICTIONARY); + bool is_dict_type = (storage_type->id() == Type::DICTIONARY); if (c_struct_->dictionary != nullptr) { if (!is_dict_type) { return Status::Invalid("Import type is ", type_->ToString(), " but dictionary field in ArrowArray struct is not null"); } - const auto& dict_type = checked_cast(*type_); + const auto& dict_type = checked_cast(*storage_type); // Import dictionary values ArrayImporter dict_importer(dict_type.value_type()); RETURN_NOT_OK(dict_importer.ImportDict(this, c_struct_->dictionary)); @@ -1292,13 +1359,11 @@ struct ArrayImporter { return Status::OK(); } - Status ImportMainData() { return VisitTypeInline(*type_, this); } - Status Visit(const DataType& type) { return Status::NotImplemented("Cannot import array of type ", type_->ToString()); } - Status Visit(const FixedWidthType& type) { return ImportFixedSizePrimitive(); } + Status Visit(const FixedWidthType& type) { return ImportFixedSizePrimitive(type); } Status Visit(const NullType& type) { RETURN_NOT_OK(CheckNoChildren()); @@ -1352,16 +1417,15 @@ struct ArrayImporter { return Status::OK(); } - Status ImportFixedSizePrimitive() { - const auto& fw_type = checked_cast(*type_); + Status ImportFixedSizePrimitive(const FixedWidthType& type) { RETURN_NOT_OK(CheckNoChildren()); RETURN_NOT_OK(CheckNumBuffers(2)); RETURN_NOT_OK(AllocateArrayData()); RETURN_NOT_OK(ImportNullBitmap()); - if (BitUtil::IsMultipleOf8(fw_type.bit_width())) { - RETURN_NOT_OK(ImportFixedSizeBuffer(1, fw_type.bit_width() / 8)); + if (BitUtil::IsMultipleOf8(type.bit_width())) { + RETURN_NOT_OK(ImportFixedSizeBuffer(1, type.bit_width() / 8)); } else { - DCHECK_EQ(fw_type.bit_width(), 1); + DCHECK_EQ(type.bit_width(), 1); RETURN_NOT_OK(ImportBitsBuffer(1)); } return Status::OK(); diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 54ce0efcf9d3d..c51cb66c03b98 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -31,8 +31,10 @@ #include "arrow/c/util_internal.h" #include "arrow/ipc/json_simple.h" #include "arrow/memory_pool.h" +#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/endian.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" @@ -45,6 +47,7 @@ using internal::ArrayExportGuard; using internal::ArrayExportTraits; using internal::ArrayStreamExportGuard; using internal::ArrayStreamExportTraits; +using internal::checked_cast; using internal::SchemaExportGuard; using internal::SchemaExportTraits; @@ -122,6 +125,10 @@ using ArrayReleaseCallback = ReleaseCallback; static const std::vector kMetadataKeys1{"key1", "key2"}; static const std::vector kMetadataValues1{"", "bar"}; + +static const std::vector kMetadataKeys2{"key"}; +static const std::vector kMetadataValues2{"abcde"}; + // clang-format off static const std::string kEncodedMetadata1{ // NOLINT: runtime/string #if ARROW_LITTLE_ENDIAN @@ -133,11 +140,7 @@ static const std::string kEncodedMetadata1{ // NOLINT: runtime/string 0, 0, 0, 4, 'k', 'e', 'y', '1', 0, 0, 0, 0, 0, 0, 0, 4, 'k', 'e', 'y', '2', 0, 0, 0, 3, 'b', 'a', 'r'}; #endif -// clang-format on -static const std::vector kMetadataKeys2{"key"}; -static const std::vector kMetadataValues2{"abcde"}; -// clang-format off static const std::string kEncodedMetadata2{ // NOLINT: runtime/string #if ARROW_LITTLE_ENDIAN 1, 0, 0, 0, @@ -146,6 +149,51 @@ static const std::string kEncodedMetadata2{ // NOLINT: runtime/string 0, 0, 0, 1, 0, 0, 0, 3, 'k', 'e', 'y', 0, 0, 0, 5, 'a', 'b', 'c', 'd', 'e'}; #endif + +static const std::string kEncodedUuidMetadata = // NOLINT: runtime/string +#if ARROW_LITTLE_ENDIAN + std::string {2, 0, 0, 0} + + std::string {20, 0, 0, 0} + kExtensionTypeKeyName + + std::string {4, 0, 0, 0} + "uuid" + + std::string {24, 0, 0, 0} + kExtensionMetadataKeyName + + std::string {15, 0, 0, 0} + "uuid-serialized"; +#else + std::string {0, 0, 0, 2} + + std::string {0, 0, 0, 20} + kExtensionTypeKeyName + + std::string {0, 0, 0, 4} + "uuid" + + std::string {0, 0, 0, 24} + kExtensionMetadataKeyName + + std::string {0, 0, 0, 15} + "uuid-serialized"; +#endif + +static const std::string kEncodedDictExtensionMetadata = // NOLINT: runtime/string +#if ARROW_LITTLE_ENDIAN + std::string {2, 0, 0, 0} + + std::string {20, 0, 0, 0} + kExtensionTypeKeyName + + std::string {14, 0, 0, 0} + "dict-extension" + + std::string {24, 0, 0, 0} + kExtensionMetadataKeyName + + std::string {25, 0, 0, 0} + "dict-extension-serialized"; +#else + std::string {0, 0, 0, 2} + + std::string {0, 0, 0, 20} + kExtensionTypeKeyName + + std::string {0, 0, 0, 14} + "dict-extension" + + std::string {0, 0, 0, 24} + kExtensionMetadataKeyName + + std::string {0, 0, 0, 25} + "dict-extension-serialized"; +#endif + +static const std::string kEncodedComplex128Metadata = // NOLINT: runtime/string +#if ARROW_LITTLE_ENDIAN + std::string {2, 0, 0, 0} + + std::string {20, 0, 0, 0} + kExtensionTypeKeyName + + std::string {10, 0, 0, 0} + "complex128" + + std::string {24, 0, 0, 0} + kExtensionMetadataKeyName + + std::string {21, 0, 0, 0} + "complex128-serialized"; +#else + std::string {0, 0, 0, 2} + + std::string {0, 0, 0, 20} + kExtensionTypeKeyName + + std::string {0, 0, 0, 10} + "complex128" + + std::string {0, 0, 0, 24} + kExtensionMetadataKeyName + + std::string {0, 0, 0, 21} + "complex128-serialized"; +#endif // clang-format on static constexpr int64_t kDefaultFlags = ARROW_FLAG_NULLABLE; @@ -404,6 +452,16 @@ TEST_F(TestSchemaExport, Dictionary) { } } +TEST_F(TestSchemaExport, Extension) { + TestPrimitive(uuid(), "w:16", "", kDefaultFlags, kEncodedUuidMetadata); + + TestNested(dict_extension_type(), {"c", "u"}, {"", ""}, {kDefaultFlags, kDefaultFlags}, + {kEncodedDictExtensionMetadata, ""}); + + TestNested(complex128(), {"+s", "g", "g"}, {"", "real", "imag"}, + {ARROW_FLAG_NULLABLE, 0, 0}, {kEncodedComplex128Metadata, "", ""}); +} + TEST_F(TestSchemaExport, ExportField) { TestPrimitive(field("thing", null()), "n", "thing", ARROW_FLAG_NULLABLE); // With nullable = false @@ -507,11 +565,9 @@ class TestArrayExport : public ::testing::Test { public: void SetUp() override { pool_ = default_memory_pool(); } - static std::function*)> JSONArrayFactory( + static std::function>()> JSONArrayFactory( std::shared_ptr type, const char* json) { - return [=](std::shared_ptr* out) -> Status { - return ::arrow::ipc::internal::json::ArrayFromJSON(type, json, out); - }; + return [=]() { return ArrayFromJSON(type, json); }; } template @@ -519,7 +575,7 @@ class TestArrayExport : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); std::shared_ptr arr; - ASSERT_OK(factory(&arr)); + ASSERT_OK_AND_ASSIGN(arr, ToResult(factory())); const ArrayData& data = *arr->data(); // non-owning reference struct ArrowArray c_export; ASSERT_OK(ExportArray(*arr, &c_export)); @@ -562,7 +618,7 @@ class TestArrayExport : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); std::shared_ptr arr; - ASSERT_OK(factory(&arr)); + ASSERT_OK_AND_ASSIGN(arr, ToResult(factory())); const ArrayData& data = *arr->data(); // non-owning reference struct ArrowArray c_export_temp, c_export_final; ASSERT_OK(ExportArray(*arr, &c_export_temp)); @@ -607,7 +663,7 @@ class TestArrayExport : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); std::shared_ptr arr; - ASSERT_OK(factory(&arr)); + ASSERT_OK_AND_ASSIGN(arr, ToResult(factory())); struct ArrowArray c_export_parent, c_export_child; ASSERT_OK(ExportArray(*arr, &c_export_parent)); @@ -661,7 +717,7 @@ class TestArrayExport : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); std::shared_ptr arr; - ASSERT_OK(factory(&arr)); + ASSERT_OK_AND_ASSIGN(arr, ToResult(factory())); struct ArrowArray c_export_parent; ASSERT_OK(ExportArray(*arr, &c_export_parent)); @@ -752,10 +808,7 @@ TEST_F(TestArrayExport, Primitive) { } TEST_F(TestArrayExport, PrimitiveSliced) { - auto factory = [](std::shared_ptr* out) -> Status { - *out = ArrayFromJSON(int16(), "[1, 2, null, -3]")->Slice(1, 2); - return Status::OK(); - }; + auto factory = []() { return ArrayFromJSON(int16(), "[1, 2, null, -3]")->Slice(1, 2); }; TestPrimitive(factory); } @@ -802,18 +855,17 @@ TEST_F(TestArrayExport, List) { TEST_F(TestArrayExport, ListSliced) { { - auto factory = [](std::shared_ptr* out) -> Status { - *out = ArrayFromJSON(list(int8()), "[[1, 2], [3, null], [4, 5, 6], null]") - ->Slice(1, 2); - return Status::OK(); + auto factory = []() { + return ArrayFromJSON(list(int8()), "[[1, 2], [3, null], [4, 5, 6], null]") + ->Slice(1, 2); }; TestNested(factory); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(int16(), "[1, 2, 3, 4, null, 5, 6, 7, 8]")->Slice(1, 6); auto offsets = ArrayFromJSON(int32(), "[0, 2, 3, 5, 6]")->Slice(2, 4); - return ListArray::FromArrays(*offsets, *values).Value(out); + return ListArray::FromArrays(*offsets, *values); }; TestNested(factory); } @@ -847,28 +899,25 @@ TEST_F(TestArrayExport, Union) { TEST_F(TestArrayExport, Dictionary) { { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])"); auto indices = ArrayFromJSON(uint16(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), - indices, values) - .Value(out); + indices, values); }; TestNested(factory); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays( - dictionary(indices->type(), values->type(), /*ordered=*/true), indices, - values) - .Value(out); + dictionary(indices->type(), values->type(), /*ordered=*/true), indices, values); }; TestNested(factory); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() -> Result> { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); ARROW_ASSIGN_OR_RAISE( @@ -876,13 +925,20 @@ TEST_F(TestArrayExport, Dictionary) { DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), indices, values)); auto offsets = ArrayFromJSON(int64(), "[0, 2, 5]"); - RETURN_NOT_OK(LargeListArray::FromArrays(*offsets, *dict_array).Value(out)); - return (*out)->ValidateFull(); + ARROW_ASSIGN_OR_RAISE(auto arr, LargeListArray::FromArrays(*offsets, *dict_array)); + RETURN_NOT_OK(arr->ValidateFull()); + return arr; }; TestNested(factory); } } +TEST_F(TestArrayExport, Extension) { + TestPrimitive(ExampleUuid); + TestPrimitive(ExampleSmallint); + TestPrimitive(ExampleComplex128); +} + TEST_F(TestArrayExport, MovePrimitive) { TestMovePrimitive(int8(), "[1, 2, null, -3]"); TestMovePrimitive(fixed_size_binary(3), R"(["foo", "bar", null])"); @@ -898,17 +954,16 @@ TEST_F(TestArrayExport, MoveNested) { TEST_F(TestArrayExport, MoveDictionary) { { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), - indices, values) - .Value(out); + indices, values); }; TestMoveNested(factory); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() -> Result> { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); ARROW_ASSIGN_OR_RAISE( @@ -916,8 +971,9 @@ TEST_F(TestArrayExport, MoveDictionary) { DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), indices, values)); auto offsets = ArrayFromJSON(int64(), "[0, 2, 5]"); - RETURN_NOT_OK(LargeListArray::FromArrays(*offsets, *dict_array).Value(out)); - return (*out)->ValidateFull(); + ARROW_ASSIGN_OR_RAISE(auto arr, LargeListArray::FromArrays(*offsets, *dict_array)); + RETURN_NOT_OK(arr->ValidateFull()); + return arr; }; TestMoveNested(factory); } @@ -934,7 +990,7 @@ TEST_F(TestArrayExport, MoveChild) { R"([[1, "foo"], [2, null]])", /*child_id=*/1); { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() -> Result> { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); ARROW_ASSIGN_OR_RAISE( @@ -942,8 +998,9 @@ TEST_F(TestArrayExport, MoveChild) { DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), indices, values)); auto offsets = ArrayFromJSON(int64(), "[0, 2, 5]"); - RETURN_NOT_OK(LargeListArray::FromArrays(*offsets, *dict_array).Value(out)); - return (*out)->ValidateFull(); + ARROW_ASSIGN_OR_RAISE(auto arr, LargeListArray::FromArrays(*offsets, *dict_array)); + RETURN_NOT_OK(arr->ValidateFull()); + return arr; }; TestMoveChild(factory, /*child_id=*/0); } @@ -1400,6 +1457,32 @@ TEST_F(TestSchemaImport, Dictionary) { CheckImport(expected); } +TEST_F(TestSchemaImport, UnregisteredExtension) { + FillPrimitive("w:16"); + c_struct_.metadata = kEncodedUuidMetadata.c_str(); + auto expected = fixed_size_binary(16); + CheckImport(expected); +} + +TEST_F(TestSchemaImport, RegisteredExtension) { + { + ExtensionTypeGuard guard(uuid()); + FillPrimitive("w:16"); + c_struct_.metadata = kEncodedUuidMetadata.c_str(); + auto expected = uuid(); + CheckImport(expected); + } + { + ExtensionTypeGuard guard(dict_extension_type()); + FillPrimitive(AddChild(), "u"); + FillPrimitive("c"); + FillDictionary(); + c_struct_.metadata = kEncodedDictExtensionMetadata.c_str(); + auto expected = dict_extension_type(); + CheckImport(expected); + } +} + TEST_F(TestSchemaImport, FormatStringError) { FillPrimitive(""); CheckImportError(); @@ -1481,6 +1564,22 @@ TEST_F(TestSchemaImport, DictionaryError) { CheckImportError(); } +TEST_F(TestSchemaImport, ExtensionError) { + ExtensionTypeGuard guard(uuid()); + + // Storage type doesn't match + FillPrimitive("w:15"); + c_struct_.metadata = kEncodedUuidMetadata.c_str(); + CheckImportError(); + + // Invalid serialization + std::string bogus_metadata = kEncodedUuidMetadata; + bogus_metadata[bogus_metadata.size() - 5] += 1; + FillPrimitive("w:16"); + c_struct_.metadata = bogus_metadata.c_str(); + CheckImportError(); +} + TEST_F(TestSchemaImport, RecursionError) { FillPrimitive(AddChild(), "c", "unused"); auto c = AddChild(); @@ -2163,21 +2262,44 @@ TEST_F(TestArrayImport, DictionaryWithOffset) { FillPrimitive(3, 0, 0, primitive_buffers_no_nulls4); FillDictionary(); - auto dict_values = ArrayFromJSON(utf8(), R"(["", "bar", "quux"])"); - auto indices = ArrayFromJSON(int8(), "[1, 2, 0]"); - ASSERT_OK_AND_ASSIGN( - auto expected, - DictionaryArray::FromArrays(dictionary(int8(), utf8()), indices, dict_values)); + auto expected = DictArrayFromJSON(dictionary(int8(), utf8()), "[1, 2, 0]", + R"(["", "bar", "quux"])"); CheckImport(expected); FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1); FillPrimitive(4, 0, 2, primitive_buffers_no_nulls4); FillDictionary(); - dict_values = ArrayFromJSON(utf8(), R"(["foo", "", "bar", "quux"])"); - indices = ArrayFromJSON(int8(), "[0, 1, 3, 0]"); - ASSERT_OK_AND_ASSIGN(expected, DictionaryArray::FromArrays(dictionary(int8(), utf8()), - indices, dict_values)); + expected = DictArrayFromJSON(dictionary(int8(), utf8()), "[0, 1, 3, 0]", + R"(["foo", "", "bar", "quux"])"); + CheckImport(expected); +} + +TEST_F(TestArrayImport, RegisteredExtension) { + ExtensionTypeGuard guard({smallint(), dict_extension_type(), complex128()}); + + // smallint + FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_16); + auto expected = + ExtensionType::WrapArray(smallint(), ArrayFromJSON(int16(), "[513, 1027, 1541]")); + CheckImport(expected); + + // dict_extension_type + FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1); + FillPrimitive(6, 0, 0, primitive_buffers_no_nulls4); + FillDictionary(); + + auto storage = DictArrayFromJSON(dictionary(int8(), utf8()), "[1, 2, 0, 1, 3, 0]", + R"(["foo", "", "bar", "quux"])"); + expected = ExtensionType::WrapArray(dict_extension_type(), storage); + CheckImport(expected); + + // complex128 + FillPrimitive(AddChild(), 3, 0, /*offset=*/0, primitive_buffers_no_nulls6); + FillPrimitive(AddChild(), 3, 0, /*offset=*/3, primitive_buffers_no_nulls6); + FillStructLike(3, 0, 0, 2, buffers_no_nulls_no_data); + expected = MakeComplex128(ArrayFromJSON(float64(), "[0.0, 1.5, -2.0]"), + ArrayFromJSON(float64(), "[3.0, 4.0, 5.0]")); CheckImport(expected); } @@ -2341,8 +2463,9 @@ class TestSchemaRoundtrip : public ::testing::Test { public: void SetUp() override { pool_ = default_memory_pool(); } - template - void TestWithTypeFactory(TypeFactory&& factory) { + template + void TestWithTypeFactory(TypeFactory&& factory, + ExpectedTypeFactory&& factory_expected) { std::shared_ptr type, actual; struct ArrowSchema c_schema {}; // zeroed SchemaExportGuard schema_guard(&c_schema); @@ -2359,7 +2482,7 @@ class TestSchemaRoundtrip : public ::testing::Test { // Recreate the type ASSERT_OK_AND_ASSIGN(actual, ImportType(&c_schema)); - type = factory(); + type = factory_expected(); AssertTypeEqual(*type, *actual); type.reset(); actual.reset(); @@ -2367,6 +2490,11 @@ class TestSchemaRoundtrip : public ::testing::Test { ASSERT_EQ(pool_->bytes_allocated(), orig_bytes); } + template + void TestWithTypeFactory(TypeFactory&& factory) { + TestWithTypeFactory(factory, factory); + } + template void TestWithSchemaFactory(SchemaFactory&& factory) { std::shared_ptr schema, actual; @@ -2459,6 +2587,27 @@ TEST_F(TestSchemaRoundtrip, Dictionary) { } } +TEST_F(TestSchemaRoundtrip, UnregisteredExtension) { + TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); }); + TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), utf8()); }); + + // Inside nested type + TestWithTypeFactory([]() { return list(dict_extension_type()); }, + []() { return list(dictionary(int8(), utf8())); }); +} + +TEST_F(TestSchemaRoundtrip, RegisteredExtension) { + ExtensionTypeGuard guard({uuid(), dict_extension_type(), complex128()}); + TestWithTypeFactory(uuid); + TestWithTypeFactory(dict_extension_type); + TestWithTypeFactory(complex128); + + // Inside nested type + TestWithTypeFactory([]() { return list(uuid()); }); + TestWithTypeFactory([]() { return list(dict_extension_type()); }); + TestWithTypeFactory([]() { return list(complex128()); }); +} + TEST_F(TestSchemaRoundtrip, Map) { TestWithTypeFactory([&]() { return map(utf8(), int32()); }); TestWithTypeFactory([&]() { return map(list(utf8()), int32()); }); @@ -2482,28 +2631,30 @@ TEST_F(TestSchemaRoundtrip, Schema) { class TestArrayRoundtrip : public ::testing::Test { public: - using ArrayFactory = std::function*)>; + using ArrayFactory = std::function>()>; void SetUp() override { pool_ = default_memory_pool(); } static ArrayFactory JSONArrayFactory(std::shared_ptr type, const char* json) { - return [=](std::shared_ptr* out) -> Status { - return ::arrow::ipc::internal::json::ArrayFromJSON(type, json, out); - }; + return [=]() { return ArrayFromJSON(type, json); }; } static ArrayFactory SlicedArrayFactory(ArrayFactory factory) { - return [=](std::shared_ptr* out) -> Status { - std::shared_ptr arr; - RETURN_NOT_OK(factory(&arr)); + return [=]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto arr, factory()); DCHECK_GE(arr->length(), 2); - *out = arr->Slice(1, arr->length() - 2); - return Status::OK(); + return arr->Slice(1, arr->length() - 2); }; } template void TestWithArrayFactory(ArrayFactory&& factory) { + TestWithArrayFactory(factory, factory); + } + + template + void TestWithArrayFactory(ArrayFactory&& factory, + ExpectedArrayFactory&& factory_expected) { std::shared_ptr array; struct ArrowArray c_array {}; struct ArrowSchema c_schema {}; @@ -2512,7 +2663,7 @@ class TestArrayRoundtrip : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); - ASSERT_OK(factory(&array)); + ASSERT_OK_AND_ASSIGN(array, ToResult(factory())); ASSERT_OK(ExportType(*array->type(), &c_schema)); ASSERT_OK(ExportArray(*array, &c_array)); @@ -2539,7 +2690,7 @@ class TestArrayRoundtrip : public ::testing::Test { // Check value of imported array { std::shared_ptr expected; - ASSERT_OK(factory(&expected)); + ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected())); AssertTypeEqual(*expected->type(), *array->type()); AssertArraysEqual(*expected, *array, true); } @@ -2556,7 +2707,7 @@ class TestArrayRoundtrip : public ::testing::Test { SchemaExportGuard schema_guard(&c_schema); auto orig_bytes = pool_->bytes_allocated(); - ASSERT_OK(factory(&batch)); + ASSERT_OK_AND_ASSIGN(batch, ToResult(factory())); ASSERT_OK(ExportSchema(*batch->schema(), &c_schema)); ASSERT_OK(ExportRecordBatch(*batch, &c_array)); @@ -2579,7 +2730,7 @@ class TestArrayRoundtrip : public ::testing::Test { // Check value of imported record batch { std::shared_ptr expected; - ASSERT_OK(factory(&expected)); + ASSERT_OK_AND_ASSIGN(expected, ToResult(factory())); AssertSchemaEqual(*expected->schema(), *batch->schema()); AssertBatchesEqual(*expected, *batch); } @@ -2621,15 +2772,15 @@ TEST_F(TestArrayRoundtrip, Primitive) { } TEST_F(TestArrayRoundtrip, UnknownNullCount) { - TestWithArrayFactory([](std::shared_ptr* arr) -> Status { - *arr = ArrayFromJSON(int32(), "[0, 1, 2]"); - if ((*arr)->null_bitmap()) { + TestWithArrayFactory([]() -> Result> { + auto arr = ArrayFromJSON(int32(), "[0, 1, 2]"); + if (arr->null_bitmap()) { return Status::Invalid( "Failed precondition: " "the array shouldn't have a null bitmap."); } - (*arr)->data()->SetNullCount(kUnknownNullCount); - return Status::OK(); + arr->data()->SetNullCount(kUnknownNullCount); + return arr; }); } @@ -2670,30 +2821,62 @@ TEST_F(TestArrayRoundtrip, Nested) { TEST_F(TestArrayRoundtrip, Dictionary) { { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), - indices, values) - .Value(out); + indices, values); }; TestWithArrayFactory(factory); TestWithArrayFactory(SlicedArrayFactory(factory)); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays( - dictionary(indices->type(), values->type(), /*ordered=*/true), indices, - values) - .Value(out); + dictionary(indices->type(), values->type(), /*ordered=*/true), indices, values); }; TestWithArrayFactory(factory); TestWithArrayFactory(SlicedArrayFactory(factory)); } } +TEST_F(TestArrayRoundtrip, RegisteredExtension) { + ExtensionTypeGuard guard({smallint(), complex128(), dict_extension_type(), uuid()}); + + TestWithArrayFactory(ExampleSmallint); + TestWithArrayFactory(ExampleUuid); + TestWithArrayFactory(ExampleComplex128); + TestWithArrayFactory(ExampleDictExtension); + + // Nested inside outer array + auto NestedFactory = [](ArrayFactory factory) { + return [factory]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto arr, ToResult(factory())); + return FixedSizeListArray::FromArrays(arr, /*list_size=*/1); + }; + }; + TestWithArrayFactory(NestedFactory(ExampleSmallint)); + TestWithArrayFactory(NestedFactory(ExampleUuid)); + TestWithArrayFactory(NestedFactory(ExampleComplex128)); + TestWithArrayFactory(NestedFactory(ExampleDictExtension)); +} + +TEST_F(TestArrayRoundtrip, UnregisteredExtension) { + auto StorageExtractor = [](ArrayFactory factory) { + return [factory]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto arr, ToResult(factory())); + return checked_cast(*arr).storage(); + }; + }; + + TestWithArrayFactory(ExampleSmallint, StorageExtractor(ExampleSmallint)); + TestWithArrayFactory(ExampleUuid, StorageExtractor(ExampleUuid)); + TestWithArrayFactory(ExampleComplex128, StorageExtractor(ExampleComplex128)); + TestWithArrayFactory(ExampleDictExtension, StorageExtractor(ExampleDictExtension)); +} + TEST_F(TestArrayRoundtrip, RecordBatch) { auto schema = ::arrow::schema( {field("ints", int16()), field("bools", boolean(), /*nullable=*/false)}); @@ -2701,22 +2884,18 @@ TEST_F(TestArrayRoundtrip, RecordBatch) { auto arr1 = ArrayFromJSON(boolean(), "[false, true, false]"); { - auto factory = [&](std::shared_ptr* out) -> Status { - *out = RecordBatch::Make(schema, 3, {arr0, arr1}); - return Status::OK(); - }; + auto factory = [&]() { return RecordBatch::Make(schema, 3, {arr0, arr1}); }; TestWithBatchFactory(factory); } { // With schema and field metadata - auto factory = [&](std::shared_ptr* out) -> Status { + auto factory = [&]() { auto f0 = schema->field(0); auto f1 = schema->field(1); f1 = f1->WithMetadata(key_value_metadata(kMetadataKeys1, kMetadataValues1)); auto schema_with_md = ::arrow::schema({f0, f1}, key_value_metadata(kMetadataKeys2, kMetadataValues2)); - *out = RecordBatch::Make(schema_with_md, 3, {arr0, arr1}); - return Status::OK(); + return RecordBatch::Make(schema_with_md, 3, {arr0, arr1}); }; TestWithBatchFactory(factory); } diff --git a/cpp/src/arrow/extension_type_test.cc b/cpp/src/arrow/extension_type_test.cc index cd1c3b9790e76..31222d74806f4 100644 --- a/cpp/src/arrow/extension_type_test.cc +++ b/cpp/src/arrow/extension_type_test.cc @@ -325,10 +325,12 @@ TEST_F(TestExtensionType, ValidateExtensionArray) { auto p1_type = std::make_shared(6); auto ext_arr2 = ExampleParametric(p1_type, "[null, 1, 2, 3]"); auto ext_arr3 = ExampleStruct(); + auto ext_arr4 = ExampleComplex128(); ASSERT_OK(ext_arr1->ValidateFull()); ASSERT_OK(ext_arr2->ValidateFull()); ASSERT_OK(ext_arr3->ValidateFull()); + ASSERT_OK(ext_arr4->ValidateFull()); } } // namespace arrow diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 245534b1d5c26..d7b7fb54eaffb 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -355,20 +355,18 @@ const std::vector kBatchCases = { &MakeFloatBatch, &MakeIntervals, &MakeUuid, + &MakeComplex128, &MakeDictExtension}; static int g_file_number = 0; class ExtensionTypesMixin { public: - ExtensionTypesMixin() { - // Register the extension types required to ensure roundtripping - ext_guards_.emplace_back(uuid()); - ext_guards_.emplace_back(dict_extension_type()); - } + // Register the extension types required to ensure roundtripping + ExtensionTypesMixin() : ext_guard_({uuid(), dict_extension_type(), complex128()}) {} protected: - std::vector ext_guards_; + ExtensionTypeGuard ext_guard_; }; class IpcTestFixture : public io::MemoryMapFixture, public ExtensionTypesMixin { diff --git a/cpp/src/arrow/ipc/test_common.cc b/cpp/src/arrow/ipc/test_common.cc index c93f1f60e6e5a..5068eca001ac6 100644 --- a/cpp/src/arrow/ipc/test_common.cc +++ b/cpp/src/arrow/ipc/test_common.cc @@ -985,6 +985,23 @@ Status MakeUuid(std::shared_ptr* out) { return Status::OK(); } +Status MakeComplex128(std::shared_ptr* out) { + auto type = complex128(); + auto storage_type = checked_cast(*type).storage_type(); + + auto f0 = field("f0", type); + auto f1 = field("f1", type, /*nullable=*/false); + auto schema = ::arrow::schema({f0, f1}); + + auto a0 = ExtensionType::WrapArray(complex128(), + ArrayFromJSON(storage_type, "[[1.0, -2.5], null]")); + auto a1 = ExtensionType::WrapArray( + complex128(), ArrayFromJSON(storage_type, "[[1.0, -2.5], [3.0, -4.0]]")); + + *out = RecordBatch::Make(schema, a1->length(), {a0, a1}); + return Status::OK(); +} + Status MakeDictExtension(std::shared_ptr* out) { auto type = dict_extension_type(); auto storage_type = checked_cast(*type).storage_type(); diff --git a/cpp/src/arrow/ipc/test_common.h b/cpp/src/arrow/ipc/test_common.h index 2217bae39fcb3..48df28b2d5a28 100644 --- a/cpp/src/arrow/ipc/test_common.h +++ b/cpp/src/arrow/ipc/test_common.h @@ -159,6 +159,9 @@ Status MakeNull(std::shared_ptr* out); ARROW_TESTING_EXPORT Status MakeUuid(std::shared_ptr* out); +ARROW_TESTING_EXPORT +Status MakeComplex128(std::shared_ptr* out); + ARROW_TESTING_EXPORT Status MakeDictExtension(std::shared_ptr* out); diff --git a/cpp/src/arrow/testing/extension_type.h b/cpp/src/arrow/testing/extension_type.h index 4163c9d8358f9..5afe23400767b 100644 --- a/cpp/src/arrow/testing/extension_type.h +++ b/cpp/src/arrow/testing/extension_type.h @@ -19,6 +19,7 @@ #include #include +#include #include "arrow/extension_type.h" #include "arrow/testing/visibility.h" @@ -87,6 +88,30 @@ class ARROW_TESTING_EXPORT DictExtensionType : public ExtensionType { std::string Serialize() const override { return "dict-extension-serialized"; } }; +class ARROW_TESTING_EXPORT Complex128Array : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +class ARROW_TESTING_EXPORT Complex128Type : public ExtensionType { + public: + Complex128Type() + : ExtensionType(struct_({::arrow::field("real", float64(), /*nullable=*/false), + ::arrow::field("imag", float64(), /*nullable=*/false)})) {} + + std::string extension_name() const override { return "complex128"; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override; + + std::string Serialize() const override { return "complex128-serialized"; } +}; + ARROW_TESTING_EXPORT std::shared_ptr uuid(); @@ -96,24 +121,38 @@ std::shared_ptr smallint(); ARROW_TESTING_EXPORT std::shared_ptr dict_extension_type(); +ARROW_TESTING_EXPORT +std::shared_ptr complex128(); + ARROW_TESTING_EXPORT std::shared_ptr ExampleUuid(); ARROW_TESTING_EXPORT std::shared_ptr ExampleSmallint(); +ARROW_TESTING_EXPORT +std::shared_ptr ExampleDictExtension(); + +ARROW_TESTING_EXPORT +std::shared_ptr ExampleComplex128(); + +ARROW_TESTING_EXPORT +std::shared_ptr MakeComplex128(const std::shared_ptr& real, + const std::shared_ptr& imag); + // A RAII class that registers an extension type on construction // and unregisters it on destruction. class ARROW_TESTING_EXPORT ExtensionTypeGuard { public: explicit ExtensionTypeGuard(const std::shared_ptr& type); + explicit ExtensionTypeGuard(const DataTypeVector& types); ~ExtensionTypeGuard(); ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionTypeGuard); protected: ARROW_DISALLOW_COPY_AND_ASSIGN(ExtensionTypeGuard); - std::string extension_name_; + std::vector extension_names_; }; } // namespace arrow diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 3e7c9b78c6b24..587154c1f3048 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -811,6 +811,28 @@ Result> DictExtensionType::Deserialize( return std::make_shared(); } +bool Complex128Type::ExtensionEquals(const ExtensionType& other) const { + return (other.extension_name() == this->extension_name()); +} + +std::shared_ptr Complex128Type::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK(ExtensionEquals(checked_cast(*data->type))); + return std::make_shared(data); +} + +Result> Complex128Type::Deserialize( + std::shared_ptr storage_type, const std::string& serialized) const { + if (serialized != "complex128-serialized") { + return Status::Invalid("Type identifier did not match: '", serialized, "'"); + } + if (!storage_type->Equals(*storage_type_)) { + return Status::Invalid("Invalid storage type for Complex128Type: ", + storage_type->ToString()); + } + return std::make_shared(); +} + std::shared_ptr uuid() { return std::make_shared(); } std::shared_ptr smallint() { return std::make_shared(); } @@ -819,40 +841,58 @@ std::shared_ptr dict_extension_type() { return std::make_shared(); } -std::shared_ptr ExampleUuid() { - auto storage_type = fixed_size_binary(16); - auto ext_type = uuid(); +std::shared_ptr complex128() { return std::make_shared(); } +std::shared_ptr MakeComplex128(const std::shared_ptr& real, + const std::shared_ptr& imag) { + auto type = complex128(); + std::shared_ptr storage( + new StructArray(checked_cast(*type).storage_type(), + real->length(), {real, imag})); + return ExtensionType::WrapArray(type, storage); +} + +std::shared_ptr ExampleUuid() { auto arr = ArrayFromJSON( - storage_type, + fixed_size_binary(16), "[null, \"abcdefghijklmno0\", \"abcdefghijklmno1\", \"abcdefghijklmno2\"]"); - - auto ext_data = arr->data()->Copy(); - ext_data->type = ext_type; - return MakeArray(ext_data); + return ExtensionType::WrapArray(uuid(), arr); } std::shared_ptr ExampleSmallint() { - auto storage_type = int16(); - auto ext_type = smallint(); - auto arr = ArrayFromJSON(storage_type, "[-32768, null, 1, 2, 3, 4, 32767]"); - auto ext_data = arr->data()->Copy(); - ext_data->type = ext_type; - return MakeArray(ext_data); + auto arr = ArrayFromJSON(int16(), "[-32768, null, 1, 2, 3, 4, 32767]"); + return ExtensionType::WrapArray(smallint(), arr); } -ExtensionTypeGuard::ExtensionTypeGuard(const std::shared_ptr& type) { - ARROW_CHECK_EQ(type->id(), Type::EXTENSION); - auto ext_type = checked_pointer_cast(type); +std::shared_ptr ExampleDictExtension() { + auto arr = DictArrayFromJSON(dictionary(int8(), utf8()), "[0, 1, null, 1]", + R"(["foo", "bar"])"); + return ExtensionType::WrapArray(dict_extension_type(), arr); +} + +std::shared_ptr ExampleComplex128() { + auto arr = ArrayFromJSON(struct_({field("", float64()), field("", float64())}), + "[[1.0, -2.5], null, [3.0, -4.5]]"); + return ExtensionType::WrapArray(complex128(), arr); +} - ARROW_CHECK_OK(RegisterExtensionType(ext_type)); - extension_name_ = ext_type->extension_name(); - DCHECK(!extension_name_.empty()); +ExtensionTypeGuard::ExtensionTypeGuard(const std::shared_ptr& type) + : ExtensionTypeGuard(DataTypeVector{type}) {} + +ExtensionTypeGuard::ExtensionTypeGuard(const DataTypeVector& types) { + for (const auto& type : types) { + ARROW_CHECK_EQ(type->id(), Type::EXTENSION); + auto ext_type = checked_pointer_cast(type); + + ARROW_CHECK_OK(RegisterExtensionType(ext_type)); + extension_names_.push_back(ext_type->extension_name()); + DCHECK(!extension_names_.back().empty()); + } } ExtensionTypeGuard::~ExtensionTypeGuard() { - if (!extension_name_.empty()) { - ARROW_CHECK_OK(UnregisterExtensionType(extension_name_)); + for (const auto& name : extension_names_) { + ARROW_CHECK_OK(UnregisterExtensionType(name)); } } diff --git a/cpp/src/arrow/testing/json_integration_test.cc b/cpp/src/arrow/testing/json_integration_test.cc index 34b871c56c1ab..55620119550e0 100644 --- a/cpp/src/arrow/testing/json_integration_test.cc +++ b/cpp/src/arrow/testing/json_integration_test.cc @@ -197,8 +197,7 @@ Status RunCommand(const std::string& json_path, const std::string& arrow_path, const std::string& command) { // Make sure the required extension types are registered, as they will be // referenced in test data. - ExtensionTypeGuard uuid_ext_guard(uuid()); - ExtensionTypeGuard dict_ext_guard(dict_extension_type()); + ExtensionTypeGuard ext_guard({uuid(), dict_extension_type()}); if (json_path == "") { return Status::Invalid("Must specify json file name"); @@ -1105,8 +1104,7 @@ class TestJsonRoundTrip : public ::testing::TestWithParam { }; void CheckRoundtrip(const RecordBatch& batch) { - ExtensionTypeGuard uuid_ext_guard(uuid()); - ExtensionTypeGuard dict_ext_guard(dict_extension_type()); + ExtensionTypeGuard guard({uuid(), dict_extension_type(), complex128()}); TestSchemaRoundTrip(*batch.schema()); @@ -1160,6 +1158,7 @@ const std::vector kBatchCases = { &MakeFloatBatch, &MakeIntervals, &MakeUuid, + &MakeComplex128, &MakeDictExtension}; INSTANTIATE_TEST_SUITE_P(TestJsonRoundTrip, TestJsonRoundTrip, diff --git a/cpp/src/arrow/util/key_value_metadata.cc b/cpp/src/arrow/util/key_value_metadata.cc index ad3b686a9bdb5..fd179a8bf3877 100644 --- a/cpp/src/arrow/util/key_value_metadata.cc +++ b/cpp/src/arrow/util/key_value_metadata.cc @@ -56,8 +56,6 @@ static std::vector UnorderedMapValues( return values; } -KeyValueMetadata::KeyValueMetadata() : keys_(), values_() {} - KeyValueMetadata::KeyValueMetadata( const std::unordered_map& map) : keys_(UnorderedMapKeys(map)), values_(UnorderedMapValues(map)) { @@ -85,9 +83,9 @@ void KeyValueMetadata::ToUnorderedMap( } } -void KeyValueMetadata::Append(const std::string& key, const std::string& value) { - keys_.push_back(key); - values_.push_back(value); +void KeyValueMetadata::Append(std::string key, std::string value) { + keys_.push_back(std::move(key)); + values_.push_back(std::move(value)); } Result KeyValueMetadata::Get(const std::string& key) const { diff --git a/cpp/src/arrow/util/key_value_metadata.h b/cpp/src/arrow/util/key_value_metadata.h index d42ab78f6671b..2a31bf378b0da 100644 --- a/cpp/src/arrow/util/key_value_metadata.h +++ b/cpp/src/arrow/util/key_value_metadata.h @@ -34,16 +34,15 @@ namespace arrow { /// \brief A container for key-value pair type metadata. Not thread-safe class ARROW_EXPORT KeyValueMetadata { public: - KeyValueMetadata(); + KeyValueMetadata() = default; KeyValueMetadata(std::vector keys, std::vector values); explicit KeyValueMetadata(const std::unordered_map& map); - virtual ~KeyValueMetadata() = default; static std::shared_ptr Make(std::vector keys, std::vector values); void ToUnorderedMap(std::unordered_map* out) const; - void Append(const std::string& key, const std::string& value); + void Append(std::string key, std::string value); Result Get(const std::string& key) const; bool Contains(const std::string& key) const; diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 4f9f4184b2d36..8ed7090ab5fc3 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -146,6 +146,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: Type id() c_bool Equals(const CDataType& other) + c_bool Equals(const shared_ptr[CDataType]& other) shared_ptr[CField] field(int i) const vector[shared_ptr[CField]] fields() @@ -2341,6 +2342,14 @@ cdef extern from 'arrow/extension_type.h' namespace 'arrow': c_string extension_name() shared_ptr[CDataType] storage_type() + @staticmethod + shared_ptr[CArray] WrapArray(shared_ptr[CDataType] ext_type, + shared_ptr[CArray] storage) + + @staticmethod + shared_ptr[CChunkedArray] WrapArray(shared_ptr[CDataType] ext_type, + shared_ptr[CChunkedArray] storage) + cdef cppclass CExtensionArray" arrow::ExtensionArray"(CArray): CExtensionArray(shared_ptr[CDataType], shared_ptr[CArray] storage) diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index 2ac30fd2cf201..f0ce42909f100 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -47,15 +47,41 @@ ValueError, match="Cannot import released ArrowArrayStream") +class ParamExtType(pa.PyExtensionType): + + def __init__(self, width): + self._width = width + pa.PyExtensionType.__init__(self, pa.binary(width)) + + @property + def width(self): + return self._width + + def __reduce__(self): + return ParamExtType, (self.width,) + + def make_schema(): return pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) +def make_extension_schema(): + return pa.schema([('ext', ParamExtType(3))], + metadata={b'key1': b'value1'}) + + def make_batch(): return pa.record_batch([[[1], [2, 42]]], make_schema()) +def make_extension_batch(): + schema = make_extension_schema() + ext_col = schema[0].type.wrap_array(pa.array([b"foo", b"bar"], + type=pa.binary(3))) + return pa.record_batch([ext_col], schema) + + def make_batches(): schema = make_schema() return [ @@ -174,19 +200,18 @@ def test_export_import_array(): pa.Array._import_from_c(ptr_array, ptr_schema) -@needs_cffi -def test_export_import_schema(): +def check_export_import_schema(schema_factory): c_schema = ffi.new("struct ArrowSchema*") ptr_schema = int(ffi.cast("uintptr_t", c_schema)) gc.collect() # Make sure no Arrow data dangles in a ref cycle old_allocated = pa.total_allocated_bytes() - make_schema()._export_to_c(ptr_schema) + schema_factory()._export_to_c(ptr_schema) assert pa.total_allocated_bytes() > old_allocated # Delete and recreate C++ object from exported pointer schema_new = pa.Schema._import_from_c(ptr_schema) - assert schema_new == make_schema() + assert schema_new == schema_factory() assert pa.total_allocated_bytes() == old_allocated del schema_new assert pa.total_allocated_bytes() == old_allocated @@ -205,7 +230,16 @@ def test_export_import_schema(): @needs_cffi -def test_export_import_batch(): +def test_export_import_schema(): + check_export_import_schema(make_schema) + + +@needs_cffi +def test_export_import_schema_with_extension(): + check_export_import_schema(make_extension_schema) + + +def check_export_import_batch(batch_factory): c_schema = ffi.new("struct ArrowSchema*") ptr_schema = int(ffi.cast("uintptr_t", c_schema)) c_array = ffi.new("struct ArrowArray*") @@ -215,8 +249,8 @@ def test_export_import_batch(): old_allocated = pa.total_allocated_bytes() # Schema is known up front - schema = make_schema() - batch = make_batch() + batch = batch_factory() + schema = batch.schema py_value = batch.to_pydict() batch._export_to_c(ptr_array) assert pa.total_allocated_bytes() > old_allocated @@ -233,14 +267,14 @@ def test_export_import_batch(): pa.RecordBatch._import_from_c(ptr_array, make_schema()) # Type is exported and imported at the same time - batch = make_batch() + batch = batch_factory() py_value = batch.to_pydict() batch._export_to_c(ptr_array, ptr_schema) # Delete and recreate C++ objects from exported pointers del batch batch_new = pa.RecordBatch._import_from_c(ptr_array, ptr_schema) assert batch_new.to_pydict() == py_value - assert batch_new.schema == make_schema() + assert batch_new.schema == batch_factory().schema assert pa.total_allocated_bytes() > old_allocated del batch_new assert pa.total_allocated_bytes() == old_allocated @@ -250,7 +284,7 @@ def test_export_import_batch(): # Not a struct type pa.int32()._export_to_c(ptr_schema) - make_batch()._export_to_c(ptr_array) + batch_factory()._export_to_c(ptr_array) with pytest.raises(ValueError, match="ArrowSchema describes non-struct type"): pa.RecordBatch._import_from_c(ptr_array, ptr_schema) @@ -259,6 +293,16 @@ def test_export_import_batch(): pa.RecordBatch._import_from_c(ptr_array, ptr_schema) +@needs_cffi +def test_export_import_batch(): + check_export_import_batch(make_batch) + + +@needs_cffi +def test_export_import_batch_with_extension(): + check_export_import_batch(make_extension_batch) + + def _export_import_batch_reader(ptr_stream, reader_factory): # Prepare input batches = make_batches() diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 391149772ccf9..d166c2af83e65 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -201,6 +201,34 @@ def test_ext_array_equality(): assert not d.equals(f) +def test_ext_array_wrap_array(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar", None], type=pa.binary(3)) + arr = ty.wrap_array(storage) + arr.validate(full=True) + assert isinstance(arr, pa.ExtensionArray) + assert arr.type == ty + assert arr.storage == storage + + storage = pa.chunked_array([[b"abc", b"def"], [b"ghi"]], + type=pa.binary(3)) + arr = ty.wrap_array(storage) + arr.validate(full=True) + assert isinstance(arr, pa.ChunkedArray) + assert arr.type == ty + assert arr.chunk(0).storage == storage.chunk(0) + assert arr.chunk(1).storage == storage.chunk(1) + + # Wrong storage type + storage = pa.array([b"foo", b"bar", None]) + with pytest.raises(TypeError, match="Incompatible storage type"): + ty.wrap_array(storage) + + # Not an array or chunked array + with pytest.raises(TypeError, match="Expected array or chunked array"): + ty.wrap_array(None) + + def test_ext_scalar_from_array(): data = [b"0123456789abcdef", b"0123456789abcdef", b"zyxwvutsrqponmlk", None] diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 06f753fa18d35..5b478ed7746dd 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -708,6 +708,45 @@ cdef class BaseExtensionType(DataType): """ return pyarrow_wrap_data_type(self.ext_type.storage_type()) + def wrap_array(self, storage): + """ + Wrap the given storage array as an extension array. + + Parameters + ---------- + storage : Array or ChunkedArray + + Returns + ------- + array : Array or ChunkedArray + Extension array wrapping the storage array + """ + cdef: + shared_ptr[CDataType] c_storage_type + + if isinstance(storage, Array): + c_storage_type = ( storage).ap.type() + elif isinstance(storage, ChunkedArray): + c_storage_type = ( storage).chunked_array.type() + else: + raise TypeError( + f"Expected array or chunked array, got {storage.__class__}") + + if not c_storage_type.get().Equals(deref(self.ext_type) + .storage_type()): + raise TypeError( + f"Incompatible storage type for {self}: " + f"expected {self.storage_type}, got {storage.type}") + + if isinstance(storage, Array): + return pyarrow_wrap_array( + self.ext_type.WrapArray( + self.sp_type, ( storage).sp_array)) + else: + return pyarrow_wrap_chunked_array( + self.ext_type.WrapArray( + self.sp_type, ( storage).sp_chunked_array)) + cdef class ExtensionType(BaseExtensionType): """