-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add array function to polars expressions #19079
Changes from 9 commits
53e0d16
cb895b2
3462cfe
950fc01
a0bb53f
6316d22
0543f3d
4f3c354
4ae4316
18dc375
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,11 @@ | ||
[env] | ||
# Tune jemalloc (https://github.com/pola-rs/polars/issues/18088). | ||
JEMALLOC_SYS_WITH_MALLOC_CONF = "dirty_decay_ms:500,muzzy_decay_ms:-1" | ||
|
||
[target.'cfg(all())'] | ||
rustflags = [ | ||
"-C", "link-arg=-Wl,-rpath,.../pyenv.git/versions/3.10.15/lib", | ||
"-C", "link-arg=-L.../pyenv.git/versions/3.10.15/lib", | ||
"-C", "link-arg=-lpython3.10", | ||
] | ||
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,8 +41,9 @@ rayon = { workspace = true } | |
recursive = { workspace = true } | ||
regex = { workspace = true, optional = true } | ||
serde = { workspace = true, features = ["rc"], optional = true } | ||
serde_json = { workspace = true, optional = true } | ||
serde_json = { workspace = true, optional = false } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't add a new ArrayVec dependency. And Adding dependencies is not something we want ideally. |
||
strum_macros = { workspace = true } | ||
arrayvec = { version = "0.7.6" , features = ["serde"]} | ||
|
||
[build-dependencies] | ||
version_check = { workspace = true } | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
use arrayvec::ArrayString; | ||
use polars_core::prelude::*; | ||
#[cfg(feature = "array_to_struct")] | ||
use polars_ops::chunked_array::array::{ | ||
arr_default_struct_name_gen, ArrToStructNameGenerator, ToStruct, | ||
}; | ||
|
||
use crate::dsl::function_expr::ArrayFunction; | ||
use crate::dsl::function_expr::{ArrayFunction, ArrayKwargs}; | ||
use crate::prelude::*; | ||
|
||
/// Specialized expressions for [`Series`] of [`DataType::Array`]. | ||
|
@@ -194,3 +195,32 @@ impl ArrayNameSpace { | |
) | ||
} | ||
} | ||
|
||
pub fn array_from_expr<E: AsRef<[IE]>, IE: Into<Expr> + Clone>( | ||
s: E, | ||
dtype_str: &str, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we get a string here? |
||
) -> PolarsResult<Expr> { | ||
let s: Vec<_> = s.as_ref().iter().map(|e| e.clone().into()).collect(); | ||
|
||
polars_ensure!(!s.is_empty(), ComputeError: "`array` needs one or more expressions"); | ||
|
||
// let mut kwargs = ArrayKwargs::default(); | ||
// const max_sz: usize = kwargs.dtype_expr.capacity(); | ||
const MAX_SZ: usize = 256; // hardcode for now, plan to replace this anyway | ||
let mut trunc_str = dtype_str.to_string(); | ||
trunc_str.truncate(MAX_SZ); | ||
let fixed_string = ArrayString::<{ MAX_SZ }>::from(&trunc_str).unwrap(); | ||
let kwargs = ArrayKwargs { | ||
dtype_expr: fixed_string, | ||
}; | ||
|
||
Ok(Expr::Function { | ||
input: s, | ||
function: FunctionExpr::ArrayExpr(ArrayFunction::Array(kwargs)), | ||
options: FunctionOptions { | ||
collect_groups: ApplyOptions::ElementWise, | ||
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, | ||
..Default::default() | ||
}, | ||
}) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,21 @@ | ||
use std::collections::HashMap; | ||
|
||
use arrayvec::ArrayString; | ||
use arrow::array::{FixedSizeListArray, PrimitiveArray}; | ||
use arrow::bitmap::MutableBitmap; | ||
use polars_core::with_match_physical_numeric_polars_type; | ||
use polars_ops::chunked_array::array::*; | ||
|
||
use super::*; | ||
use crate::{map, map_as_slice}; | ||
|
||
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash, Default, Serialize, Deserialize)] | ||
pub struct ArrayKwargs { | ||
// Not sure how to get a serializable DataType here | ||
// For prototype, use fixed size string | ||
pub dtype_expr: ArrayString<256>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have a string here? |
||
} | ||
|
||
#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] | ||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] | ||
pub enum ArrayFunction { | ||
|
@@ -19,6 +32,7 @@ pub enum ArrayFunction { | |
Any, | ||
#[cfg(feature = "array_any_all")] | ||
All, | ||
Array(ArrayKwargs), | ||
Sort(SortOptions), | ||
Reverse, | ||
ArgMin, | ||
|
@@ -46,6 +60,7 @@ impl ArrayFunction { | |
Median => mapper.map_to_float_dtype(), | ||
#[cfg(feature = "array_any_all")] | ||
Any | All => mapper.with_dtype(DataType::Boolean), | ||
Array(kwargs) => array_output_type(mapper.args(), kwargs), | ||
Sort(_) => mapper.with_same_dtype(), | ||
Reverse => mapper.with_same_dtype(), | ||
ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE), | ||
|
@@ -60,6 +75,49 @@ impl ArrayFunction { | |
} | ||
} | ||
|
||
fn deserialize_dtype(dtype_expr: &str) -> PolarsResult<Option<DataType>> { | ||
match dtype_expr.len() { | ||
0 => Ok(None), | ||
_ => match serde_json::from_str::<Expr>(dtype_expr) { | ||
Ok(Expr::DtypeColumn(dtypes)) if dtypes.len() == 1 => Ok(Some(dtypes[0].clone())), | ||
Ok(_) => Err( | ||
polars_err!(ComputeError: "Expected a DtypeColumn expression with a single dtype"), | ||
), | ||
Err(_) => Err(polars_err!(ComputeError: "Could not deserialize dtype expression")), | ||
}, | ||
} | ||
} | ||
|
||
fn get_expected_dtype(inputs: &[DataType], kwargs: &ArrayKwargs) -> PolarsResult<DataType> { | ||
// Decide what dtype to use for the constructed array | ||
// For now, the logic is to use the dtype in kwargs, if specified | ||
// Otherwise, use the type of the first column. | ||
// | ||
// An alternate idea could be to call try_get_supertype for the types. | ||
// Or logic like DataFrame::get_supertype_all | ||
// The problem is, I think this cast may be too general and we may only want to support primitive types | ||
// Also, we don't support String yet. | ||
let expected_dtype = deserialize_dtype(&kwargs.dtype_expr)?.unwrap_or(inputs[0].clone()); | ||
Ok(expected_dtype) | ||
} | ||
|
||
fn array_output_type(input_fields: &[Field], kwargs: &ArrayKwargs) -> PolarsResult<Field> { | ||
// Expected target type is either the provided dtype or the type of the first column | ||
let dtypes: Vec<DataType> = input_fields.iter().map(|f| f.dtype().clone()).collect(); | ||
let expected_dtype = get_expected_dtype(&dtypes, kwargs)?; | ||
|
||
for field in input_fields.iter() { | ||
if !field.dtype().is_numeric() { | ||
polars_bail!(ComputeError: "all input fields must be numeric") | ||
} | ||
} | ||
|
||
Ok(Field::new( | ||
PlSmallStr::from_static("array"), | ||
DataType::Array(Box::new(expected_dtype), input_fields.len()), | ||
)) | ||
} | ||
|
||
fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> { | ||
if let DataType::Array(inner, _) = datatype { | ||
Ok(DataType::List(inner.clone())) | ||
|
@@ -85,6 +143,7 @@ impl Display for ArrayFunction { | |
Any => "any", | ||
#[cfg(feature = "array_any_all")] | ||
All => "all", | ||
Array(_) => "array", | ||
Sort(_) => "sort", | ||
Reverse => "reverse", | ||
ArgMin => "arg_min", | ||
|
@@ -118,6 +177,7 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> { | |
Any => map!(any), | ||
#[cfg(feature = "array_any_all")] | ||
All => map!(all), | ||
Array(kwargs) => map_as_slice!(array_new, kwargs), | ||
Sort(options) => map!(sort, options), | ||
Reverse => map!(reverse), | ||
ArgMin => map!(arg_min), | ||
|
@@ -133,6 +193,100 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> { | |
} | ||
} | ||
|
||
// Create a new array from a slice of series | ||
fn array_new(inputs: &[Column], kwargs: ArrayKwargs) -> PolarsResult<Column> { | ||
array_internal(inputs, kwargs) | ||
} | ||
fn array_internal(inputs: &[Column], kwargs: ArrayKwargs) -> PolarsResult<Column> { | ||
let dtypes: Vec<DataType> = inputs.iter().map(|f| f.dtype().clone()).collect(); | ||
let expected_dtype = get_expected_dtype(&dtypes, &kwargs)?; | ||
|
||
// This conversion is yuck, there is probably a standard way to go from &[Column] to &[Series] | ||
let series: Vec<Series> = inputs | ||
.iter() | ||
.map(|col| col.clone().take_materialized_series()) | ||
.collect(); | ||
|
||
// Convert dtype to native numeric type and invoke array_numeric | ||
let res_series = with_match_physical_numeric_polars_type!(expected_dtype, |$T| { | ||
array_numeric::<$T>(&series[..], &expected_dtype) | ||
})?; | ||
|
||
Ok(res_series.into_column()) | ||
} | ||
|
||
// Combine numeric series into an array | ||
fn array_numeric<T: PolarsNumericType>( | ||
inputs: &[Series], | ||
dtype: &DataType, | ||
) -> PolarsResult<Series> { | ||
let rows = inputs[0].len(); | ||
let cols = inputs.len(); | ||
let capacity = cols * rows; | ||
|
||
let mut values: Vec<T::Native> = vec![T::Native::default(); capacity]; | ||
|
||
// Support for casting | ||
// Cast fields to the target dtype as needed | ||
let mut casts = HashMap::new(); | ||
for j in 0..cols { | ||
if inputs[j].dtype() != dtype { | ||
let cast_input = inputs[j].cast(dtype)?; | ||
casts.insert(j, cast_input); | ||
} | ||
} | ||
|
||
let mut cols_ca = Vec::new(); | ||
for j in 0..cols { | ||
if inputs[j].dtype() != dtype { | ||
cols_ca.push(casts.get(&j).expect("expect conversion").unpack::<T>()?); | ||
} else { | ||
cols_ca.push(inputs[j].unpack::<T>()?); | ||
} | ||
} | ||
|
||
for i in 0..rows { | ||
for j in 0..cols { | ||
values[i * cols + j] = unsafe { cols_ca[j].value_unchecked(i) }; | ||
} | ||
} | ||
|
||
let validity = if cols_ca.iter().any(|col| col.has_nulls()) { | ||
let mut validity = MutableBitmap::from_len_zeroed(capacity); | ||
for (j, col) in cols_ca.iter().enumerate() { | ||
let mut row_offset = 0; | ||
for chunk in col.chunks() { | ||
if let Some(chunk_validity) = chunk.validity() { | ||
for set_bit in chunk_validity.true_idx_iter() { | ||
validity.set(cols * (row_offset + set_bit) + j, true); | ||
} | ||
} else { | ||
for chunk_row in 0..chunk.len() { | ||
validity.set(cols * (row_offset + chunk_row) + j, true); | ||
} | ||
} | ||
row_offset += chunk.len(); | ||
} | ||
} | ||
Some(validity.into()) | ||
} else { | ||
None | ||
}; | ||
|
||
let values_array = PrimitiveArray::from_vec(values).with_validity(validity); | ||
let dtype = DataType::Array(Box::new(dtype.clone()), cols); | ||
let arrow_dtype = dtype.to_arrow(CompatLevel::newest()); | ||
let array = FixedSizeListArray::try_new( | ||
arrow_dtype.clone(), | ||
values_array.len(), | ||
Box::new(values_array), | ||
None, | ||
)?; | ||
Ok(unsafe { | ||
Series::_try_from_arrow_unchecked("Array".into(), vec![Box::new(array)], &arrow_dtype)? | ||
}) | ||
} | ||
|
||
pub(super) fn max(s: &Column) -> PolarsResult<Column> { | ||
Ok(s.array()?.array_max().into()) | ||
} | ||
|
@@ -249,3 +403,102 @@ pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> { | |
|
||
ca.array_shift(n.as_materialized_series()).map(Column::from) | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use polars_core::datatypes::Field; | ||
use polars_core::frame::DataFrame; | ||
use polars_core::prelude::{Column, Series}; | ||
|
||
use super::*; | ||
|
||
#[test] | ||
fn test_array_f64() { | ||
println!("\ntest_array_f64"); | ||
let f1 = Series::new("f1".into(), &[1.0, 2.0]); | ||
let f2 = Series::new("f2".into(), &[3.0, 4.0]); | ||
|
||
let mut cols: Vec<Column> = Vec::new(); | ||
cols.push(Column::Series(f1)); | ||
cols.push(Column::Series(f2)); | ||
|
||
let array_df = DataFrame::new(cols.clone()).unwrap(); | ||
println!("input df\n{}\n", &array_df); | ||
|
||
let mut fields: Vec<Field> = Vec::new(); | ||
for col in &cols { | ||
let f: Field = (col.field().to_mut()).clone(); | ||
fields.push(f); | ||
} | ||
let kwargs = crate::dsl::function_expr::array::ArrayKwargs { | ||
dtype_expr: "{\"DtypeColumn\":[\"Float64\"]}".to_string(), | ||
}; | ||
let expected_result = | ||
crate::dsl::function_expr::array::array_output_type(&fields, kwargs.clone()).unwrap(); | ||
println!("expected result\n{:?}\n", &expected_result); | ||
|
||
let new_arr = | ||
crate::dsl::function_expr::array::array_internal(array_df.get_columns(), kwargs); | ||
println!("actual result\n{:?}", &new_arr); | ||
|
||
assert!(new_arr.is_ok()); | ||
assert_eq!(new_arr.unwrap().dtype(), expected_result.dtype()); | ||
} | ||
|
||
fn i32_series() -> (Vec<Column>, Vec<Field>, DataFrame) { | ||
let f1 = Series::new("f1".into(), &[1, 2]); | ||
let f2 = Series::new("f2".into(), &[3, 4]); | ||
|
||
let mut cols: Vec<Column> = Vec::new(); | ||
cols.push(Column::Series(f1)); | ||
cols.push(Column::Series(f2)); | ||
|
||
let array_df = DataFrame::new(cols.clone()).unwrap(); | ||
println!("input df\n{}\n", &array_df); | ||
|
||
let mut fields: Vec<Field> = Vec::new(); | ||
for col in &cols { | ||
let f: Field = (col.field().to_mut()).clone(); | ||
fields.push(f); | ||
} | ||
(cols, fields, array_df) | ||
} | ||
|
||
#[test] | ||
fn test_array_i32() { | ||
println!("\ntest_array_i32"); | ||
let (_cols, fields, array_df) = i32_series(); | ||
let kwargs = crate::dsl::function_expr::array::ArrayKwargs { | ||
dtype_expr: "{\"DtypeColumn\":[\"Int32\"]}".to_string(), | ||
}; | ||
let expected_result = | ||
crate::dsl::function_expr::array::array_output_type(&fields, kwargs.clone()).unwrap(); | ||
println!("expected result\n{:?}\n", &expected_result); | ||
|
||
let new_arr = | ||
crate::dsl::function_expr::array::array_internal(array_df.get_columns(), kwargs); | ||
println!("actual result\n{:?}", &new_arr); | ||
|
||
assert!(new_arr.is_ok()); | ||
assert_eq!(new_arr.unwrap().dtype(), expected_result.dtype()); | ||
} | ||
|
||
#[test] | ||
fn test_array_i32_converted() { | ||
println!("\ntest_array_i32_converted"); | ||
let (_cols, fields, array_df) = i32_series(); | ||
let kwargs = crate::dsl::function_expr::array::ArrayKwargs { | ||
dtype_expr: "{\"DtypeColumn\":[\"Float64\"]}".to_string(), | ||
}; | ||
let expected_result = | ||
crate::dsl::function_expr::array::array_output_type(&fields, kwargs.clone()).unwrap(); | ||
println!("expected result\n{:?}\n", &expected_result); | ||
|
||
let new_arr = | ||
crate::dsl::function_expr::array::array_internal(array_df.get_columns(), kwargs); | ||
println!("actual result\n{:?}", &new_arr); | ||
|
||
assert!(new_arr.is_ok()); | ||
assert_eq!(new_arr.unwrap().dtype(), expected_result.dtype()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI. This is temporary developer scaffolding so I can run and debug via
cargo test
directly. See e.g. https://stackoverflow.com/questions/78204333/how-to-run-rust-library-unit-tests-with-maturinThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be removed. This shouldn't be commited.