From 41fa07b442807c2b3af7b3023004562d1073eff2 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 5 Feb 2024 17:15:44 +0100 Subject: [PATCH] GH-39865: [C++] Strip extension metadata when importing a registered extension (#39866) ### Rationale for this change When importing an extension type from the C Data Interface and the extension type is registered, we would still leave the extension-related metadata on the storage type. ### What changes are included in this PR? Strip extension-related metadata on the storage type if we succeed in recreating the extension type. This matches the behavior of the IPC layer and allows for more exact roundtripping. ### Are these changes tested? Yes. ### Are there any user-facing changes? No, unless people mistakingly rely on the presence of said metadata. * Closes: #39865 Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/c/bridge.cc | 6 +++ cpp/src/arrow/c/bridge_test.cc | 48 ++++++++++++++++-------- cpp/src/arrow/util/key_value_metadata.cc | 18 ++++----- cpp/src/arrow/util/key_value_metadata.h | 11 +++--- 4 files changed, 52 insertions(+), 31 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 172ed8962ce77..9b165a10a61e7 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -914,6 +914,8 @@ struct DecodedMetadata { std::shared_ptr metadata; std::string extension_name; std::string extension_serialized; + int extension_name_index = -1; // index of extension_name in metadata + int extension_serialized_index = -1; // index of extension_serialized in metadata }; Result DecodeMetadata(const char* metadata) { @@ -956,8 +958,10 @@ Result DecodeMetadata(const char* metadata) { RETURN_NOT_OK(read_string(&values[i])); if (keys[i] == kExtensionTypeKeyName) { decoded.extension_name = values[i]; + decoded.extension_name_index = i; } else if (keys[i] == kExtensionMetadataKeyName) { decoded.extension_serialized = values[i]; + decoded.extension_serialized_index = i; } } decoded.metadata = key_value_metadata(std::move(keys), std::move(values)); @@ -1046,6 +1050,8 @@ struct SchemaImporter { ARROW_ASSIGN_OR_RAISE( type_, registered_ext_type->Deserialize(std::move(type_), metadata_.extension_serialized)); + RETURN_NOT_OK(metadata_.metadata->DeleteMany( + {metadata_.extension_name_index, metadata_.extension_serialized_index})); } } diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 321ec36c38d8c..8b67027454c55 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -1872,7 +1872,7 @@ class TestSchemaImport : public ::testing::Test, public SchemaStructBuilder { ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_)); Reset(); // for further tests cb.AssertCalled(); // was released - AssertTypeEqual(*expected, *type); + AssertTypeEqual(*expected, *type, /*check_metadata=*/true); } void CheckImport(const std::shared_ptr& expected) { @@ -1892,7 +1892,7 @@ class TestSchemaImport : public ::testing::Test, public SchemaStructBuilder { ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_)); Reset(); // for further tests cb.AssertCalled(); // was released - AssertSchemaEqual(*expected, *schema); + AssertSchemaEqual(*expected, *schema, /*check_metadata=*/true); } void CheckImportError() { @@ -3571,7 +3571,7 @@ class TestSchemaRoundtrip : public ::testing::Test { // Recreate the type ASSERT_OK_AND_ASSIGN(actual, ImportType(&c_schema)); type = factory_expected(); - AssertTypeEqual(*type, *actual); + AssertTypeEqual(*type, *actual, /*check_metadata=*/true); type.reset(); actual.reset(); @@ -3602,7 +3602,7 @@ class TestSchemaRoundtrip : public ::testing::Test { // Recreate the schema ASSERT_OK_AND_ASSIGN(actual, ImportSchema(&c_schema)); schema = factory(); - AssertSchemaEqual(*schema, *actual); + AssertSchemaEqual(*schema, *actual, /*check_metadata=*/true); schema.reset(); actual.reset(); @@ -3695,13 +3695,27 @@ TEST_F(TestSchemaRoundtrip, Dictionary) { } } +// Given an extension type, return a field of its storage type + the +// serialized extension metadata. +std::shared_ptr GetStorageWithMetadata(const std::string& field_name, + const std::shared_ptr& type) { + const auto& ext_type = checked_cast(*type); + auto storage_type = ext_type.storage_type(); + auto md = KeyValueMetadata::Make({kExtensionTypeKeyName, kExtensionMetadataKeyName}, + {ext_type.extension_name(), ext_type.Serialize()}); + return field(field_name, storage_type, /*nullable=*/true, md); +} + TEST_F(TestSchemaRoundtrip, UnregisteredExtension) { TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); }); TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), utf8()); }); - // Inside nested type - TestWithTypeFactory([]() { return list(dict_extension_type()); }, - []() { return list(dictionary(int8(), utf8())); }); + // Inside nested type. + // When an extension type is not known by the importer, it is imported + // as its storage type and the extension metadata is preserved on the field. + TestWithTypeFactory( + []() { return list(dict_extension_type()); }, + []() { return list(GetStorageWithMetadata("item", dict_extension_type())); }); } TEST_F(TestSchemaRoundtrip, RegisteredExtension) { @@ -3710,7 +3724,9 @@ TEST_F(TestSchemaRoundtrip, RegisteredExtension) { TestWithTypeFactory(dict_extension_type); TestWithTypeFactory(complex128); - // Inside nested type + // Inside nested type. + // When the extension type is registered, the extension metadata is removed + // from the storage type's field to ensure roundtripping (GH-39865). TestWithTypeFactory([]() { return list(uuid()); }); TestWithTypeFactory([]() { return list(dict_extension_type()); }); TestWithTypeFactory([]() { return list(complex128()); }); @@ -3810,7 +3826,7 @@ class TestArrayRoundtrip : public ::testing::Test { { std::shared_ptr expected; ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected())); - AssertTypeEqual(*expected->type(), *array->type()); + AssertTypeEqual(*expected->type(), *array->type(), /*check_metadata=*/true); AssertArraysEqual(*expected, *array, true); } array.reset(); @@ -3850,7 +3866,7 @@ class TestArrayRoundtrip : public ::testing::Test { { std::shared_ptr expected; ASSERT_OK_AND_ASSIGN(expected, ToResult(factory())); - AssertSchemaEqual(*expected->schema(), *batch->schema()); + AssertSchemaEqual(*expected->schema(), *batch->schema(), /*check_metadata=*/true); AssertBatchesEqual(*expected, *batch); } batch.reset(); @@ -4230,7 +4246,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test { { std::shared_ptr expected; ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected())); - AssertTypeEqual(*expected->type(), *array->type()); + AssertTypeEqual(*expected->type(), *array->type(), /*check_metadata=*/true); AssertArraysEqual(*expected, *array, true); } array.reset(); @@ -4276,7 +4292,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test { { std::shared_ptr expected; ASSERT_OK_AND_ASSIGN(expected, ToResult(factory())); - AssertSchemaEqual(*expected->schema(), *batch->schema()); + AssertSchemaEqual(*expected->schema(), *batch->schema(), /*check_metadata=*/true); AssertBatchesEqual(*expected, *batch); } batch.reset(); @@ -4353,7 +4369,7 @@ class TestArrayStreamExport : public BaseArrayStreamTest { SchemaExportGuard schema_guard(&c_schema); ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema)); ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema)); - AssertSchemaEqual(expected, *schema); + AssertSchemaEqual(expected, *schema, /*check_metadata=*/true); } void AssertStreamEnd(struct ArrowArrayStream* c_stream) { @@ -4437,7 +4453,7 @@ TEST_F(TestArrayStreamExport, ArrayLifetime) { { SchemaExportGuard schema_guard(&c_schema); ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema)); - AssertSchemaEqual(*schema, *got_schema); + AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true); } ASSERT_GT(pool_->bytes_allocated(), orig_allocated_); @@ -4462,7 +4478,7 @@ TEST_F(TestArrayStreamExport, Errors) { { SchemaExportGuard schema_guard(&c_schema); ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema)); - AssertSchemaEqual(schema, arrow::schema({})); + AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true); } struct ArrowArray c_array; @@ -4539,7 +4555,7 @@ TEST_F(TestArrayStreamRoundtrip, Simple) { ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches, orig_schema)); Roundtrip(std::move(reader), [&](const std::shared_ptr& reader) { - AssertSchemaEqual(*orig_schema, *reader->schema()); + AssertSchemaEqual(*orig_schema, *reader->schema(), /*check_metadata=*/true); AssertReaderNext(reader, *batches[0]); AssertReaderNext(reader, *batches[1]); AssertReaderEnd(reader); diff --git a/cpp/src/arrow/util/key_value_metadata.cc b/cpp/src/arrow/util/key_value_metadata.cc index bc48ae76c2a2f..002e8b0975094 100644 --- a/cpp/src/arrow/util/key_value_metadata.cc +++ b/cpp/src/arrow/util/key_value_metadata.cc @@ -90,7 +90,7 @@ void KeyValueMetadata::Append(std::string key, std::string value) { values_.push_back(std::move(value)); } -Result KeyValueMetadata::Get(const std::string& key) const { +Result KeyValueMetadata::Get(std::string_view key) const { auto index = FindKey(key); if (index < 0) { return Status::KeyError(key); @@ -129,7 +129,7 @@ Status KeyValueMetadata::DeleteMany(std::vector indices) { return Status::OK(); } -Status KeyValueMetadata::Delete(const std::string& key) { +Status KeyValueMetadata::Delete(std::string_view key) { auto index = FindKey(key); if (index < 0) { return Status::KeyError(key); @@ -138,20 +138,18 @@ Status KeyValueMetadata::Delete(const std::string& key) { } } -Status KeyValueMetadata::Set(const std::string& key, const std::string& value) { +Status KeyValueMetadata::Set(std::string key, std::string value) { auto index = FindKey(key); if (index < 0) { - Append(key, value); + Append(std::move(key), std::move(value)); } else { - keys_[index] = key; - values_[index] = value; + keys_[index] = std::move(key); + values_[index] = std::move(value); } return Status::OK(); } -bool KeyValueMetadata::Contains(const std::string& key) const { - return FindKey(key) >= 0; -} +bool KeyValueMetadata::Contains(std::string_view key) const { return FindKey(key) >= 0; } void KeyValueMetadata::reserve(int64_t n) { DCHECK_GE(n, 0); @@ -188,7 +186,7 @@ std::vector> KeyValueMetadata::sorted_pairs( return pairs; } -int KeyValueMetadata::FindKey(const std::string& key) const { +int KeyValueMetadata::FindKey(std::string_view key) const { for (size_t i = 0; i < keys_.size(); ++i) { if (keys_[i] == key) { return static_cast(i); diff --git a/cpp/src/arrow/util/key_value_metadata.h b/cpp/src/arrow/util/key_value_metadata.h index 8702ce73a639a..57ade11e75868 100644 --- a/cpp/src/arrow/util/key_value_metadata.h +++ b/cpp/src/arrow/util/key_value_metadata.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -44,13 +45,13 @@ class ARROW_EXPORT KeyValueMetadata { void ToUnorderedMap(std::unordered_map* out) const; void Append(std::string key, std::string value); - Result Get(const std::string& key) const; - bool Contains(const std::string& key) const; + Result Get(std::string_view key) const; + bool Contains(std::string_view key) const; // Note that deleting may invalidate known indices - Status Delete(const std::string& key); + Status Delete(std::string_view key); Status Delete(int64_t index); Status DeleteMany(std::vector indices); - Status Set(const std::string& key, const std::string& value); + Status Set(std::string key, std::string value); void reserve(int64_t n); @@ -63,7 +64,7 @@ class ARROW_EXPORT KeyValueMetadata { std::vector> sorted_pairs() const; /// \brief Perform linear search for key, returning -1 if not found - int FindKey(const std::string& key) const; + int FindKey(std::string_view key) const; std::shared_ptr Copy() const;