Skip to content

Commit

Permalink
fix/11982: resolves projection issue found in with_column window fn u…
Browse files Browse the repository at this point in the history
…sage

Signed-off-by: Devan <devandbenz@gmail.com>
  • Loading branch information
devanbenz committed Aug 15, 2024
1 parent b9961c3 commit 8affbdb
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 3 deletions.
40 changes: 40 additions & 0 deletions datafusion-examples/examples/testing_window_bug_will_delete.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use arrow::array::{Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::prelude::SessionContext;
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{col, BuiltInWindowFunction, Expr, WindowFunctionDefinition};
use std::sync::Arc;

#[tokio::main]
async fn main() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![5, 4, 3, 2, 1]))],
)?;

let ctx = SessionContext::new();

let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?;
ctx.register_table("t", Arc::new(provider))?;

let df = ctx.table("t").await?;

let func = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber),
vec![],
))
.alias("row_num");

df.clone()
.select(vec![col("a"), func.clone()])?
.show()
.await?;

df.with_column("r", func)?.show().await?;

Ok(())
}
40 changes: 37 additions & 3 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1452,12 +1452,14 @@ impl DataFrame {
let mut fields: Vec<Expr> = plan
.schema()
.iter()
.map(|(qualifier, field)| {
.filter_map(|(qualifier, field)| {
qualifier?;

if field.name() == name {
col_exists = true;
new_column.clone()
Some(new_column.clone())
} else {
col(Column::from((qualifier, field)))
Some(col(Column::from((qualifier, field))))
}
})
.collect();
Expand Down Expand Up @@ -1703,6 +1705,7 @@ mod tests {
use arrow::array::{self, Int32Array};
use datafusion_common::{Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFunctionDefinition,
Expand Down Expand Up @@ -2373,6 +2376,37 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_window_function_with_column() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![5, 4, 3, 2, 1]))],
)?;

let ctx = SessionContext::new();

let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?;
ctx.register_table("t", Arc::new(provider))?;

let df = ctx.table("t").await?;

let func = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::RowNumber,
),
vec![],
))
.alias("row_num");

let out = df.with_column("r", func)?;

// Should only output 'a' and 'r'
assert_eq!(2, out.schema().fields().len());
Ok(())
}

#[tokio::test]
async fn test_distinct() -> Result<()> {
let t = test_table().await?;
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ impl LogicalPlanBuilder {
.window(window_exprs)?
.build()?;
}

Ok(plan)
}
/// Apply a projection without alias.
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,7 @@ impl Window {
.iter()
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect();

let input_len = fields.len();
let mut window_fields = fields;
let expr_fields = exprlist_to_fields(window_expr.as_slice(), &input)?;
Expand Down

0 comments on commit 8affbdb

Please sign in to comment.