Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Deadlock due to bad logic in new-streaming join sampling #21265

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 118 additions & 95 deletions crates/polars-stream/src/nodes/joins/equi_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,59 +168,108 @@ fn estimate_cardinality(
Ok(sketch.estimate())
}

#[expect(clippy::needless_lifetimes)]
fn insert_cached_into_parallel_stream<'s, 'env>(
cached: &'s Option<ArrayQueue<Morsel>>,
num_pipelines: usize,
recv_port: Option<RecvPort<'_>>,
scope: &'s TaskScope<'s, 'env>,
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
) -> Option<Vec<Receiver<Morsel>>> {
let Some(cached) = cached.as_ref().filter(|c| !c.is_empty()) else {
return recv_port.map(|p| p.parallel());
};

let receivers = if let Some(p) = recv_port {
p.parallel().into_iter().map(Some).collect_vec()
} else {
(0..num_pipelines).map(|_| None).collect_vec()
};

let source_token = SourceToken::new();
let mut out = Vec::new();
for orig_recv in receivers {
let (mut new_send, new_recv) = connector();
out.push(new_recv);
let source_token = source_token.clone();
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
// Act like an InMemorySource node until cached morsels are consumed.
let wait_group = WaitGroup::default();
loop {
let Some(mut morsel) = cached.pop() else {
break;
};
morsel.replace_source_token(source_token.clone());
morsel.set_consume_token(wait_group.token());
if new_send.send(morsel).await.is_err() {
return Ok(());
}
wait_group.wait().await;
if source_token.stop_requested() {
return Ok(());
}
}
struct BufferedStream {
morsels: ArrayQueue<Morsel>,
post_buffer_offset: MorselSeq,
}

if let Some(mut recv) = orig_recv {
while let Ok(morsel) = recv.recv().await {
if new_send.send(morsel).await.is_err() {
impl BufferedStream {
pub fn new(morsels: Vec<Morsel>, start_offset: MorselSeq) -> Self {
// Relabel so we can insert into parallel streams later.
let mut seq = start_offset;
let queue = ArrayQueue::new(morsels.len().max(1));
for mut morsel in morsels {
morsel.set_seq(seq);
queue.push(morsel).unwrap();
seq = seq.successor();
}

Self {
morsels: queue,
post_buffer_offset: seq,
}
}

pub fn is_empty(&self) -> bool {
self.morsels.is_empty()
}

#[expect(clippy::needless_lifetimes)]
pub fn reinsert<'s, 'env>(
&'s self,
num_pipelines: usize,
recv_port: Option<RecvPort<'_>>,
scope: &'s TaskScope<'s, 'env>,
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
) -> Option<Vec<Receiver<Morsel>>> {
let receivers = if let Some(p) = recv_port {
p.parallel().into_iter().map(Some).collect_vec()
} else {
(0..num_pipelines).map(|_| None).collect_vec()
};

let source_token = SourceToken::new();
let mut out = Vec::new();
for orig_recv in receivers {
let (mut new_send, new_recv) = connector();
out.push(new_recv);
let source_token = source_token.clone();
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
// Act like an InMemorySource node until cached morsels are consumed.
let wait_group = WaitGroup::default();
loop {
let Some(mut morsel) = self.morsels.pop() else {
break;
};
morsel.replace_source_token(source_token.clone());
morsel.set_consume_token(wait_group.token());
if new_send.send(morsel).await.is_err() {
return Ok(());
}
wait_group.wait().await;
// TODO: Unfortunately we can't actually stop here without
// re-buffering morsels from the stream that comes after.
// if source_token.stop_requested() {
// break;
// }
}
}
Ok(())
}));

if let Some(mut recv) = orig_recv {
while let Ok(mut morsel) = recv.recv().await {
if source_token.stop_requested() {
morsel.source_token().stop();
}
morsel.set_seq(morsel.seq().offset_by(self.post_buffer_offset));
if new_send.send(morsel).await.is_err() {
break;
}
}
}
Ok(())
}));
}
Some(out)
}
}

impl Default for BufferedStream {
fn default() -> Self {
Self {
morsels: ArrayQueue::new(1),
post_buffer_offset: MorselSeq::default(),
}
}
}

impl Drop for BufferedStream {
fn drop(&mut self) {
POOL.install(|| {
// Parallel drop as the state might be quite big.
(0..self.morsels.len())
.into_par_iter()
.for_each(|_| drop(self.morsels.pop()));
})
}
Some(out)
}

#[derive(Default)]
Expand Down Expand Up @@ -357,19 +406,10 @@ impl SampleState {
new_chunked_idx_table(params.right_key_schema.clone())
});

