Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Arrow support to cover nullable data. #4049

Merged
merged 13 commits into from
Sep 11, 2024
5 changes: 5 additions & 0 deletions test/src/unit-arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ struct CPPArrayFx {
str_attr.set_cell_val_num(TILEDB_VAR_NUM);
attrs.push_back(str_attr);
}
{
auto str_attr = Attribute(ctx, "utf_string3", TILEDB_STRING_UTF8);
str_attr.set_cell_val_num(TILEDB_VAR_NUM);
attrs.push_back(str_attr);
}
{
auto str_attr = Attribute(ctx, "tiledb_char", TILEDB_CHAR);
str_attr.set_cell_val_num(TILEDB_VAR_NUM);
Expand Down
4 changes: 4 additions & 0 deletions test/src/unit_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def create(self):
utf_strings[np.random.randint(0, col_size, size=col_size//2)] = ''
self.data['utf_string2'] = pa.array(utf_strings)

# another version with some cells set to NULL
utf_strings[np.random.randint(0, col_size, size=col_size//2)] = None
self.data['utf_string3'] = pa.array(utf_strings)

self.data['datetime_ns'] = pa.array(rand_datetime64_array(col_size))

##########################################################################
Expand Down
113 changes: 93 additions & 20 deletions tiledb/sm/cpp_api/arrow_io_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
* source: https://arrow.apache.org/docs/format/CDataInterface.html
*/

#ifndef ARROW_C_DATA_INTERFACE
#define ARROW_C_DATA_INTERFACE

#define ARROW_FLAG_DICTIONARY_ORDERED 1
#define ARROW_FLAG_NULLABLE 2
#define ARROW_FLAG_MAP_KEYS_SORTED 4
Expand Down Expand Up @@ -76,6 +79,9 @@ struct ArrowArray {
// Opaque producer-specific data
void* private_data;
};

#endif

/* End Arrow C API */
/* ************************************************************************ */

Expand Down Expand Up @@ -123,17 +129,21 @@ struct TypeInfo {

// is this represented as "Arrow large"
bool arrow_large;

bool nullable;
};

struct BufferInfo {
TypeInfo tdbtype;
bool is_var; // is var-length
bool is_nullable; // is nullable
uint64_t data_num; // number of data elements
void* data; // data pointer
uint64_t data_elem_size; // bytes per data element
uint64_t offsets_num; // number of offsets
void* offsets; // offsets pointer
size_t offsets_elem_size; // bytes per offset element
uint8_t* validity; // optional validity buffer (if is_nullable)
};

/* ****************************** */
Expand Down Expand Up @@ -262,6 +272,7 @@ ArrowInfo tiledb_buffer_arrow_fmt(BufferInfo bufferinfo, bool use_list = true) {

TypeInfo arrow_type_to_tiledb(ArrowSchema* arw_schema) {
auto fmt = std::string(arw_schema->format);
bool nullable = arw_schema->flags & ARROW_FLAG_NULLABLE;
bool large = false;
if (fmt == "+l") {
large = false;
Expand All @@ -274,36 +285,36 @@ TypeInfo arrow_type_to_tiledb(ArrowSchema* arw_schema) {
}

if (fmt == "i")
return {TILEDB_INT32, 4, 1, large};
return {TILEDB_INT32, 4, 1, large, nullable};
else if (fmt == "l")
return {TILEDB_INT64, 8, 1, large};
return {TILEDB_INT64, 8, 1, large, nullable};
else if (fmt == "f")
return {TILEDB_FLOAT32, 4, 1, large};
return {TILEDB_FLOAT32, 4, 1, large, nullable};
else if (fmt == "g")
return {TILEDB_FLOAT64, 8, 1, large};
return {TILEDB_FLOAT64, 8, 1, large, nullable};
else if (fmt == "B")
return {TILEDB_BLOB, 1, 1, large};
return {TILEDB_BLOB, 1, 1, large, nullable};
else if (fmt == "c")
return {TILEDB_INT8, 1, 1, large};
return {TILEDB_INT8, 1, 1, large, nullable};
else if (fmt == "C")
return {TILEDB_UINT8, 1, 1, large};
return {TILEDB_UINT8, 1, 1, large, nullable};
else if (fmt == "s")
return {TILEDB_INT16, 2, 1, large};
return {TILEDB_INT16, 2, 1, large, nullable};
else if (fmt == "S")
return {TILEDB_UINT16, 2, 1, large};
return {TILEDB_UINT16, 2, 1, large, nullable};
else if (fmt == "I")
return {TILEDB_UINT32, 4, 1, large};
return {TILEDB_UINT32, 4, 1, large, nullable};
else if (fmt == "L")
return {TILEDB_UINT64, 8, 1, large};
return {TILEDB_UINT64, 8, 1, large, nullable};
// this is kind of a hack
// technically 'tsn:' is timezone-specific, which we don't support
// however, the blank (no suffix) base is interconvertible w/ np.datetime64
else if (fmt == "tsn:")
return {TILEDB_DATETIME_NS, 8, 1, large};
return {TILEDB_DATETIME_NS, 8, 1, large, nullable};
else if (fmt == "z" || fmt == "Z")
return {TILEDB_CHAR, 1, TILEDB_VAR_NUM, fmt == "Z"};
return {TILEDB_CHAR, 1, TILEDB_VAR_NUM, fmt == "Z", nullable};
else if (fmt == "u" || fmt == "U")
return {TILEDB_STRING_UTF8, 1, TILEDB_VAR_NUM, fmt == "U"};
return {TILEDB_STRING_UTF8, 1, TILEDB_VAR_NUM, fmt == "U", nullable};
else
throw tiledb::TileDBError(
"[TileDB-Arrow]: Unknown or unsupported Arrow format string '" + fmt +
Expand All @@ -314,9 +325,11 @@ TypeInfo tiledb_dt_info(const ArraySchema& schema, const std::string& name) {
if (schema.has_attribute(name)) {
auto attr = schema.attribute(name);
auto retval = TypeInfo();
retval.type = attr.type(),
retval.elem_size = tiledb::impl::type_size(attr.type()),
retval.cell_val_num = attr.cell_val_num(), retval.arrow_large = false;
retval.type = attr.type();
retval.elem_size = tiledb::impl::type_size(attr.type());
retval.cell_val_num = attr.cell_val_num();
retval.arrow_large = false;
retval.nullable = attr.nullable();
return retval;
} else if (schema.domain().has_dimension(name)) {
auto dom = schema.domain();
Expand All @@ -327,6 +340,7 @@ TypeInfo tiledb_dt_info(const ArraySchema& schema, const std::string& name) {
retval.elem_size = tiledb::impl::type_size(dim.type());
retval.cell_val_num = dim.cell_val_num();
retval.arrow_large = false;
retval.nullable = false;
return retval;
} else {
throw TDB_LERROR("Schema does not have attribute named '" + name + "'");
Expand Down Expand Up @@ -604,6 +618,18 @@ ArrowImporter::~ArrowImporter() {
}
}

static inline int8_t bitmap_get(const uint8_t* bits, int64_t i) {
return (bits[i >> 3] >> (i & 0x07)) & 1;
}

static void bitmap_to_bytemap(void* bitmap, int64_t n) {
uint8_t* bmp = static_cast<uint8_t*>(bitmap);
std::vector<uint8_t> valcpy(bmp, bmp + n); // we make as we will overwrite.
for (auto i = 0; i < n; i++) {
bmp[i] = bitmap_get(valcpy.data(), i);
}
}

void ArrowImporter::import_(
std::string name, ArrowArray* arw_array, ArrowSchema* arw_schema) {
auto typeinfo = arrow_type_to_tiledb(arw_schema);
Expand All @@ -630,6 +656,7 @@ void ArrowImporter::import_(
query_->set_data_buffer(name, p_data, data_nbytes);
query_->set_offsets_buffer(
name, static_cast<uint64_t*>(p_offsets), num_offsets + 1);

} else {
// fixed-size attribute (not TILEDB_VAR_NUM)
assert(arw_array->n_buffers == 2);
Expand All @@ -639,6 +666,15 @@ void ArrowImporter::import_(

query_->set_data_buffer(name, static_cast<void*>(p_data), data_num);
}

if (typeinfo.nullable && arw_array->buffers[0] != nullptr) {
bitmap_to_bytemap(
const_cast<void*>(arw_array->buffers[0]), arw_array->length);
query_->set_validity_buffer(
name,
static_cast<uint8_t*>(const_cast<void*>(arw_array->buffers[0])),
arw_array->length);
kounelisagis marked this conversation as resolved.
Show resolved Hide resolved
}
}

/* ****************************** */
Expand Down Expand Up @@ -670,6 +706,7 @@ BufferInfo ArrowExporter::buffer_info(const std::string& name) {
uint64_t* offsets = nullptr;
uint64_t offsets_nelem = 0;
uint64_t elem_size = 0;
uint8_t* validity = nullptr;

auto typeinfo = tiledb_dt_info(query_->array().schema(), name);

Expand Down Expand Up @@ -714,6 +751,10 @@ BufferInfo ArrowExporter::buffer_info(const std::string& name) {
query_->get_data_buffer(name, &data, &data_nelem, &elem_size);
}

if (typeinfo.nullable) {
query_->get_validity_buffer(name, &validity, &data_nelem);
}

auto retval = BufferInfo();
retval.tdbtype = typeinfo;
retval.is_var = is_var;
Expand All @@ -723,6 +764,8 @@ BufferInfo ArrowExporter::buffer_info(const std::string& name) {
retval.offsets_num = (is_var ? offsets_nelem : 1);
retval.offsets = offsets;
retval.offsets_elem_size = offsets_elem_nbytes;
retval.is_nullable = typeinfo.nullable;
retval.validity = validity;

return retval;
}
Expand All @@ -733,8 +776,32 @@ int64_t flags_for_buffer(BufferInfo binfo) {
#define ARROW_FLAG_NULLABLE 2
#define ARROW_FLAG_MAP_KEYS_SORTED 4
*/
(void)binfo;
return 0;
int64_t val = 0;
if (binfo.is_nullable)
val |= ARROW_FLAG_NULLABLE;
return val;
}

int64_t bytemap_to_bitmap(uint8_t* bytemap, int64_t num) {
// helper function from column_buffer class in libtiledbsoma
// note that it transforms bytemap _in place_ by design, as we now own the
// buffer added null count return for convenience
int64_t nulls = 0;
int i_dst = 0;
for (unsigned int i_src = 0; i_src < num; i_src++) {
nulls += bytemap[i_src] == 0;
// Overwrite every 8 bytes with a one-byte bitmap
if (i_src % 8 == 0) {
// Each bit in the bitmap corresponds to one byte in the bytemap
// Note: the bitmap must be byte-aligned (8 bits)
int bitmap = 0;
for (unsigned int i = i_src; i < i_src + 8 && i < num; i++) {
bitmap |= bytemap[i] << (i % 8);
}
bytemap[i_dst++] = bitmap;
}
}
return nulls;
}

void ArrowExporter::export_(
Expand Down Expand Up @@ -763,6 +830,12 @@ void ArrowExporter::export_(
}
cpp_schema->export_ptr(schema);

int64_t null_num = 0;
if (bufferinfo.is_nullable) {
null_num = bytemap_to_bitmap(bufferinfo.validity, bufferinfo.data_num);
buffers[0] = bufferinfo.validity;
}

size_t elem_num = 0;
if (bufferinfo.is_var) {
// adjust for offset unless empty result
Expand All @@ -773,7 +846,7 @@ void ArrowExporter::export_(

auto cpp_arrow_array = new CPPArrowArray(
elem_num, // elem_num
0, // null_num
null_num, // null_num
0, // offset
{}, // children
buffers);
Expand Down
Loading