diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 34b781079b13..a3b4d6bb5a9c 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -221,6 +221,14 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Timestamp(_, _), Timestamp(_, _) | Date32 | Date64) => true, // date64 to timestamp might not make sense, (Int64, Duration(_)) => true, + (Duration(_), Int64) => true, + (Interval(from_type), Int64) => { + match from_type{ + IntervalUnit::YearMonth => true, + IntervalUnit::DayTime => true, + IntervalUnit::MonthDayNano => false, // Native type is i128 + } + } (_, _) => false, } } @@ -358,7 +366,6 @@ macro_rules! cast_decimal_to_float { /// * To or from `StructArray` /// * List to primitive /// * Utf8 to boolean -/// * Interval and duration pub fn cast_with_options( array: &ArrayRef, to_type: &DataType, @@ -1111,6 +1118,17 @@ pub fn cast_with_options( } } } + (Duration(_), Int64) => cast_array_data::(array, to_type.clone()), + (Interval(from_type), Int64) => match from_type { + IntervalUnit::YearMonth => { + cast_numeric_arrays::(array) + } + IntervalUnit::DayTime => cast_array_data::(array, to_type.clone()), + IntervalUnit::MonthDayNano => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, (_, _) => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -2765,6 +2783,44 @@ mod tests { assert!(c.is_null(2)); } + #[test] + fn test_cast_duration_to_i64() { + let base = vec![5, 6, 7, 8, 100000000]; + + let duration_arrays = vec![ + Arc::new(DurationNanosecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationMicrosecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationMillisecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationSecondArray::from(base.clone())) as ArrayRef, + ]; + + for arr in duration_arrays { + assert!(can_cast_types(arr.data_type(), &DataType::Int64)); + let result = cast(&arr, &DataType::Int64).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(base.as_slice(), result.values()); + } + } + + #[test] + fn test_cast_interval_to_i64() { + let base = vec![5, 6, 7, 8]; + + let interval_arrays = vec![ + Arc::new(IntervalDayTimeArray::from(base.clone())) as ArrayRef, + Arc::new(IntervalYearMonthArray::from( + base.iter().map(|x| *x as i32).collect::>(), + )) as ArrayRef, + ]; + + for arr in interval_arrays { + assert!(can_cast_types(arr.data_type(), &DataType::Int64)); + let result = cast(&arr, &DataType::Int64).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(base.as_slice(), result.values()); + } + } + #[test] fn test_cast_to_strings() { let a = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;