Skip to content

Commit

Permalink
ARROW-13855: [C++][Python] Implement C data interface support for ext…
Browse files Browse the repository at this point in the history
…ension types

Closes #11071 from pitrou/ARROW-13855-export-extension

Authored-by: Antoine Pitrou <antoine@python.org>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
pitrou authored and lidavidm committed Sep 3, 2021
1 parent a45fc3f commit 5ead375
Show file tree
Hide file tree
Showing 15 changed files with 626 additions and 168 deletions.
128 changes: 96 additions & 32 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/buffer.h"
#include "arrow/c/helpers.h"
#include "arrow/c/util_internal.h"
#include "arrow/extension_type.h"
#include "arrow/memory_pool.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
Expand Down Expand Up @@ -56,8 +57,6 @@ using internal::ArrayExportTraits;
using internal::SchemaExportGuard;
using internal::SchemaExportTraits;

// TODO export / import Extension types and arrays

namespace {

Status ExportingNotImplemented(const DataType& type) {
Expand Down Expand Up @@ -171,23 +170,26 @@ struct SchemaExporter {
export_.name_ = field.name();
flags_ = field.nullable() ? ARROW_FLAG_NULLABLE : 0;

const DataType& type = *field.type();
RETURN_NOT_OK(ExportFormat(type));
RETURN_NOT_OK(ExportChildren(type.fields()));
const DataType* type = UnwrapExtension(field.type().get());
RETURN_NOT_OK(ExportFormat(*type));
RETURN_NOT_OK(ExportChildren(type->fields()));
RETURN_NOT_OK(ExportMetadata(field.metadata().get()));
return Status::OK();
}

Status ExportType(const DataType& type) {
Status ExportType(const DataType& orig_type) {
flags_ = ARROW_FLAG_NULLABLE;

RETURN_NOT_OK(ExportFormat(type));
RETURN_NOT_OK(ExportChildren(type.fields()));
const DataType* type = UnwrapExtension(&orig_type);
RETURN_NOT_OK(ExportFormat(*type));
RETURN_NOT_OK(ExportChildren(type->fields()));
// There may be additional metadata to export
RETURN_NOT_OK(ExportMetadata(nullptr));
return Status::OK();
}

Status ExportSchema(const Schema& schema) {
static StructType dummy_struct_type({});
static const StructType dummy_struct_type({});
flags_ = 0;

RETURN_NOT_OK(ExportFormat(dummy_struct_type));
Expand Down Expand Up @@ -232,6 +234,17 @@ struct SchemaExporter {
c_struct->release = ReleaseExportedSchema;
}

const DataType* UnwrapExtension(const DataType* type) {
if (type->id() == Type::EXTENSION) {
const auto& ext_type = checked_cast<const ExtensionType&>(*type);
additional_metadata_.reserve(2);
additional_metadata_.emplace_back(kExtensionTypeKeyName, ext_type.extension_name());
additional_metadata_.emplace_back(kExtensionMetadataKeyName, ext_type.Serialize());
return ext_type.storage_type().get();
}
return type;
}

Status ExportFormat(const DataType& type) {
if (type.id() == Type::DICTIONARY) {
const auto& dict_type = checked_cast<const DictionaryType&>(type);
Expand Down Expand Up @@ -259,10 +272,29 @@ struct SchemaExporter {
return Status::OK();
}

Status ExportMetadata(const KeyValueMetadata* metadata) {
if (metadata != nullptr && metadata->size() >= 0) {
ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(*metadata));
Status ExportMetadata(const KeyValueMetadata* orig_metadata) {
static const KeyValueMetadata empty_metadata;

if (orig_metadata == nullptr) {
orig_metadata = &empty_metadata;
}
if (additional_metadata_.empty()) {
if (orig_metadata->size() > 0) {
ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(*orig_metadata));
}
return Status::OK();
}
// Additional metadata needs to be appended to the existing
// (for extension types)
KeyValueMetadata metadata(orig_metadata->keys(), orig_metadata->values());
for (const auto& kv : additional_metadata_) {
// The metadata may already be there => ignore
if (metadata.Contains(kv.first)) {
continue;
}
metadata.Append(kv.first, kv.second);
}
ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(metadata));
return Status::OK();
}

Expand Down Expand Up @@ -442,6 +474,7 @@ struct SchemaExporter {

ExportedSchemaPrivateData export_;
int64_t flags_ = 0;
std::vector<std::pair<std::string, std::string>> additional_metadata_;
std::unique_ptr<SchemaExporter> dict_exporter_;
std::vector<SchemaExporter> child_exporters_;
};
Expand Down Expand Up @@ -721,7 +754,13 @@ class FormatStringParser {
size_t index_;
};

Result<std::shared_ptr<KeyValueMetadata>> DecodeMetadata(const char* metadata) {
struct DecodedMetadata {
std::shared_ptr<KeyValueMetadata> metadata;
std::string extension_name;
std::string extension_serialized;
};

Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
auto read_int32 = [&](int32_t* out) -> Status {
int32_t v;
memcpy(&v, metadata, 4);
Expand All @@ -744,21 +783,29 @@ Result<std::shared_ptr<KeyValueMetadata>> DecodeMetadata(const char* metadata) {
return Status::OK();
};

DecodedMetadata decoded;

if (metadata == nullptr) {
return nullptr;
return decoded;
}
int32_t npairs;
RETURN_NOT_OK(read_int32(&npairs));
if (npairs == 0) {
return nullptr;
return decoded;
}
std::vector<std::string> keys(npairs);
std::vector<std::string> values(npairs);
for (int32_t i = 0; i < npairs; ++i) {
RETURN_NOT_OK(read_string(&keys[i]));
RETURN_NOT_OK(read_string(&values[i]));
if (keys[i] == kExtensionTypeKeyName) {
decoded.extension_name = values[i];
} else if (keys[i] == kExtensionMetadataKeyName) {
decoded.extension_serialized = values[i];
}
}
return key_value_metadata(std::move(keys), std::move(values));
decoded.metadata = key_value_metadata(std::move(keys), std::move(values));
return decoded;
}

struct SchemaImporter {
Expand All @@ -775,10 +822,9 @@ struct SchemaImporter {
}

Result<std::shared_ptr<Field>> MakeField() const {
ARROW_ASSIGN_OR_RAISE(auto metadata, DecodeMetadata(c_struct_->metadata));
const char* name = c_struct_->name ? c_struct_->name : "";
bool nullable = (c_struct_->flags & ARROW_FLAG_NULLABLE) != 0;
return field(name, type_, nullable, std::move(metadata));
return field(name, type_, nullable, std::move(metadata_.metadata));
}

Result<std::shared_ptr<Schema>> MakeSchema() const {
Expand All @@ -787,8 +833,7 @@ struct SchemaImporter {
"Cannot import schema: ArrowSchema describes non-struct type ",
type_->ToString());
}
ARROW_ASSIGN_OR_RAISE(auto metadata, DecodeMetadata(c_struct_->metadata));
return schema(type_->fields(), std::move(metadata));
return schema(type_->fields(), std::move(metadata_.metadata));
}

Result<std::shared_ptr<DataType>> MakeType() const { return type_; }
Expand Down Expand Up @@ -836,6 +881,20 @@ struct SchemaImporter {
bool ordered = (c_struct_->flags & ARROW_FLAG_DICTIONARY_ORDERED) != 0;
type_ = dictionary(type_, dict_importer.type_, ordered);
}

// Import metadata
ARROW_ASSIGN_OR_RAISE(metadata_, DecodeMetadata(c_struct_->metadata));

// Detect extension type
if (!metadata_.extension_name.empty()) {
const auto registered_ext_type = GetExtensionType(metadata_.extension_name);
if (registered_ext_type) {
ARROW_ASSIGN_OR_RAISE(
type_, registered_ext_type->Deserialize(std::move(type_),
metadata_.extension_serialized));
}
}

return Status::OK();
}

Expand Down Expand Up @@ -1130,6 +1189,7 @@ struct SchemaImporter {
int64_t recursion_level_;
std::vector<SchemaImporter> child_importers_;
std::shared_ptr<DataType> type_;
DecodedMetadata metadata_;
};

} // namespace
Expand Down Expand Up @@ -1255,8 +1315,15 @@ struct ArrayImporter {
}

Status DoImport() {
// Unwrap extension type
const DataType* storage_type = type_.get();
if (storage_type->id() == Type::EXTENSION) {
storage_type =
checked_cast<const ExtensionType&>(*storage_type).storage_type().get();
}

// First import children (required for reconstituting parent array data)
const auto& fields = type_->fields();
const auto& fields = storage_type->fields();
if (c_struct_->n_children != static_cast<int64_t>(fields.size())) {
return Status::Invalid("ArrowArray struct has ", c_struct_->n_children,
" children, expected ", fields.size(), " for type ",
Expand All @@ -1270,15 +1337,15 @@ struct ArrayImporter {
}

// Import main data
RETURN_NOT_OK(ImportMainData());
RETURN_NOT_OK(VisitTypeInline(*storage_type, this));

bool is_dict_type = (type_->id() == Type::DICTIONARY);
bool is_dict_type = (storage_type->id() == Type::DICTIONARY);
if (c_struct_->dictionary != nullptr) {
if (!is_dict_type) {
return Status::Invalid("Import type is ", type_->ToString(),
" but dictionary field in ArrowArray struct is not null");
}
const auto& dict_type = checked_cast<const DictionaryType&>(*type_);
const auto& dict_type = checked_cast<const DictionaryType&>(*storage_type);
// Import dictionary values
ArrayImporter dict_importer(dict_type.value_type());
RETURN_NOT_OK(dict_importer.ImportDict(this, c_struct_->dictionary));
Expand All @@ -1292,13 +1359,11 @@ struct ArrayImporter {
return Status::OK();
}

Status ImportMainData() { return VisitTypeInline(*type_, this); }

Status Visit(const DataType& type) {
return Status::NotImplemented("Cannot import array of type ", type_->ToString());
}

Status Visit(const FixedWidthType& type) { return ImportFixedSizePrimitive(); }
Status Visit(const FixedWidthType& type) { return ImportFixedSizePrimitive(type); }

Status Visit(const NullType& type) {
RETURN_NOT_OK(CheckNoChildren());
Expand Down Expand Up @@ -1352,16 +1417,15 @@ struct ArrayImporter {
return Status::OK();
}

Status ImportFixedSizePrimitive() {
const auto& fw_type = checked_cast<const FixedWidthType&>(*type_);
Status ImportFixedSizePrimitive(const FixedWidthType& type) {
RETURN_NOT_OK(CheckNoChildren());
RETURN_NOT_OK(CheckNumBuffers(2));
RETURN_NOT_OK(AllocateArrayData());
RETURN_NOT_OK(ImportNullBitmap());
if (BitUtil::IsMultipleOf8(fw_type.bit_width())) {
RETURN_NOT_OK(ImportFixedSizeBuffer(1, fw_type.bit_width() / 8));
if (BitUtil::IsMultipleOf8(type.bit_width())) {
RETURN_NOT_OK(ImportFixedSizeBuffer(1, type.bit_width() / 8));
} else {
DCHECK_EQ(fw_type.bit_width(), 1);
DCHECK_EQ(type.bit_width(), 1);
RETURN_NOT_OK(ImportBitsBuffer(1));
}
return Status::OK();
Expand Down
Loading

0 comments on commit 5ead375

Please sign in to comment.