Skip to content

Commit

Permalink
refactor: Get Column into polars-expr (#19660)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Nov 6, 2024
1 parent 1ae9d24 commit d34a3e1
Show file tree
Hide file tree
Showing 32 changed files with 580 additions and 338 deletions.
9 changes: 9 additions & 0 deletions crates/polars-core/src/chunked_array/from_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ where
}
}

impl FromIterator<Option<Column>> for ListChunked {
fn from_iter<T: IntoIterator<Item = Option<Column>>>(iter: T) -> Self {
ListChunked::from_iter(
iter.into_iter()
.map(|c| c.map(|c| c.take_materialized_series())),
)
}
}

impl FromIterator<Option<Series>> for ListChunked {
#[inline]
fn from_iter<I: IntoIterator<Item = Option<Series>>>(iter: I) -> Self {
Expand Down
69 changes: 3 additions & 66 deletions crates/polars-core/src/frame/column/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,70 +1,7 @@
use num_traits::{Num, NumCast};
use polars_error::{polars_bail, PolarsResult};
use polars_error::PolarsResult;

use super::{Column, ScalarColumn, Series};
use crate::utils::Container;

fn output_length(a: &Column, b: &Column) -> PolarsResult<usize> {
match (a.len(), b.len()) {
// broadcasting
(1, o) | (o, 1) => Ok(o),
// equal
(a, b) if a == b => Ok(a),
// unequal
(a, b) => {
polars_bail!(InvalidOperation: "cannot do arithmetic operation on series of different lengths: got {} and {}", a, b)
},
}
}

fn unit_series_op<F: Fn(&Series, &Series) -> PolarsResult<Series>>(
l: &Series,
r: &Series,
op: F,
length: usize,
) -> PolarsResult<Column> {
debug_assert!(l.len() <= 1);
debug_assert!(r.len() <= 1);

op(l, r)
.map(|s| ScalarColumn::from_single_value_series(s, length))
.map(Column::from)
}

fn op_with_broadcast<F: Fn(&Series, &Series) -> PolarsResult<Series>>(
l: &Column,
r: &Column,
op: F,
) -> PolarsResult<Column> {
// Here we rely on the underlying broadcast operations.

let length = output_length(l, r)?;
match (l, r) {
(Column::Series(l), Column::Scalar(r)) => {
let r = r.as_single_value_series();
if l.len() == 1 {
unit_series_op(l, &r, op, length)
} else {
op(l, &r).map(Column::from)
}
},
(Column::Scalar(l), Column::Series(r)) => {
let l = l.as_single_value_series();
if r.len() == 1 {
unit_series_op(&l, r, op, length)
} else {
op(&l, r).map(Column::from)
}
},
(Column::Scalar(l), Column::Scalar(r)) => unit_series_op(
&l.as_single_value_series(),
&r.as_single_value_series(),
op,
length,
),
(l, r) => op(l.as_materialized_series(), r.as_materialized_series()).map(Column::from),
}
}

fn num_op_with_broadcast<T: Num + NumCast, F: Fn(&Series, T) -> Series>(
c: &'_ Column,
Expand All @@ -90,7 +27,7 @@ macro_rules! broadcastable_ops {

#[inline]
fn $op(self, rhs: Self) -> Self::Output {
op_with_broadcast(&self, &rhs, |l, r| l.$op(r))
self.try_apply_broadcasting_binary_elementwise(&rhs, |l, r| l.$op(r))
}
}

Expand All @@ -99,7 +36,7 @@ macro_rules! broadcastable_ops {

#[inline]
fn $op(self, rhs: Self) -> Self::Output {
op_with_broadcast(self, rhs, |l, r| l.$op(r))
self.try_apply_broadcasting_binary_elementwise(rhs, |l, r| l.$op(r))
}
}
)+
Expand Down
Loading

0 comments on commit d34a3e1

Please sign in to comment.