diff --git a/cpp/src/arrow/acero/CMakeLists.txt b/cpp/src/arrow/acero/CMakeLists.txt index b77d52a23eedb..1889f65632b38 100644 --- a/cpp/src/arrow/acero/CMakeLists.txt +++ b/cpp/src/arrow/acero/CMakeLists.txt @@ -170,8 +170,11 @@ add_arrow_acero_test(plan_test add_arrow_acero_test(source_node_test SOURCES source_node_test.cc test_nodes.cc) add_arrow_acero_test(fetch_node_test SOURCES fetch_node_test.cc test_nodes.cc) add_arrow_acero_test(order_by_node_test SOURCES order_by_node_test.cc test_nodes.cc) -add_arrow_acero_test(hash_join_node_test SOURCES hash_join_node_test.cc - bloom_filter_test.cc) +add_arrow_acero_test(hash_join_node_test + SOURCES + hash_join_node_test.cc + bloom_filter_test.cc + swiss_join_test.cc) add_arrow_acero_test(pivot_longer_node_test SOURCES pivot_longer_node_test.cc test_nodes.cc) diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index 3b012cc9da9ba..1c3a7b227567b 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -2244,6 +2244,10 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, match_iterator.SetLookupResult( minibatch_size_next, minibatch_start, match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), no_duplicate_keys, hash_table_->key_to_payload()); + if (!residual_filter_->IsTrivial()) { + std::memset(filtered_bitvector_buf.mutable_data(), 0, + bit_util::BytesForBits(minibatch_size_next)); + } int num_matches_next; while (match_iterator.GetNextBatch(minibatch_size, &num_matches_next, materialize_batch_ids_buf.mutable_data(), @@ -2256,8 +2260,6 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, materialize_key_ids_buf.mutable_data(), materialize_payload_ids_buf.mutable_data(), /*output_payload_ids=*/true, !(no_duplicate_keys || no_payload_columns), temp_stack, &num_matches_next)); - std::memset(filtered_bitvector_buf.mutable_data(), 0, - bit_util::BytesForBits(minibatch_size_next)); for (int i = 0; i < num_matches_next; ++i) { int bit_idx = materialize_batch_ids_buf.mutable_data()[i] - minibatch_start; bit_util::SetBitTo(filtered_bitvector_buf.mutable_data(), bit_idx, 1); diff --git a/cpp/src/arrow/acero/swiss_join_test.cc b/cpp/src/arrow/acero/swiss_join_test.cc new file mode 100644 index 0000000000000..e51af2d1594c3 --- /dev/null +++ b/cpp/src/arrow/acero/swiss_join_test.cc @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +namespace arrow { +namespace acero { + +TEST(SwissJoin, ResidualFilter) {} + +} // namespace acero +} // namespace arrow \ No newline at end of file