Skip to content

Commit

Permalink
Merge 5219891 into a196e19
Browse files Browse the repository at this point in the history
  • Loading branch information
zverevgeny authored Jul 28, 2024
2 parents a196e19 + 5219891 commit 8432391
Show file tree
Hide file tree
Showing 94 changed files with 1,539 additions and 935 deletions.
15 changes: 14 additions & 1 deletion ydb/core/formats/arrow/arrow_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ std::shared_ptr<arrow::Scalar> DefaultScalar(const std::shared_ptr<arrow::DataTy
}
return true;
});
Y_ABORT_UNLESS(out);
AFL_VERIFY(out)("type", type->ToString());
return out;
}

Expand Down Expand Up @@ -634,6 +634,19 @@ int ScalarCompare(const std::shared_ptr<arrow::Scalar>& x, const std::shared_ptr
return ScalarCompare(*x, *y);
}

int ScalarCompareNullable(const std::shared_ptr<arrow::Scalar>& x, const std::shared_ptr<arrow::Scalar>& y) {
if (!x && !!y) {
return -1;
}
if (!!x && !y) {
return 1;
}
if (!x && !y) {
return 0;
}
return ScalarCompare(*x, *y);
}

std::shared_ptr<arrow::RecordBatch> SortBatch(const std::shared_ptr<arrow::RecordBatch>& batch,
const std::shared_ptr<arrow::Schema>& sortingKey, const bool andUnique) {
auto sortPermutation = MakeSortPermutation(batch, sortingKey, andUnique);
Expand Down
1 change: 1 addition & 0 deletions ydb/core/formats/arrow/arrow_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ std::shared_ptr<arrow::Scalar> GetScalar(const std::shared_ptr<arrow::Array>& ar
bool IsGoodScalar(const std::shared_ptr<arrow::Scalar>& x);
int ScalarCompare(const arrow::Scalar& x, const arrow::Scalar& y);
int ScalarCompare(const std::shared_ptr<arrow::Scalar>& x, const std::shared_ptr<arrow::Scalar>& y);
int ScalarCompareNullable(const std::shared_ptr<arrow::Scalar>& x, const std::shared_ptr<arrow::Scalar>& y);
std::partial_ordering ColumnsCompare(const std::vector<std::shared_ptr<arrow::Array>>& x, const ui32 xRow, const std::vector<std::shared_ptr<arrow::Array>>& y, const ui32 yRow);
bool ScalarLess(const std::shared_ptr<arrow::Scalar>& x, const std::shared_ptr<arrow::Scalar>& y);
bool ScalarLess(const arrow::Scalar& x, const arrow::Scalar& y);
Expand Down
13 changes: 13 additions & 0 deletions ydb/core/formats/arrow/common/accessor.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "accessor.h"
#include <ydb/core/formats/arrow/size_calcer.h>
#include <ydb/core/formats/arrow/switch/compare.h>
#include <ydb/core/formats/arrow/switch/switch_type.h>
#include <ydb/library/actors/core/log.h>
Expand Down Expand Up @@ -94,6 +95,10 @@ class TChunkAccessor {

}

std::optional<ui64> TTrivialArray::DoGetRawSize() const {
return NArrow::GetArrayDataSize(Array);
}

std::partial_ordering IChunkedArray::TCurrentChunkAddress::Compare(const ui64 position, const TCurrentChunkAddress& item, const ui64 itemPosition) const {
AFL_VERIFY(StartPosition <= position);
AFL_VERIFY(position < FinishPosition);
Expand All @@ -119,4 +124,12 @@ IChunkedArray::TCurrentChunkAddress TTrivialChunkedArray::DoGetChunk(const std::
return SelectChunk(chunkCurrent, position, accessor);
}

std::optional<ui64> TTrivialChunkedArray::DoGetRawSize() const {
ui64 result = 0;
for (auto&& i : Array->chunks()) {
result += NArrow::GetArrayDataSize(i);
}
return result;
}

}
21 changes: 16 additions & 5 deletions ydb/core/formats/arrow/common/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,23 @@ class IChunkedArray {
YDB_READONLY_DEF(std::shared_ptr<arrow::DataType>, DataType);
YDB_READONLY(ui64, RecordsCount, 0);
YDB_READONLY(EType, Type, EType::Undefined);
virtual std::optional<ui64> DoGetRawSize() const = 0;
protected:
virtual std::shared_ptr<arrow::ChunkedArray> DoGetChunkedArray() const = 0;
virtual TCurrentChunkAddress DoGetChunk(const std::optional<TCurrentChunkAddress>& chunkCurrent, const ui64 position) const = 0;

template <class TChunkAccessor>
TCurrentChunkAddress SelectChunk(const std::optional<TCurrentChunkAddress>& chunkCurrent, const ui64 position, const TChunkAccessor& accessor) const {
if (!chunkCurrent || position >= chunkCurrent->GetStartPosition() + chunkCurrent->GetLength()) {
if (!chunkCurrent || position >= chunkCurrent->GetStartPosition()) {
ui32 startIndex = 0;
ui64 idx = 0;
if (chunkCurrent) {
AFL_VERIFY(chunkCurrent->GetChunkIndex() + 1 < accessor.GetChunksCount());
startIndex = chunkCurrent->GetChunkIndex() + 1;
idx = chunkCurrent->GetStartPosition() + chunkCurrent->GetLength();
if (position < chunkCurrent->GetFinishPosition()) {
return *chunkCurrent;
}
AFL_VERIFY(chunkCurrent->GetChunkIndex() < accessor.GetChunksCount());
startIndex = chunkCurrent->GetChunkIndex();
idx = chunkCurrent->GetStartPosition();
}
for (ui32 i = startIndex; i < accessor.GetChunksCount(); ++i) {
const ui64 nextIdx = idx + accessor.GetChunkLength(i);
Expand All @@ -105,7 +109,7 @@ class IChunkedArray {
}
idx = nextIdx;
}
} else if (position < chunkCurrent->GetStartPosition()) {
} else {
AFL_VERIFY(chunkCurrent->GetChunkIndex() > 0);
ui64 idx = chunkCurrent->GetStartPosition();
for (i32 i = chunkCurrent->GetChunkIndex() - 1; i >= 0; --i) {
Expand Down Expand Up @@ -156,6 +160,10 @@ class IChunkedArray {
TString DebugString(const ui32 position) const;
};

std::optional<ui64> GetRawSize() const {
return DoGetRawSize();
}

std::shared_ptr<arrow::ChunkedArray> GetChunkedArray() const {
return DoGetChunkedArray();
}
Expand All @@ -180,6 +188,8 @@ class TTrivialArray: public IChunkedArray {
using TBase = IChunkedArray;
const std::shared_ptr<arrow::Array> Array;
protected:
virtual std::optional<ui64> DoGetRawSize() const override;

virtual TCurrentChunkAddress DoGetChunk(const std::optional<TCurrentChunkAddress>& /*chunkCurrent*/, const ui64 /*position*/) const override {
return TCurrentChunkAddress(Array, 0, 0);
}
Expand All @@ -204,6 +214,7 @@ class TTrivialChunkedArray: public IChunkedArray {
virtual std::shared_ptr<arrow::ChunkedArray> DoGetChunkedArray() const override {
return Array;
}
virtual std::optional<ui64> DoGetRawSize() const override;

public:
TTrivialChunkedArray(const std::shared_ptr<arrow::ChunkedArray>& data)
Expand Down
2 changes: 1 addition & 1 deletion ydb/core/formats/arrow/common/adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class TDataBuilderPolicy<TGeneralContainer> {
return batch;
}
[[nodiscard]] static std::shared_ptr<TGeneralContainer> ApplyArrowFilter(const std::shared_ptr<TGeneralContainer>& batch, const std::shared_ptr<arrow::BooleanArray>& filter) {
auto table = batch->BuildTable();
auto table = batch->BuildTableVerified();
return std::make_shared<TGeneralContainer>(TDataBuilderPolicy<arrow::Table>::ApplyArrowFilter(table, filter));
}
[[nodiscard]] static std::shared_ptr<TGeneralContainer> GetEmptySame(const std::shared_ptr<TGeneralContainer>& batch) {
Expand Down
140 changes: 119 additions & 21 deletions ydb/core/formats/arrow/common/container.cpp
Original file line number Diff line number Diff line change
@@ -1,50 +1,60 @@
#include "container.h"
#include <ydb/library/actors/core/log.h>
#include <ydb/core/formats/arrow/arrow_helpers.h>
#include <ydb/core/formats/arrow/simple_arrays_cache.h>

namespace NKikimr::NArrow {

NKikimr::TConclusionStatus TGeneralContainer::MergeColumnsStrictly(const TGeneralContainer& container) {
if (RecordsCount != container.RecordsCount) {
TConclusionStatus TGeneralContainer::MergeColumnsStrictly(const TGeneralContainer& container) {
if (!container.RecordsCount) {
return TConclusionStatus::Success();
}
if (!RecordsCount) {
RecordsCount = container.RecordsCount;
}
if (*RecordsCount != *container.RecordsCount) {
return TConclusionStatus::Fail(TStringBuilder() << "inconsistency records count in additional container: " <<
container.GetSchema()->ToString() << ". expected: " << RecordsCount << ", reality: " << container.GetRecordsCount());
}
for (i32 i = 0; i < container.Schema->num_fields(); ++i) {
auto addFieldResult = AddField(container.Schema->field(i), container.Columns[i]);
if (!addFieldResult) {
if (addFieldResult.IsFail()) {
return addFieldResult;
}
}
return TConclusionStatus::Success();
}

NKikimr::TConclusionStatus TGeneralContainer::AddField(const std::shared_ptr<arrow::Field>& f, const std::shared_ptr<NAccessor::IChunkedArray>& data) {
TConclusionStatus TGeneralContainer::AddField(const std::shared_ptr<arrow::Field>& f, const std::shared_ptr<NAccessor::IChunkedArray>& data) {
AFL_VERIFY(f);
AFL_VERIFY(data);
if (data->GetRecordsCount() != RecordsCount) {
if (RecordsCount && data->GetRecordsCount() != *RecordsCount) {
return TConclusionStatus::Fail(TStringBuilder() << "inconsistency records count in new column: " <<
f->name() << ". expected: " << RecordsCount << ", reality: " << data->GetRecordsCount());
}
if (!data->GetDataType()->Equals(f->type())) {
return TConclusionStatus::Fail("schema and data type are not equals: " + data->GetDataType()->ToString() + " vs " + f->type()->ToString());
}
if (Schema->GetFieldByName(f->name())) {
return TConclusionStatus::Fail("field name duplication: " + f->name());
}
auto resultAdd = Schema->AddField(Schema->num_fields(), f);
if (!resultAdd.ok()) {
return TConclusionStatus::Fail("internal schema error on add field: " + resultAdd.status().ToString());
{
auto conclusion = Schema->AddField(f);
if (conclusion.IsFail()) {
return conclusion;
}
}
Schema = *resultAdd;
RecordsCount = data->GetRecordsCount();
Columns.emplace_back(data);
return TConclusionStatus::Success();
}

TGeneralContainer::TGeneralContainer(const std::shared_ptr<arrow::Schema>& schema, std::vector<std::shared_ptr<NAccessor::IChunkedArray>>&& columns)
: Schema(schema)
, Columns(std::move(columns))
{
AFL_VERIFY(schema);
TConclusionStatus TGeneralContainer::AddField(const std::shared_ptr<arrow::Field>& f, const std::shared_ptr<arrow::ChunkedArray>& data) {
return AddField(f, std::make_shared<NAccessor::TTrivialChunkedArray>(data));
}

TConclusionStatus TGeneralContainer::AddField(const std::shared_ptr<arrow::Field>& f, const std::shared_ptr<arrow::Array>& data) {
return AddField(f, std::make_shared<NAccessor::TTrivialArray>(data));
}

void TGeneralContainer::Initialize() {
std::optional<ui64> recordsCount;
AFL_VERIFY(Schema->num_fields() == (i32)Columns.size())("schema", Schema->num_fields())("columns", Columns.size());
for (i32 i = 0; i < Schema->num_fields(); ++i) {
Expand All @@ -58,12 +68,34 @@ TGeneralContainer::TGeneralContainer(const std::shared_ptr<arrow::Schema>& schem
}
}
AFL_VERIFY(recordsCount);
AFL_VERIFY(!RecordsCount || *RecordsCount == *recordsCount);
RecordsCount = *recordsCount;
}

TGeneralContainer::TGeneralContainer(const std::vector<std::shared_ptr<arrow::Field>>& fields, std::vector<std::shared_ptr<NAccessor::IChunkedArray>>&& columns)
: Schema(std::make_shared<NModifier::TSchema>(fields))
, Columns(std::move(columns))
{
Initialize();
}

TGeneralContainer::TGeneralContainer(const std::shared_ptr<NModifier::TSchema>& schema, std::vector<std::shared_ptr<NAccessor::IChunkedArray>>&& columns)
: Schema(std::make_shared<NModifier::TSchema>(schema))
, Columns(std::move(columns))
{
Initialize();
}

TGeneralContainer::TGeneralContainer(const std::shared_ptr<arrow::Schema>& schema, std::vector<std::shared_ptr<NAccessor::IChunkedArray>>&& columns)
: Schema(std::make_shared<NModifier::TSchema>(schema))
, Columns(std::move(columns))
{
Initialize();
}

TGeneralContainer::TGeneralContainer(const std::shared_ptr<arrow::Table>& table) {
AFL_VERIFY(table);
Schema = table->schema();
Schema = std::make_shared<NModifier::TSchema>(table->schema());
RecordsCount = table->num_rows();
for (auto&& i : table->columns()) {
if (i->num_chunks() == 1) {
Expand All @@ -72,15 +104,17 @@ TGeneralContainer::TGeneralContainer(const std::shared_ptr<arrow::Table>& table)
Columns.emplace_back(std::make_shared<NAccessor::TTrivialChunkedArray>(i));
}
}
Initialize();
}

TGeneralContainer::TGeneralContainer(const std::shared_ptr<arrow::RecordBatch>& table) {
AFL_VERIFY(table);
Schema = table->schema();
Schema = std::make_shared<NModifier::TSchema>(table->schema());
RecordsCount = table->num_rows();
for (auto&& i : table->columns()) {
Columns.emplace_back(std::make_shared<NAccessor::TTrivialArray>(i));
}
Initialize();
}

std::shared_ptr<NKikimr::NArrow::NAccessor::IChunkedArray> TGeneralContainer::GetAccessorByNameVerified(const std::string& fieldId) const {
Expand Down Expand Up @@ -110,14 +144,78 @@ std::shared_ptr<arrow::Table> TGeneralContainer::BuildTableOptional(const std::o
if (fields.empty()) {
return nullptr;
}
return arrow::Table::Make(std::make_shared<arrow::Schema>(fields), columns, RecordsCount);
AFL_VERIFY(RecordsCount);
return arrow::Table::Make(std::make_shared<arrow::Schema>(fields), columns, *RecordsCount);
}

std::shared_ptr<arrow::Table> TGeneralContainer::BuildTable(const std::optional<std::set<std::string>>& columnNames /*= {}*/) const {
std::shared_ptr<arrow::Table> TGeneralContainer::BuildTableVerified(const std::optional<std::set<std::string>>& columnNames /*= {}*/) const {
auto result = BuildTableOptional(columnNames);
AFL_VERIFY(result);
AFL_VERIFY(!columnNames || result->schema()->num_fields() == (i32)columnNames->size());
return result;
}

std::shared_ptr<NArrow::NAccessor::IChunkedArray> TGeneralContainer::GetAccessorByNameOptional(const std::string& fieldId) const {
int idx = Schema->GetFieldIndex(fieldId);
if (idx == -1) {
return nullptr;
}
AFL_VERIFY((ui32)idx < Columns.size())("idx", idx)("count", Columns.size());
return Columns[idx];
}

TConclusionStatus TGeneralContainer::SyncSchemaTo(const std::shared_ptr<arrow::Schema>& schema, const IFieldsConstructor* defaultFieldsConstructor, const bool forceDefaults) {
std::shared_ptr<NModifier::TSchema> schemaNew = std::make_shared<NModifier::TSchema>();
std::vector<std::shared_ptr<NAccessor::IChunkedArray>> columnsNew;
if (!RecordsCount) {
return TConclusionStatus::Fail("original container has not data");
}
for (auto&& i : schema->fields()) {
const int idx = Schema->GetFieldIndex(i->name());
if (idx == -1) {
if (!defaultFieldsConstructor) {
return TConclusionStatus::Fail("haven't field for sync: '" + i->name() + "'");
} else {
schemaNew->AddField(i).Validate();
auto defConclusion = defaultFieldsConstructor->GetDefaultColumnElementValue(i, forceDefaults);
if (defConclusion.IsFail()) {
return defConclusion;
}
columnsNew.emplace_back(std::make_shared<NAccessor::TTrivialArray>(NArrow::TThreadSimpleArraysCache::Get(i->type(), *defConclusion, *RecordsCount)));
}
} else {
const auto& fOwned = Schema->GetFieldVerified(idx);
if (!fOwned->type()->Equals(i->type())) {
return TConclusionStatus::Fail("different field types for '" + i->name() + "'. Have " + fOwned->type()->ToString() + ", need " + i->type()->ToString());
}
schemaNew->AddField(fOwned).Validate();
columnsNew.emplace_back(Columns[idx]);
}
}
std::swap(Schema, schemaNew);
std::swap(columnsNew, Columns);
return TConclusionStatus::Success();
}

TString TGeneralContainer::DebugString() const {
TStringBuilder result;
if (RecordsCount) {
result << "records_count=" << *RecordsCount << ";";
}
result << "schema=" << Schema->ToString() << ";";
return result;
}

TConclusion<std::shared_ptr<arrow::Scalar>> IFieldsConstructor::GetDefaultColumnElementValue(const std::shared_ptr<arrow::Field>& field, const bool force) const {
AFL_VERIFY(field);
auto result = DoGetDefaultColumnElementValue(field->name());
if (result) {
return result;
}
if (force) {
return NArrow::DefaultScalar(field->type());
}
return TConclusionStatus::Fail("have not default value for column " + field->name());
}

}
Loading

0 comments on commit 8432391

Please sign in to comment.