diff --git a/crates/polars-ops/src/series/ops/clip.rs b/crates/polars-ops/src/series/ops/clip.rs index 917b2a24654d..37fc8b3030e8 100644 --- a/crates/polars-ops/src/series/ops/clip.rs +++ b/crates/polars-ops/src/series/ops/clip.rs @@ -1,4 +1,3 @@ -use num_traits::{clamp, clamp_max, clamp_min}; use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise}; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; @@ -25,7 +24,7 @@ pub fn clip(s: &Series, min: &Series, max: &Series) -> PolarsResult { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref(); let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref(); - let out = clip_helper(ca, min, max).into_series(); + let out = clip_helper_both_bounds(ca, min, max).into_series(); if original_type.is_logical() { out.cast(original_type) } else { @@ -54,7 +53,7 @@ pub fn clip_max(s: &Series, max: &Series) -> PolarsResult { with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref(); - let out = clip_min_max_helper(ca, max, clamp_max).into_series(); + let out = clip_helper_single_bound(ca, max, num_traits::clamp_max).into_series(); if original_type.is_logical() { out.cast(original_type) } else { @@ -83,7 +82,7 @@ pub fn clip_min(s: &Series, min: &Series) -> PolarsResult { with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref(); - let out = clip_min_max_helper(ca, min, clamp_min).into_series(); + let out = clip_helper_single_bound(ca, min, num_traits::clamp_min).into_series(); if original_type.is_logical() { out.cast(original_type) } else { @@ -95,7 +94,7 @@ pub fn clip_min(s: &Series, min: &Series) -> PolarsResult { } } -fn clip_helper( +fn clip_helper_both_bounds( ca: &ChunkedArray, min: &ChunkedArray, max: &ChunkedArray, @@ -106,35 +105,24 @@ where { match (min.len(), max.len()) { (1, 1) => match (min.get(0), max.get(0)) { - (Some(min), Some(max)) => { - ca.apply_generic(|s| s.map(|s| num_traits::clamp(s, min, max))) - }, - _ => ChunkedArray::::full_null(ca.name(), ca.len()), + (Some(min), Some(max)) => clip_unary(ca, |v| num_traits::clamp(v, min, max)), + (Some(min), None) => clip_unary(ca, |v| num_traits::clamp_min(v, min)), + (None, Some(max)) => clip_unary(ca, |v| num_traits::clamp_max(v, max)), + (None, None) => ca.clone(), }, (1, _) => match min.get(0) { - Some(min) => binary_elementwise(ca, max, |opt_s, opt_max| match (opt_s, opt_max) { - (Some(s), Some(max)) => Some(clamp(s, min, max)), - _ => None, - }), - _ => ChunkedArray::::full_null(ca.name(), ca.len()), + Some(min) => clip_binary(ca, max, |v, b| num_traits::clamp(v, min, b)), + None => clip_binary(ca, max, num_traits::clamp_max), }, (_, 1) => match max.get(0) { - Some(max) => binary_elementwise(ca, min, |opt_s, opt_min| match (opt_s, opt_min) { - (Some(s), Some(min)) => Some(clamp(s, min, max)), - _ => None, - }), - _ => ChunkedArray::::full_null(ca.name(), ca.len()), + Some(max) => clip_binary(ca, min, |v, b| num_traits::clamp(v, b, max)), + None => clip_binary(ca, min, num_traits::clamp_min), }, - _ => ternary_elementwise(ca, min, max, |opt_s, opt_min, opt_max| { - match (opt_s, opt_min, opt_max) { - (Some(s), Some(min), Some(max)) => Some(clamp(s, min, max)), - _ => None, - } - }), + _ => clip_ternary(ca, min, max), } } -fn clip_min_max_helper( +fn clip_helper_single_bound( ca: &ChunkedArray, bound: &ChunkedArray, op: F, @@ -146,12 +134,50 @@ where { match bound.len() { 1 => match bound.get(0) { - Some(bound) => ca.apply_generic(|s| s.map(|s| op(s, bound))), - _ => ChunkedArray::::full_null(ca.name(), ca.len()), + Some(bound) => clip_unary(ca, |v| op(v, bound)), + None => ca.clone(), }, - _ => binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) { - (Some(s), Some(bound)) => Some(op(s, bound)), - _ => None, - }), + _ => clip_binary(ca, bound, op), } } + +fn clip_unary(ca: &ChunkedArray, op: F) -> ChunkedArray +where + T: PolarsNumericType, + F: Fn(T::Native) -> T::Native + Copy, +{ + ca.apply_generic(|v| v.map(op)) +} + +fn clip_binary(ca: &ChunkedArray, bound: &ChunkedArray, op: F) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, + F: Fn(T::Native, T::Native) -> T::Native, +{ + binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) { + (Some(s), Some(bound)) => Some(op(s, bound)), + (Some(s), None) => Some(s), + (None, _) => None, + }) +} + +fn clip_ternary( + ca: &ChunkedArray, + min: &ChunkedArray, + max: &ChunkedArray, +) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, +{ + ternary_elementwise(ca, min, max, |opt_v, opt_min, opt_max| { + match (opt_v, opt_min, opt_max) { + (Some(v), Some(min), Some(max)) => Some(num_traits::clamp(v, min, max)), + (Some(v), Some(min), None) => Some(num_traits::clamp_min(v, min)), + (Some(v), None, Some(max)) => Some(num_traits::clamp_max(v, max)), + (Some(v), None, None) => Some(v), + (None, _, _) => None, + } + }) +}