diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc index 716b2191c..7aee5540e 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc @@ -906,38 +906,44 @@ class HashAggregateKernel::Impl { auto length = in[0]->length(); std::vector indices; indices.resize(length, -1); - for (int i = 0; i < length; i++) { - auto aggr_key_validity = true; - arrow::util::string_view aggr_key; - if (aggr_key_unsafe_row) { + + arrow::util::string_view aggr_key; + if (aggr_key_unsafe_row) { + for (int i = 0; i < length; i++) { aggr_key_unsafe_row->reset(); - int idx = 0; + for (auto payload_arr : payloads) { payload_arr->Append(i, &aggr_key_unsafe_row); } aggr_key = arrow::util::string_view(aggr_key_unsafe_row->data, aggr_key_unsafe_row->cursor); - } else { - aggr_key = typed_key_in->GetView(i); - aggr_key_validity = - typed_key_in->null_count() == 0 ? true : !typed_key_in->IsNull(i); - } - - // 3. get key from hash_table - int memo_index = 0; - if (!aggr_key_validity) { - memo_index = aggr_hash_table_->GetOrInsertNull([](int) {}, [](int) {}); - } else { + // FIXME(): all keys are null? aggr_hash_table_->GetOrInsert( - aggr_key, [](int) {}, [](int) {}, &memo_index); + aggr_key, [](int) {}, [](int) {}, &(indices[i])); } + } else { + for (int i = 0; i < length; i++) { + if (typed_key_in->null_count() > 0) { + aggr_key = typed_key_in->GetView(i); + auto aggr_key_validity = + typed_key_in->null_count() == 0 ? true : !typed_key_in->IsNull(i); + + if (!aggr_key_validity) { + indices[i] = aggr_hash_table_->GetOrInsertNull([](int) {}, [](int) {}); + } else { + aggr_hash_table_->GetOrInsert( + aggr_key, [](int) {}, [](int) {}, &(indices[i])); + } + } else { + aggr_key = typed_key_in->GetView(i); - if (memo_index > max_group_id_) { - max_group_id_ = memo_index; + aggr_hash_table_->GetOrInsert( + aggr_key, [](int) {}, [](int) {}, &(indices[i])); + } } - indices[i] = memo_index; } + max_group_id_ = aggr_hash_table_->Size() - 1; total_out_length_ = max_group_id_ + 1; // 4. prepare action func and evaluate std::vector> eval_func_list; diff --git a/native-sql-engine/cpp/src/precompile/hash_map.cc b/native-sql-engine/cpp/src/precompile/hash_map.cc index 991c36b5c..d55156fc6 100644 --- a/native-sql-engine/cpp/src/precompile/hash_map.cc +++ b/native-sql-engine/cpp/src/precompile/hash_map.cc @@ -68,6 +68,7 @@ namespace precompile { void (*on_not_found)(int32_t)) { \ return impl_->GetOrInsertNull(on_found, on_not_found); \ } \ + int32_t HASHMAPNAME::Size() { return impl_->size(); } \ int32_t HASHMAPNAME::Get(const TYPE& value) { return impl_->Get(value); } \ int32_t HASHMAPNAME::GetNull() { return impl_->GetNull(); } diff --git a/native-sql-engine/cpp/src/precompile/hash_map.h b/native-sql-engine/cpp/src/precompile/hash_map.h index df78033e0..de1613557 100644 --- a/native-sql-engine/cpp/src/precompile/hash_map.h +++ b/native-sql-engine/cpp/src/precompile/hash_map.h @@ -28,6 +28,7 @@ namespace precompile { void (*on_not_found)(int32_t), int32_t* out_memo_index); \ int32_t GetOrInsertNull(void (*on_found)(int32_t), void (*on_not_found)(int32_t)); \ int32_t Get(const TYPE& value); \ + int32_t Size(); \ int32_t GetNull(); \ \ private: \