Skip to content

Commit

Permalink
fix offset trait
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Dec 9, 2023
1 parent ef98cf6 commit 29ac936
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 33 deletions.
54 changes: 27 additions & 27 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2071,13 +2071,13 @@ impl ScalarValue {
///
/// assert_eq!(scalar_vec, expected);
/// ```
pub fn convert_list_array_to_scalar_vec<O: OffsetSizeTrait>(array: &dyn Array) -> Result<Vec<Vec<Self>>> {


if as_list_array(array).is_ok() {
Self::convert_list_array_to_scalar_vec_internal(array)
pub fn convert_list_array_to_scalar_vec<O: OffsetSizeTrait>(
array: &dyn Array,
) -> Result<Vec<Vec<Self>>> {
if array.as_list_opt::<O>().is_some() {
Self::convert_list_array_to_scalar_vec_internal::<O>(array)
} else {
_internal_err!("Expected ListArray but found: {array:?}")
_internal_err!("Expected GenericListArray but found: {array:?}")
}
}

Expand All @@ -2086,18 +2086,18 @@ impl ScalarValue {
) -> Result<Vec<Vec<Self>>> {
let mut scalars_vec = Vec::with_capacity(array.len());

let list_arr = as_generic_list_array::<O>(array);

if let Ok(list_arr) = as_list_array(array) {
if let Some(list_arr) = array.as_list_opt::<O>() {
for index in 0..list_arr.len() {
let scalars = match list_arr.is_null(index) {
true => Vec::new(),
false => {
let nested_array = list_arr.value(index);
Self::convert_list_array_to_scalar_vec_internal(&nested_array)?
.into_iter()
.flatten()
.collect()
Self::convert_list_array_to_scalar_vec_internal::<O>(
&nested_array,
)?
.into_iter()
.flatten()
.collect()
}
};
scalars_vec.push(scalars);
Expand All @@ -2106,6 +2106,7 @@ impl ScalarValue {
let scalars = ScalarValue::convert_non_list_array_to_scalars(array)?;
scalars_vec.push(scalars);
}

Ok(scalars_vec)
}

Expand Down Expand Up @@ -2134,16 +2135,16 @@ impl ScalarValue {
/// assert_eq!(scalar_vec, expected);
/// ```
pub fn convert_non_list_array_to_scalars(array: &dyn Array) -> Result<Vec<Self>> {
if as_list_array(array).is_ok() {
_internal_err!("Expected non-ListArray but found: {array:?}")
} else {
let mut scalars = Vec::with_capacity(array.len());
for index in 0..array.len() {
let scalar = ScalarValue::try_from_array(array, index)?;
scalars.push(scalar);
}
Ok(scalars)
if array.as_list_opt::<i32>().is_some() || array.as_list_opt::<i64>().is_some() {
return _internal_err!("Expected non ListArray but found: {array:?}");
}

let mut scalars = Vec::with_capacity(array.len());
for index in 0..array.len() {
let scalar = ScalarValue::try_from_array(array, index)?;
scalars.push(scalar);
}
Ok(scalars)
}

// TODO: Support more types after other ScalarValue is wrapped with ArrayRef
Expand Down Expand Up @@ -2194,7 +2195,7 @@ impl ScalarValue {
typed_cast!(array, index, LargeStringArray, LargeUtf8)?
}
DataType::List(_) => {
let list_array = as_list_array(array)?;
let list_array = as_list_array(array);
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(array_into_list_array(nested_array));
Expand Down Expand Up @@ -3163,7 +3164,6 @@ impl ScalarType<i64> for TimestampNanosecondType {
}

#[cfg(test)]
#[cfg(feature = "parquet")]
mod tests {
use super::*;

Expand Down Expand Up @@ -3202,7 +3202,7 @@ mod tests {
let l12 = arrays_into_list_array([l1, l2]).unwrap();
let arr = Arc::new(l12) as ArrayRef;

let actual = ScalarValue::convert_list_array_to_scalar_vec(&arr).unwrap();
let actual = ScalarValue::convert_list_array_to_scalar_vec::<i32>(&arr).unwrap();
let expected = vec![
vec![
ScalarValue::Int32(Some(1)),
Expand Down Expand Up @@ -3232,7 +3232,7 @@ mod tests {
let actual_arr = sv
.to_array_of_size(2)
.expect("Failed to convert to array of size");
let actual_list_arr = as_list_array(&actual_arr).unwrap();
let actual_list_arr = as_list_array(&actual_arr);

let arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), None, Some(2)]),
Expand Down Expand Up @@ -3272,7 +3272,7 @@ mod tests {
];

let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8);
let result = as_list_array(&array).unwrap();
let result = as_list_array(&array);

let expected = array_into_list_array(Arc::new(StringArray::from(vec![
"rust",
Expand Down
8 changes: 6 additions & 2 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,17 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
let column = actual[0].column(0);
assert_eq!(column.len(), 1);

let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec(&column)?;
let mut scalars = scalar_vec[0].clone();
// 1 row
let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec::<i32>(&column)?;

// workaround lack of Ord of ScalarValue
let cmp = |a: &ScalarValue, b: &ScalarValue| {
a.partial_cmp(b).expect("Can compare ScalarValues")
};

let mut scalars = scalar_vec.first().unwrap().to_owned();
scalars.sort_by(cmp);

assert_eq!(
scalars,
vec![
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl Accumulator for DistinctArrayAggAccumulator {
let array = &values[0];
match array.data_type() {
DataType::List(_) => {
let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec(array)?;
let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec::<i32>(array)?;
for scalars in scalar_vec {
self.values.extend(scalars);
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
partition_ordering_values.push(self.ordering_values.clone());

let array_agg_res =
ScalarValue::convert_list_array_to_scalar_vec(array_agg_values)?;
ScalarValue::convert_list_array_to_scalar_vec::<i32>(array_agg_values)?;

for v in array_agg_res.into_iter() {
partition_values.push(v);
}

let orderings = ScalarValue::convert_list_array_to_scalar_vec(agg_orderings)?;
let orderings = ScalarValue::convert_list_array_to_scalar_vec::<i32>(agg_orderings)?;
// Ordering requirement expression values for each entry in the ARRAY_AGG list
let other_ordering_values = self.convert_array_agg_to_orderings(orderings)?;
for v in other_ordering_values.into_iter() {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/aggregate/count_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl Accumulator for DistinctCountAccumulator {
return Ok(());
}
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec(&states[0])?;
let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec::<i32>(&states[0])?;
for scalars in scalar_vec.into_iter() {
self.values.extend(scalars)
}
Expand Down

0 comments on commit 29ac936

Please sign in to comment.