Skip to content

Commit

Permalink
fix to_dtype for aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 13, 2024
1 parent 8ea9c5c commit 8439a9e
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 91 deletions.
4 changes: 4 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ impl DataType {
}
}

pub fn implode(self) -> DataType {
DataType::List(Box::new(self))
}

/// Convert to the physical data type
#[must_use]
pub fn to_physical(&self) -> DataType {
Expand Down
176 changes: 101 additions & 75 deletions crates/polars-plan/src/logical_plan/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,47 @@ impl AExpr {
}

/// Get Field result of the expression. The schema is the input data.
#[recursive]
pub fn to_field(
&self,
schema: &Schema,
ctxt: Context,
arena: &Arena<AExpr>,
) -> PolarsResult<Field> {
// During aggregation a column that isn't aggregated gets an extra nesting level
// col(foo: i64) -> list[i64]
// But not if we do an aggregation:
// col(foo: i64).sum() -> i64
// The `nested` keeps track of the nesting we need to add.
let mut nested = matches!(ctxt, Context::Aggregation) as u8;
let mut field = self.to_field_impl(schema, arena, &mut nested)?;

if nested >= 1 {
field.coerce(field.data_type().clone().implode());
}
Ok(field)
}

/// Get Field result of the expression. The schema is the input data.
#[recursive]
pub fn to_field_impl(
&self,
schema: &Schema,
arena: &Arena<AExpr>,
nested: &mut u8,
) -> PolarsResult<Field> {
use AExpr::*;
use DataType::*;
match self {
Len => Ok(Field::new(LEN, IDX_DTYPE)),
Len => {
*nested = 0;
Ok(Field::new(LEN, IDX_DTYPE))
},
Window { function, .. } => {
let e = arena.get(*function);
e.to_field(schema, ctxt, arena)
e.to_field_impl(schema, arena, nested)
},
Explode(expr) => {
let field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let field = arena.get(*expr).to_field_impl(schema, arena, nested)?;

if let List(inner) = field.data_type() {
Ok(Field::new(field.name(), *inner.clone()))
Expand All @@ -47,22 +71,11 @@ impl AExpr {
},
Alias(expr, name) => Ok(Field::new(
name,
arena.get(*expr).get_type(schema, ctxt, arena)?,
arena.get(*expr).to_field_impl(schema, arena, nested)?.dtype,
)),
Column(name) => {
let field = schema
.get_field(name)
.ok_or_else(|| PolarsError::ColumnNotFound(name.to_string().into()));

match ctxt {
Context::Default => field,
Context::Aggregation => field.map(|mut field| {
let dtype = List(Box::new(field.data_type().clone()));
field.coerce(dtype);
field
}),
}
},
Column(name) => schema
.get_field(name)
.ok_or_else(|| PolarsError::ColumnNotFound(name.to_string().into())),
Literal(sv) => Ok(match sv {
LiteralValue::Series(s) => s.field().into_owned(),
_ => Field::new(sv.output_name(), sv.get_datatype()),
Expand All @@ -83,45 +96,43 @@ impl AExpr {
| Operator::LogicalOr => {
let out_field;
let out_name = {
out_field = arena.get(*left).to_field(schema, ctxt, arena)?;
out_field = arena.get(*left).to_field_impl(schema, arena, nested)?;
out_field.name().as_str()
};
Field::new(out_name, Boolean)
},
Operator::TrueDivide => return get_truediv_field(*left, arena, ctxt, schema),
_ => return get_arithmetic_field(*left, *right, arena, *op, ctxt, schema),
Operator::TrueDivide => return get_truediv_field(*left, arena, schema, nested),
_ => return get_arithmetic_field(*left, *right, arena, *op, schema, nested),
};

Ok(field)
},
Sort { expr, .. } => arena.get(*expr).to_field(schema, ctxt, arena),
Sort { expr, .. } => arena.get(*expr).to_field_impl(schema, arena, nested),
Gather {
expr,
returns_scalar,
..
} => {
let ctxt = if *returns_scalar {
Context::Default
} else {
ctxt
};
arena.get(*expr).to_field(schema, ctxt, arena)
if *returns_scalar {
*nested = nested.saturating_sub(1);
}
arena.get(*expr).to_field_impl(schema, arena, nested)
},
SortBy { expr, .. } => arena.get(*expr).to_field(schema, ctxt, arena),
Filter { input, .. } => arena.get(*input).to_field(schema, ctxt, arena),
SortBy { expr, .. } => arena.get(*expr).to_field_impl(schema, arena, nested),
Filter { input, .. } => arena.get(*input).to_field_impl(schema, arena, nested),
Agg(agg) => {
use IRAggExpr::*;
match agg {
Max { input: expr, .. }
| Min { input: expr, .. }
| First(expr)
| Last(expr) => {
// default context because `col()` would return a list in aggregation context
arena.get(*expr).to_field(schema, Context::Default, arena)
*nested = nested.saturating_sub(1);
arena.get(*expr).to_field_impl(schema, arena, nested)
},
Sum(expr) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let dt = match field.data_type() {
Boolean => Some(IDX_DTYPE),
UInt8 | Int8 | Int16 | UInt16 => Some(Int64),
Expand All @@ -133,62 +144,61 @@ impl AExpr {
Ok(field)
},
Median(expr) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
match field.dtype {
Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)),
_ => float_type(&mut field),
}
Ok(field)
},
Mean(expr) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
match field.dtype {
Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)),
_ => float_type(&mut field),
}
Ok(field)
},
Implode(expr) => {
// default context because `col()` would return a list in aggregation context
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
field.coerce(DataType::List(field.data_type().clone().into()));
Ok(field)
},
Std(expr, _) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
float_type(&mut field);
Ok(field)
},
Var(expr, _) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
float_type(&mut field);
Ok(field)
},
NUnique(expr) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
*nested = 0;
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
field.coerce(IDX_DTYPE);
Ok(field)
},
Count(expr, _) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
*nested = 0;
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
field.coerce(IDX_DTYPE);
Ok(field)
},
AggGroups(expr) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
*nested = 1;
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
field.coerce(List(IDX_DTYPE.into()));
Ok(field)
},
Quantile { expr, .. } => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
float_type(&mut field);
Ok(field)
},
Expand All @@ -197,20 +207,35 @@ impl AExpr {
Cast {
expr, data_type, ..
} => {
let field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
Ok(Field::new(field.name(), data_type.clone()))
},
Ternary { truthy, falsy, .. } => {
let mut truthy = arena.get(*truthy).to_field(schema, ctxt, arena)?;
let falsy = arena.get(*falsy).to_field(schema, ctxt, arena)?;
if let DataType::Null = *truthy.data_type() {
truthy.coerce(falsy.data_type().clone());
Ok(truthy)
let mut nested_truthy = *nested;
let mut nested_falsy = *nested;

// During aggregation:
// left: col(foo): list<T> nesting: 1
// right; col(foo).first(): T nesting: 0
// col(foo) + col(foo).first() will have nesting 1 as we still maintain the groups list.
let mut truthy =
arena
.get(*truthy)
.to_field_impl(schema, arena, &mut nested_truthy)?;
let falsy = arena
.get(*falsy)
.to_field_impl(schema, arena, &mut nested_falsy)?;

let st = if let DataType::Null = *truthy.data_type() {
falsy.data_type().clone()
} else {
let st = try_get_supertype(truthy.data_type(), falsy.data_type())?;
truthy.coerce(st);
Ok(truthy)
}
try_get_supertype(truthy.data_type(), falsy.data_type())?
};

*nested = std::cmp::max(nested_truthy, nested_falsy);

truthy.coerce(st);
Ok(truthy)
},
AnonymousFunction {
output_type,
Expand All @@ -221,18 +246,18 @@ impl AExpr {
} => {
let tmp = function.get_output();
let output_type = tmp.as_ref().unwrap_or(output_type);
let fields = func_args_to_fields(input, schema, arena)?;
let fields = func_args_to_fields(input, schema, arena, nested)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str);
Ok(output_type.get_field(schema, ctxt, &fields))
Ok(output_type.get_field(schema, Context::Default, &fields))
},
Function {
function, input, ..
} => {
let fields = func_args_to_fields(input, schema, arena)?;
let fields = func_args_to_fields(input, schema, arena, nested)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function);
function.get_field(schema, ctxt, &fields)
function.get_field(schema, Context::Default, &fields)
},
Slice { input, .. } => arena.get(*input).to_field(schema, ctxt, arena),
Slice { input, .. } => arena.get(*input).to_field_impl(schema, arena, nested),
Wildcard => {
polars_bail!(ComputeError: "wildcard column selection not supported at this point")
},
Expand All @@ -247,14 +272,15 @@ fn func_args_to_fields(
input: &[ExprIR],
schema: &Schema,
arena: &Arena<AExpr>,
nested: &mut u8,
) -> PolarsResult<Vec<Field>> {
input
.iter()
// Default context because `col()` would return a list in aggregation context
.map(|e| {
arena
.get(e.node())
.to_field(schema, Context::Default, arena)
.to_field_impl(schema, arena, nested)
.map(|mut field| {
field.name = e.output_name().into();
field
Expand All @@ -268,8 +294,8 @@ fn get_arithmetic_field(
right: Node,
arena: &Arena<AExpr>,
op: Operator,
ctxt: Context,
schema: &Schema,
nested: &mut u8,
) -> PolarsResult<Field> {
use DataType::*;
let left_ae = arena.get(left);
Expand All @@ -283,11 +309,11 @@ fn get_arithmetic_field(
// leading to quadratic behavior. # 4736
//
// further right_type is only determined when needed.
let mut left_field = left_ae.to_field(schema, ctxt, arena)?;
let mut left_field = left_ae.to_field_impl(schema, arena, nested)?;

let super_type = match op {
Operator::Minus if left_field.dtype.is_temporal() => {
let right_type = right_ae.get_type(schema, ctxt, arena)?;
let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype;
match (&left_field.dtype, right_type) {
// T - T != T if T is a datetime / date
(Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, &tur)),
Expand All @@ -302,7 +328,7 @@ fn get_arithmetic_field(
IDX_DTYPE
},
_ => {
let right_type = right_ae.get_type(schema, ctxt, arena)?;
let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype;

match (&left_field.dtype, &right_type) {
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -345,10 +371,10 @@ fn get_arithmetic_field(
fn get_truediv_field(
left: Node,
arena: &Arena<AExpr>,
ctxt: Context,
schema: &Schema,
nested: &mut u8,
) -> PolarsResult<Field> {
let mut left_field = arena.get(left).to_field(schema, ctxt, arena)?;
let mut left_field = arena.get(left).to_field_impl(schema, arena, nested)?;
use DataType::*;
let out_type = match left_field.data_type() {
Float32 => Float32,
Expand Down
9 changes: 0 additions & 9 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,6 @@ def test_concat() -> None:
assert s.len() == 3


@pytest.mark.parametrize("dtype", [pl.Int64, pl.Float64, pl.String, pl.Boolean])
def test_eq_missing_list_and_primitive(dtype: PolarsDataType) -> None:
s1 = pl.Series([None, None], dtype=dtype)
s2 = pl.Series([None, None], dtype=pl.List(dtype))

expected = pl.Series([True, True])
assert_series_equal(s1.eq_missing(s2), expected)


def test_to_frame() -> None:
s1 = pl.Series([1, 2])
s2 = pl.Series("s", [1, 2])
Expand Down
7 changes: 0 additions & 7 deletions py-polars/tests/unit/testing/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,6 @@ def test_assert_series_equal_full_null_nested_list() -> None:
assert_series_equal(s, s)


def test_assert_series_equal_full_null_nested_not_nested() -> None:
s1 = pl.Series([None, None], dtype=pl.List(pl.Float64))
s2 = pl.Series([None, None], dtype=pl.Float64)

assert_series_equal(s1, s2, check_dtypes=False)


def test_assert_series_equal_nested_list_nan() -> None:
s = pl.Series([[1.0, 2.0], [3.0, nan]], dtype=pl.List(pl.Float64))
assert_series_equal(s, s)
Expand Down

0 comments on commit 8439a9e

Please sign in to comment.