fn make_queue(v: Vec<Morsel>) -> Option<ArrayQueue<Morsel>> {
if v.is_empty() {
return None;
}
let queue = ArrayQueue::new(v.len());
for morsel in v {
queue.push(morsel).unwrap();
}
Some(queue)
}

let mut sampled_build_morsels = make_queue(core::mem::take(&mut self.left));
let mut sampled_probe_morsels = make_queue(core::mem::take(&mut self.right));
let mut sampled_build_morsels =
BufferedStream::new(core::mem::take(&mut self.left), MorselSeq::default());
let mut sampled_probe_morsels =
BufferedStream::new(core::mem::take(&mut self.right), MorselSeq::default());
if !left_is_build {
core::mem::swap(&mut sampled_build_morsels, &mut sampled_probe_morsels);
}
Expand All @@ -381,18 +421,13 @@ impl SampleState {
};

// Simulate the sample build morsels flowing into the build side.
if sampled_build_morsels.is_some() {
if !sampled_build_morsels.is_empty() {
let state = ExecutionState::new();
crate::async_executor::task_scope(|scope| {
let mut join_handles = Vec::new();
let receivers = insert_cached_into_parallel_stream(
&sampled_build_morsels,
num_pipelines,
None,
scope,
&mut join_handles,
)
.unwrap();
let receivers = sampled_build_morsels
.reinsert(num_pipelines, None, scope, &mut join_handles)
.unwrap();

for (worker_ps, recv) in build_state.partitions_per_worker.iter_mut().zip(receivers)
{
Expand Down Expand Up @@ -431,7 +466,7 @@ struct BuildPartition {
#[derive(Default)]
struct BuildState {
partitions_per_worker: Vec<Vec<BuildPartition>>,
sampled_probe_morsels: Option<ArrayQueue<Morsel>>,
sampled_probe_morsels: BufferedStream,
}

impl BuildState {
Expand Down Expand Up @@ -594,7 +629,7 @@ struct ProbeTable {
struct ProbeState {
table_per_partition: Vec<ProbeTable>,
max_seq_sent: MorselSeq,
sampled_probe_morsels: Option<ArrayQueue<Morsel>>,
sampled_probe_morsels: BufferedStream,
}

impl ProbeState {
Expand Down Expand Up @@ -858,11 +893,6 @@ impl Drop for ProbeState {
POOL.install(|| {
// Parallel drop as the state might be quite big.
self.table_per_partition.par_drain(..).for_each(drop);
if let Some(morsels) = &self.sampled_probe_morsels {
(0..morsels.len())
.into_par_iter()
.for_each(|_| drop(morsels.pop()));
}
})
}
}
Expand Down Expand Up @@ -1126,11 +1156,7 @@ impl ComputeNode for EquiJoinNode {
// If we are probing and the probe input is done, emit unmatched if
// necessary, otherwise we're done.
if let EquiJoinState::Probe(probe_state) = &mut self.state {
let samples_consumed = probe_state
.sampled_probe_morsels
.as_ref()
.map(|m| m.is_empty())
.unwrap_or(true);
let samples_consumed = probe_state.sampled_probe_morsels.is_empty();
if samples_consumed && recv[probe_idx] == PortState::Done {
if self.params.emit_unmatched_build() {
if self.params.preserve_order_build {
Expand Down Expand Up @@ -1194,11 +1220,7 @@ impl ComputeNode for EquiJoinNode {
if recv[probe_idx] != PortState::Done {
core::mem::swap(&mut send[0], &mut recv[probe_idx]);
} else {
let samples_consumed = probe_state
.sampled_probe_morsels
.as_ref()
.map(|m| m.is_empty())
.unwrap_or(true);
let samples_consumed = probe_state.sampled_probe_morsels.is_empty();
send[0] = if samples_consumed {
PortState::Done
} else {
Expand Down Expand Up @@ -1319,14 +1341,15 @@ impl ComputeNode for EquiJoinNode {
EquiJoinState::Probe(probe_state) => {
assert!(recv_ports[build_idx].is_none());
let senders = send_ports[0].take().unwrap().parallel();
let receivers = insert_cached_into_parallel_stream(
&probe_state.sampled_probe_morsels,
self.num_pipelines,
recv_ports[probe_idx].take(),
scope,
join_handles,
)
.unwrap();
let receivers = probe_state
.sampled_probe_morsels
.reinsert(
self.num_pipelines,
recv_ports[probe_idx].take(),
scope,
join_handles,
)
.unwrap();

let partitioner = HashPartitioner::new(self.num_pipelines, 0);
let probe_tasks = receivers
Expand Down
Loading