Skip to content

Commit

Permalink
Fix the test to actually use the gated node. Fix the gated node so th…
Browse files Browse the repository at this point in the history
…at it delivers batches in order.
  • Loading branch information
westonpace committed Jun 14, 2023
1 parent 4575f91 commit 4af7119
Showing 1 changed file with 57 additions and 31 deletions.
88 changes: 57 additions & 31 deletions cpp/src/arrow/acero/asof_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1513,28 +1513,50 @@ struct GatedNode : public ExecNode, public TracedNode {

Status StopProducingImpl() override { return Status::OK(); }

Status SendBatchesUnlocked(std::unique_lock<std::mutex>&& lock) {
while (!queued_batches_.empty()) {
// If we are ready to release the batch, do so immediately.
Future<> maybe_unlocked = gate_->WaitForNextReleasedBatch();
bool callback_added = maybe_unlocked.TryAddCallback([this] {
return [this](const Status& st) {
DCHECK_OK(st);
plan_->query_context()->ScheduleTask(
[this] {
std::unique_lock lk(mutex_);
return SendBatchesUnlocked(std::move(lk));
},
"GatedNode::ResumeAfterNotify");
};
});
if (callback_added) {
break;
}
// Otherwise, the future is already finished which means the gate is unlocked
// and we are allowed to send a batch
ExecBatch next = std::move(queued_batches_.front());
queued_batches_.pop();
lock.unlock();
ARROW_RETURN_NOT_OK(output_->InputReceived(this, std::move(next)));
lock.lock();
}
return Status::OK();
}

Status InputReceived(ExecNode* input, ExecBatch batch) override {
auto scope = TraceInputReceived(batch);
DCHECK_EQ(input, inputs_[0]);

// If we are ready to release the batch, do so immediately.
Future<> maybe_unlocked = gate_->WaitForNextReleasedBatch();
if (maybe_unlocked.is_finished()) {
return output_->InputReceived(this, std::move(batch));
}
// This may be called concurrently by the source and by a restart attempt. Process
// one at a time (this critical section should be pretty small)
std::unique_lock lk(mutex_);
queued_batches_.push(std::move(batch));

// Otherwise, we will wait for the gate to notify us and then check if we are
// ready to relese a batch again.
maybe_unlocked.AddCallback([this, input, batch](const Status& st) {
DCHECK_OK(st);
plan_->query_context()->ScheduleTask(
[this, input, batch] { return InputReceived(input, batch); },
"GatedNode::ResumeAfterNotify");
});
return Status::OK();
return SendBatchesUnlocked(std::move(lk));
}

Gate* gate_;
std::queue<ExecBatch> queued_batches_;
std::mutex mutex_;
};

AsyncGenerator<std::optional<ExecBatch>> GetGen(
Expand All @@ -1546,8 +1568,7 @@ AsyncGenerator<std::optional<ExecBatch>> GetGen(BatchesWithSchema bws) {
}

template <typename BatchesMaker>
void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size,
double fast_delay, double slow_delay, bool noisy = false) {
void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size) {
auto l_schema =
schema({field("time", int32()), field("key", int32()), field("l_value", int32())});
auto r0_schema =
Expand All @@ -1571,21 +1592,23 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size,

struct BackpressureSourceConfig {
std::string name_prefix;
bool is_fast;
bool is_gated;
std::shared_ptr<Schema> schema;
decltype(l_batches) batches;

std::string name() const { return name_prefix + ";" + (is_fast ? "fast" : "slow"); }
std::string name() const {
return name_prefix + ";" + (is_gated ? "gated" : "ungated");
}
};

Gate gate;
GatedNodeOptions gate_options(&gate);

// must have at least one fast and one slow
// Two ungated and one gated
std::vector<BackpressureSourceConfig> source_configs = {
{"0", true, l_schema, l_batches},
{"1", false, r0_schema, r0_batches},
{"2", true, r1_schema, r1_batches},
{"0", false, l_schema, l_batches},
{"1", true, r0_schema, r0_batches},
{"2", false, r1_schema, r1_batches},
};

std::vector<BackpressureCounters> bp_counters(source_configs.size());
Expand All @@ -1603,28 +1626,33 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size,
std::vector<Declaration::Input> bp_in = {src_decls.back()};
Declaration bp_decl = {BackpressureCountingNode::kFactoryName, bp_in,
std::move(options)};
if (config.is_gated) {
bp_decl = {GatedNode::kFactoryName, {bp_decl}, gate_options};
}
bp_decls.push_back(bp_decl);
}

Declaration asofjoin = {
"asofjoin", bp_decls,
GetRepeatedOptions(source_configs.size(), "time", {"key"}, 1000)};
Declaration asofjoin = {"asofjoin", bp_decls,
GetRepeatedOptions(source_configs.size(), "time", {"key"}, 0)};

ASSERT_OK_AND_ASSIGN(std::shared_ptr<internal::ThreadPool> tpool,
internal::ThreadPool::Make(1));
ExecContext exec_ctx(default_memory_pool(), tpool.get());
Future<BatchesWithCommonSchema> batches_fut =
DeclarationToExecBatchesAsync(asofjoin, exec_ctx);

BusyWait(10.0, [&] {
auto has_bp_been_applied = [&] {
int total_paused = 0;
for (const auto& counters : bp_counters) {
total_paused += counters.pause_count;
}
// One of the inputs is gated. The other two will eventually be paused by the asof
// join node
return total_paused >= 2;
});
};

BusyWait(10.0, has_bp_been_applied);
ASSERT_TRUE(has_bp_been_applied());

gate.ReleaseAllBatches();
ASSERT_FINISHES_OK_AND_ASSIGN(BatchesWithCommonSchema batches, batches_fut);
Expand All @@ -1637,8 +1665,7 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size,
}

TEST(AsofJoinTest, BackpressureWithBatches) {
return TestBackpressure(MakeIntegerBatches, /*num_batches=*/20, /*batch_size=*/1,
/*fast_delay=*/0.01, /*slow_delay=*/0.1, /*noisy=*/false);
return TestBackpressure(MakeIntegerBatches, /*num_batches=*/20, /*batch_size=*/1);
}

template <typename BatchesMaker>
Expand Down Expand Up @@ -1703,8 +1730,7 @@ T GetEnvValue(const std::string& var, T default_value) {
TEST(AsofJoinTest, BackpressureWithBatchesGen) {
int num_batches = GetEnvValue("ARROW_BACKPRESSURE_DEMO_NUM_BATCHES", 20);
int batch_size = GetEnvValue("ARROW_BACKPRESSURE_DEMO_BATCH_SIZE", 1);
return TestBackpressure(MakeIntegerBatchGenForTest, num_batches, batch_size,
/*fast_delay=*/0.001, /*slow_delay=*/0.01);
return TestBackpressure(MakeIntegerBatchGenForTest, num_batches, batch_size);
}

} // namespace acero
Expand Down

0 comments on commit 4af7119

Please sign in to comment.