Skip to content

Commit

Permalink
fix: Input batch to ShuffleRepartitioner.insert_batch should not be l…
Browse files Browse the repository at this point in the history
…arger than configured batch size (apache#523)

* fix: Input batch to ShuffleRepartitioner.insert_batch should not be larger than configured batch size

* Add test

* For review
  • Loading branch information
viirya authored and kazuyukitanimura committed Jul 1, 2024
1 parent dd0887d commit 2253042
Showing 1 changed file with 61 additions and 3 deletions.
64 changes: 61 additions & 3 deletions core/src/execution/datafusion/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,8 @@ struct ShuffleRepartitioner {
hashes_buf: Vec<u32>,
/// Partition ids for each row in the current batch
partition_ids: Vec<u64>,
/// The configured batch size
batch_size: usize,
}

struct ShuffleRepartitionerMetrics {
Expand Down Expand Up @@ -642,17 +644,41 @@ impl ShuffleRepartitioner {
reservation,
hashes_buf,
partition_ids,
batch_size,
}
}

/// Shuffles rows in input batch into corresponding partition buffer.
/// This function will slice input batch according to configured batch size and then
/// shuffle rows into corresponding partition buffer.
async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
let mut start = 0;
while start < batch.num_rows() {
let end = (start + self.batch_size).min(batch.num_rows());
let batch = batch.slice(start, end - start);
self.partitioning_batch(batch).await?;
start = end;
}
Ok(())
}

/// Shuffles rows in input batch into corresponding partition buffer.
/// This function first calculates hashes for rows and then takes rows in same
/// partition as a record batch which is appended into partition buffer.
async fn insert_batch(&mut self, input: RecordBatch) -> Result<()> {
/// This should not be called directly. Use `insert_batch` instead.
async fn partitioning_batch(&mut self, input: RecordBatch) -> Result<()> {
if input.num_rows() == 0 {
// skip empty batch
return Ok(());
}

if input.num_rows() > self.batch_size {
return Err(DataFusionError::Internal(
"Input batch size exceeds configured batch size. Call `insert_batch` instead."
.to_string(),
));
}

let _timer = self.metrics.baseline.elapsed_compute().timer();

// NOTE: in shuffle writer exec, the output_rows metrics represents the
Expand Down Expand Up @@ -951,8 +977,7 @@ async fn external_shuffle(
);

while let Some(batch) = input.next().await {
let batch = batch?;
repartitioner.insert_batch(batch).await?;
repartitioner.insert_batch(batch?).await?;
}
repartitioner.shuffle_write().await
}
Expand Down Expand Up @@ -1387,6 +1412,11 @@ impl RecordBatchStream for EmptyStream {
#[cfg(test)]
mod test {
use super::*;
use datafusion::physical_plan::common::collect;
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::prelude::SessionContext;
use datafusion_physical_expr::expressions::Column;
use tokio::runtime::Runtime;

#[test]
fn test_slot_size() {
Expand Down Expand Up @@ -1415,4 +1445,32 @@ mod test {
assert_eq!(slot_size, *expected);
})
}

#[test]
fn test_insert_larger_batch() {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
let mut b = StringBuilder::new();
for i in 0..10000 {
b.append_value(format!("{i}"));
}
let array = b.finish();
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();

let mut batches = Vec::new();
batches.push(batch.clone());

let partitions = &[batches];
let exec = ShuffleWriterExec::try_new(
Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()),
Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16),
"/tmp/data.out".to_string(),
"/tmp/index.out".to_string(),
)
.unwrap();
let ctx = SessionContext::new();
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).unwrap();
let rt = Runtime::new().unwrap();
rt.block_on(collect(stream)).unwrap();
}
}

0 comments on commit 2253042

Please sign in to comment.