Skip to content

Commit

Permalink
Merge branch 'main' into 11982/window-projection-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
devanbenz authored Aug 16, 2024
2 parents 79c9b23 + 300a08c commit 88c3c58
Show file tree
Hide file tree
Showing 38 changed files with 1,438 additions and 338 deletions.
20 changes: 9 additions & 11 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::cast::{
as_decimal128_array, as_decimal256_array, as_dictionary_array,
as_fixed_size_binary_array, as_fixed_size_list_array,
};
use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err};
use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_err};
use crate::hash_utils::create_hashes;
use crate::utils::{
array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array,
Expand Down Expand Up @@ -1707,9 +1707,7 @@ impl ScalarValue {
// figure out the type based on the first element
let data_type = match scalars.peek() {
None => {
return _internal_err!(
"Empty iterator passed to ScalarValue::iter_to_array"
);
return _exec_err!("Empty iterator passed to ScalarValue::iter_to_array");
}
Some(sv) => sv.data_type(),
};
Expand All @@ -1723,7 +1721,7 @@ impl ScalarValue {
if let ScalarValue::$SCALAR_TY(v) = sv {
Ok(v)
} else {
_internal_err!(
_exec_err!(
"Inconsistent types in ScalarValue::iter_to_array. \
Expected {:?}, got {:?}",
data_type, sv
Expand All @@ -1743,7 +1741,7 @@ impl ScalarValue {
if let ScalarValue::$SCALAR_TY(v, _) = sv {
Ok(v)
} else {
_internal_err!(
_exec_err!(
"Inconsistent types in ScalarValue::iter_to_array. \
Expected {:?}, got {:?}",
data_type, sv
Expand All @@ -1765,7 +1763,7 @@ impl ScalarValue {
if let ScalarValue::$SCALAR_TY(v) = sv {
Ok(v)
} else {
_internal_err!(
_exec_err!(
"Inconsistent types in ScalarValue::iter_to_array. \
Expected {:?}, got {:?}",
data_type, sv
Expand Down Expand Up @@ -1908,11 +1906,11 @@ impl ScalarValue {
if &inner_key_type == key_type {
Ok(*scalar)
} else {
_internal_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})")
_exec_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})")
}
}
_ => {
_internal_err!(
_exec_err!(
"Expected scalar of type {value_type} but found: {scalar} {scalar:?}"
)
}
Expand Down Expand Up @@ -1940,7 +1938,7 @@ impl ScalarValue {
if let ScalarValue::FixedSizeBinary(_, v) = sv {
Ok(v)
} else {
_internal_err!(
_exec_err!(
"Inconsistent types in ScalarValue::iter_to_array. \
Expected {data_type:?}, got {sv:?}"
)
Expand All @@ -1965,7 +1963,7 @@ impl ScalarValue {
| DataType::RunEndEncoded(_, _)
| DataType::ListView(_)
| DataType::LargeListView(_) => {
return _internal_err!(
return _not_impl_err!(
"Unsupported creation of {:?} array from ScalarValue {:?}",
data_type,
scalars.peek()
Expand Down
88 changes: 87 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1712,11 +1712,13 @@ mod tests {
use datafusion_expr::window_function::row_number;
use datafusion_expr::{
cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFunctionDefinition,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
use sqlparser::ast::NullTreatment;

// Get string representation of the plan
async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) {
Expand Down Expand Up @@ -2362,6 +2364,90 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn window_using_aggregates() -> Result<()> {
// build plan using DataFrame API
let df = test_table().await?.filter(col("c1").eq(lit("a")))?;
let mut aggr_expr = vec![
(
datafusion_functions_aggregate::first_last::first_value_udaf(),
"first_value",
),
(
datafusion_functions_aggregate::first_last::last_value_udaf(),
"last_val",
),
(
datafusion_functions_aggregate::approx_distinct::approx_distinct_udaf(),
"approx_distinct",
),
(
datafusion_functions_aggregate::approx_median::approx_median_udaf(),
"approx_median",
),
(
datafusion_functions_aggregate::median::median_udaf(),
"median",
),
(datafusion_functions_aggregate::min_max::max_udaf(), "max"),
(datafusion_functions_aggregate::min_max::min_udaf(), "min"),
]
.into_iter()
.map(|(func, name)| {
let w = WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(func),
vec![col("c3")],
);

Expr::WindowFunction(w)
.null_treatment(NullTreatment::IgnoreNulls)
.order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)])
.window_frame(WindowFrame::new_bounds(
WindowFrameUnits::Rows,
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
))
.build()
.unwrap()
.alias(name)
})
.collect::<Vec<_>>();
aggr_expr.extend_from_slice(&[col("c2"), col("c3")]);

let df: Vec<RecordBatch> = df.select(aggr_expr)?.collect().await?;

assert_batches_sorted_eq!(
["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| | | | | | | | 1 | -85 |",
"| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |",
"| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |",
"| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |",
"| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |",
"| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |",
"| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |",
"| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |",
"| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |",
"| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |",
"| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |",
"| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |",
"| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |",
"| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |",
"| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |",
"| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |",
"| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |",
"| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |",
"| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |",
"| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |",
"| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"],
&df
);

Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/10346
#[tokio::test]
async fn test_select_over_aggregate_schema() -> Result<()> {
Expand Down
98 changes: 90 additions & 8 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -987,8 +987,24 @@ impl SessionStateBuilder {

/// Returns a new [SessionStateBuilder] based on an existing [SessionState]
/// The session id for the new builder will be unset; all other fields will
/// be cloned from what is set in the provided session state
/// be cloned from what is set in the provided session state. If the default
/// catalog exists in existing session state, the new session state will not
/// create default catalog and schema.
pub fn new_from_existing(existing: SessionState) -> Self {
let default_catalog_exist = existing
.catalog_list()
.catalog(&existing.config.options().catalog.default_catalog)
.is_some();
// The new `with_create_default_catalog_and_schema` should be false if the default catalog exists
let create_default_catalog_and_schema = existing
.config
.options()
.catalog
.create_default_catalog_and_schema
&& !default_catalog_exist;
let new_config = existing
.config
.with_create_default_catalog_and_schema(create_default_catalog_and_schema);
Self {
session_id: None,
analyzer: Some(existing.analyzer),
Expand All @@ -1005,7 +1021,7 @@ impl SessionStateBuilder {
window_functions: Some(existing.window_functions.into_values().collect_vec()),
serializer_registry: Some(existing.serializer_registry),
file_formats: Some(existing.file_formats.into_values().collect_vec()),
config: Some(existing.config),
config: Some(new_config),
table_options: Some(existing.table_options),
execution_props: Some(existing.execution_props),
table_factories: Some(existing.table_factories),
Expand Down Expand Up @@ -1801,17 +1817,19 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> {

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use super::{SessionContextProvider, SessionStateBuilder};
use crate::catalog_common::MemoryCatalogProviderList;
use crate::datasource::MemTable;
use crate::execution::context::SessionState;
use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::DFSchema;
use datafusion_common::Result;
use datafusion_execution::config::SessionConfig;
use datafusion_expr::Expr;
use datafusion_sql::planner::{PlannerContext, SqlToRel};

use crate::execution::context::SessionState;

use super::{SessionContextProvider, SessionStateBuilder};
use std::collections::HashMap;
use std::sync::Arc;

#[test]
fn test_session_state_with_default_features() {
Expand Down Expand Up @@ -1841,4 +1859,68 @@ mod tests {

assert!(sql_to_expr(&state).is_err())
}

#[test]
fn test_from_existing() -> Result<()> {
fn employee_batch() -> RecordBatch {
let name: ArrayRef =
Arc::new(StringArray::from_iter_values(["Andy", "Andrew"]));
let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22]));
RecordBatch::try_from_iter(vec![("name", name), ("age", age)]).unwrap()
}
let batch = employee_batch();
let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;

let session_state = SessionStateBuilder::new()
.with_catalog_list(Arc::new(MemoryCatalogProviderList::new()))
.build();
let table_ref = session_state.resolve_table_ref("employee").to_string();
session_state
.schema_for_ref(&table_ref)?
.register_table("employee".to_string(), Arc::new(table))?;

let default_catalog = session_state
.config
.options()
.catalog
.default_catalog
.clone();
let default_schema = session_state
.config
.options()
.catalog
.default_schema
.clone();
let is_exist = session_state
.catalog_list()
.catalog(default_catalog.as_str())
.unwrap()
.schema(default_schema.as_str())
.unwrap()
.table_exist("employee");
assert!(is_exist);
let new_state = SessionStateBuilder::new_from_existing(session_state).build();
assert!(new_state
.catalog_list()
.catalog(default_catalog.as_str())
.unwrap()
.schema(default_schema.as_str())
.unwrap()
.table_exist("employee"));

// if `with_create_default_catalog_and_schema` is disabled, the new one shouldn't create default catalog and schema
let disable_create_default =
SessionConfig::default().with_create_default_catalog_and_schema(false);
let without_default_state = SessionStateBuilder::new()
.with_config(disable_create_default)
.build();
assert!(without_default_state
.catalog_list()
.catalog(&default_catalog)
.is_none());
let new_state =
SessionStateBuilder::new_from_existing(without_default_state).build();
assert!(new_state.catalog_list().catalog(&default_catalog).is_none());
Ok(())
}
}
Loading

0 comments on commit 88c3c58

Please sign in to comment.