Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-13855: [C++][Python] Implement C data interface support for extension types #11071

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 96 additions & 32 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/buffer.h"
#include "arrow/c/helpers.h"
#include "arrow/c/util_internal.h"
#include "arrow/extension_type.h"
#include "arrow/memory_pool.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
Expand Down Expand Up @@ -56,8 +57,6 @@ using internal::ArrayExportTraits;
using internal::SchemaExportGuard;
using internal::SchemaExportTraits;

// TODO export / import Extension types and arrays

namespace {

Status ExportingNotImplemented(const DataType& type) {
Expand Down Expand Up @@ -171,23 +170,26 @@ struct SchemaExporter {
export_.name_ = field.name();
flags_ = field.nullable() ? ARROW_FLAG_NULLABLE : 0;

const DataType& type = *field.type();
RETURN_NOT_OK(ExportFormat(type));
RETURN_NOT_OK(ExportChildren(type.fields()));
const DataType* type = UnwrapExtension(field.type().get());
RETURN_NOT_OK(ExportFormat(*type));
RETURN_NOT_OK(ExportChildren(type->fields()));
RETURN_NOT_OK(ExportMetadata(field.metadata().get()));
return Status::OK();
}

Status ExportType(const DataType& type) {
Status ExportType(const DataType& orig_type) {
flags_ = ARROW_FLAG_NULLABLE;

RETURN_NOT_OK(ExportFormat(type));
RETURN_NOT_OK(ExportChildren(type.fields()));
const DataType* type = UnwrapExtension(&orig_type);
RETURN_NOT_OK(ExportFormat(*type));
RETURN_NOT_OK(ExportChildren(type->fields()));
// There may be additional metadata to export
RETURN_NOT_OK(ExportMetadata(nullptr));
return Status::OK();
}

Status ExportSchema(const Schema& schema) {
static StructType dummy_struct_type({});
static const StructType dummy_struct_type({});
flags_ = 0;

RETURN_NOT_OK(ExportFormat(dummy_struct_type));
Expand Down Expand Up @@ -232,6 +234,17 @@ struct SchemaExporter {
c_struct->release = ReleaseExportedSchema;
}

const DataType* UnwrapExtension(const DataType* type) {
if (type->id() == Type::EXTENSION) {
const auto& ext_type = checked_cast<const ExtensionType&>(*type);
additional_metadata_.reserve(2);
additional_metadata_.emplace_back(kExtensionTypeKeyName, ext_type.extension_name());
additional_metadata_.emplace_back(kExtensionMetadataKeyName, ext_type.Serialize());
return ext_type.storage_type().get();
}
return type;
}

Status ExportFormat(const DataType& type) {
if (type.id() == Type::DICTIONARY) {
const auto& dict_type = checked_cast<const DictionaryType&>(type);
Expand Down Expand Up @@ -259,10 +272,29 @@ struct SchemaExporter {
return Status::OK();
}

Status ExportMetadata(const KeyValueMetadata* metadata) {
if (metadata != nullptr && metadata->size() >= 0) {
ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(*metadata));
Status ExportMetadata(const KeyValueMetadata* orig_metadata) {
static const KeyValueMetadata empty_metadata;

if (orig_metadata == nullptr) {
orig_metadata = &empty_metadata;
}
if (additional_metadata_.empty()) {
if (orig_metadata->size() > 0) {
ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(*orig_metadata));
}
return Status::OK();
}
// Additional metadata needs to be appended to the existing
// (for extension types)
KeyValueMetadata metadata(orig_metadata->keys(), orig_metadata->values());
for (const auto& kv : additional_metadata_) {
// The metadata may already be there => ignore
if (metadata.Contains(kv.first)) {
continue;
}
metadata.Append(kv.first, kv.second);
}
ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(metadata));
return Status::OK();
}

Expand Down Expand Up @@ -442,6 +474,7 @@ struct SchemaExporter {

ExportedSchemaPrivateData export_;
int64_t flags_ = 0;
std::vector<std::pair<std::string, std::string>> additional_metadata_;
std::unique_ptr<SchemaExporter> dict_exporter_;
std::vector<SchemaExporter> child_exporters_;
};
Expand Down Expand Up @@ -721,7 +754,13 @@ class FormatStringParser {
size_t index_;
};

Result<std::shared_ptr<KeyValueMetadata>> DecodeMetadata(const char* metadata) {
struct DecodedMetadata {
std::shared_ptr<KeyValueMetadata> metadata;
std::string extension_name;
std::string extension_serialized;
};

Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
auto read_int32 = [&](int32_t* out) -> Status {
int32_t v;
memcpy(&v, metadata, 4);
Expand All @@ -744,21 +783,29 @@ Result<std::shared_ptr<KeyValueMetadata>> DecodeMetadata(const char* metadata) {
return Status::OK();
};

DecodedMetadata decoded;

if (metadata == nullptr) {
return nullptr;
return decoded;
}
int32_t npairs;
RETURN_NOT_OK(read_int32(&npairs));
if (npairs == 0) {
return nullptr;
return decoded;
}
std::vector<std::string> keys(npairs);
std::vector<std::string> values(npairs);
for (int32_t i = 0; i < npairs; ++i) {
RETURN_NOT_OK(read_string(&keys[i]));
RETURN_NOT_OK(read_string(&values[i]));
if (keys[i] == kExtensionTypeKeyName) {
decoded.extension_name = values[i];
} else if (keys[i] == kExtensionMetadataKeyName) {
decoded.extension_serialized = values[i];
}
}
return key_value_metadata(std::move(keys), std::move(values));
decoded.metadata = key_value_metadata(std::move(keys), std::move(values));
return decoded;
}

struct SchemaImporter {
Expand All @@ -775,10 +822,9 @@ struct SchemaImporter {
}

Result<std::shared_ptr<Field>> MakeField() const {
ARROW_ASSIGN_OR_RAISE(auto metadata, DecodeMetadata(c_struct_->metadata));
const char* name = c_struct_->name ? c_struct_->name : "";
bool nullable = (c_struct_->flags & ARROW_FLAG_NULLABLE) != 0;
return field(name, type_, nullable, std::move(metadata));
return field(name, type_, nullable, std::move(metadata_.metadata));
}

Result<std::shared_ptr<Schema>> MakeSchema() const {
Expand All @@ -787,8 +833,7 @@ struct SchemaImporter {
"Cannot import schema: ArrowSchema describes non-struct type ",
type_->ToString());
}
ARROW_ASSIGN_OR_RAISE(auto metadata, DecodeMetadata(c_struct_->metadata));
return schema(type_->fields(), std::move(metadata));
return schema(type_->fields(), std::move(metadata_.metadata));
}

Result<std::shared_ptr<DataType>> MakeType() const { return type_; }
Expand Down Expand Up @@ -836,6 +881,20 @@ struct SchemaImporter {
bool ordered = (c_struct_->flags & ARROW_FLAG_DICTIONARY_ORDERED) != 0;
type_ = dictionary(type_, dict_importer.type_, ordered);
}

// Import metadata
ARROW_ASSIGN_OR_RAISE(metadata_, DecodeMetadata(c_struct_->metadata));

// Detect extension type
if (!metadata_.extension_name.empty()) {
const auto registered_ext_type = GetExtensionType(metadata_.extension_name);
if (registered_ext_type) {
ARROW_ASSIGN_OR_RAISE(
type_, registered_ext_type->Deserialize(std::move(type_),
metadata_.extension_serialized));
}
}

return Status::OK();
}

Expand Down Expand Up @@ -1130,6 +1189,7 @@ struct SchemaImporter {
int64_t recursion_level_;
std::vector<SchemaImporter> child_importers_;
std::shared_ptr<DataType> type_;
DecodedMetadata metadata_;
};

} // namespace
Expand Down Expand Up @@ -1255,8 +1315,15 @@ struct ArrayImporter {
}

Status DoImport() {
// Unwrap extension type
const DataType* storage_type = type_.get();
if (storage_type->id() == Type::EXTENSION) {
storage_type =
checked_cast<const ExtensionType&>(*storage_type).storage_type().get();
}

// First import children (required for reconstituting parent array data)
const auto& fields = type_->fields();
const auto& fields = storage_type->fields();
if (c_struct_->n_children != static_cast<int64_t>(fields.size())) {
return Status::Invalid("ArrowArray struct has ", c_struct_->n_children,
" children, expected ", fields.size(), " for type ",
Expand All @@ -1270,15 +1337,15 @@ struct ArrayImporter {
}

// Import main data
RETURN_NOT_OK(ImportMainData());
RETURN_NOT_OK(VisitTypeInline(*storage_type, this));

bool is_dict_type = (type_->id() == Type::DICTIONARY);
bool is_dict_type = (storage_type->id() == Type::DICTIONARY);
if (c_struct_->dictionary != nullptr) {
if (!is_dict_type) {
return Status::Invalid("Import type is ", type_->ToString(),
" but dictionary field in ArrowArray struct is not null");
}
const auto& dict_type = checked_cast<const DictionaryType&>(*type_);
const auto& dict_type = checked_cast<const DictionaryType&>(*storage_type);
// Import dictionary values
ArrayImporter dict_importer(dict_type.value_type());
RETURN_NOT_OK(dict_importer.ImportDict(this, c_struct_->dictionary));
Expand All @@ -1292,13 +1359,11 @@ struct ArrayImporter {
return Status::OK();
}

Status ImportMainData() { return VisitTypeInline(*type_, this); }

Status Visit(const DataType& type) {
return Status::NotImplemented("Cannot import array of type ", type_->ToString());
}

Status Visit(const FixedWidthType& type) { return ImportFixedSizePrimitive(); }
Status Visit(const FixedWidthType& type) { return ImportFixedSizePrimitive(type); }

Status Visit(const NullType& type) {
RETURN_NOT_OK(CheckNoChildren());
Expand Down Expand Up @@ -1352,16 +1417,15 @@ struct ArrayImporter {
return Status::OK();
}

Status ImportFixedSizePrimitive() {
const auto& fw_type = checked_cast<const FixedWidthType&>(*type_);
Status ImportFixedSizePrimitive(const FixedWidthType& type) {
RETURN_NOT_OK(CheckNoChildren());
RETURN_NOT_OK(CheckNumBuffers(2));
RETURN_NOT_OK(AllocateArrayData());
RETURN_NOT_OK(ImportNullBitmap());
if (BitUtil::IsMultipleOf8(fw_type.bit_width())) {
RETURN_NOT_OK(ImportFixedSizeBuffer(1, fw_type.bit_width() / 8));
if (BitUtil::IsMultipleOf8(type.bit_width())) {
RETURN_NOT_OK(ImportFixedSizeBuffer(1, type.bit_width() / 8));
} else {
DCHECK_EQ(fw_type.bit_width(), 1);
DCHECK_EQ(type.bit_width(), 1);
RETURN_NOT_OK(ImportBitsBuffer(1));
}
return Status::OK();
Expand Down
Loading