diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index 0f5689ff9990..cee9cbaf84df 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -331,94 +331,70 @@ fn take_bytes( let data_len = indices.len(); let bytes_offset = (data_len + 1) * std::mem::size_of::(); - let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset); + let mut offsets = MutableBuffer::new(bytes_offset); + offsets.push(T::Offset::default()); - let offsets = offsets_buffer.typed_data_mut(); let mut values = MutableBuffer::new(0); - let mut length_so_far = T::Offset::from_usize(0).unwrap(); - offsets[0] = length_so_far; let nulls; if array.null_count() == 0 && indices.null_count() == 0 { - for (i, offset) in offsets.iter_mut().skip(1).enumerate() { - let index = indices.value(i).to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - - let s = array.value(index); - - let s: &[u8] = s.as_ref(); - length_so_far += T::Offset::from_usize(s.len()).unwrap(); + offsets.extend(indices.values().iter().map(|index| { + let s: &[u8] = array.value(index.as_usize()).as_ref(); values.extend_from_slice(s); - *offset = length_so_far; - } + T::Offset::usize_as(values.len()) + })); nulls = None } else if indices.null_count() == 0 { let num_bytes = bit_util::ceil(data_len, 8); let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); let null_slice = null_buf.as_slice_mut(); - - for (i, offset) in offsets.iter_mut().skip(1).enumerate() { - let index = indices.value(i).to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - + offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { + let index = index.as_usize(); if array.is_valid(index) { let s: &[u8] = array.value(index).as_ref(); - - length_so_far += T::Offset::from_usize(s.len()).unwrap(); values.extend_from_slice(s.as_ref()); } else { bit_util::unset_bit(null_slice, i); } - *offset = length_so_far; - } + T::Offset::usize_as(values.len()) + })); nulls = Some(null_buf.into()); } else if array.null_count() == 0 { - for (i, offset) in offsets.iter_mut().skip(1).enumerate() { + offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { if indices.is_valid(i) { - let index = indices.value(i).to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - - let s: &[u8] = array.value(index).as_ref(); - - length_so_far += T::Offset::from_usize(s.len()).unwrap(); + let s: &[u8] = array.value(index.as_usize()).as_ref(); values.extend_from_slice(s); } - *offset = length_so_far; - } + T::Offset::usize_as(values.len()) + })); nulls = indices.nulls().map(|b| b.inner().sliced()); } else { let num_bytes = bit_util::ceil(data_len, 8); let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); let null_slice = null_buf.as_slice_mut(); - - for (i, offset) in offsets.iter_mut().skip(1).enumerate() { - let index = indices.value(i).to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - - if array.is_valid(index) && indices.is_valid(i) { + offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { + // check index is valid before using index. The value in + // NULL index slots may not be within bounds of array + let index = index.as_usize(); + if indices.is_valid(i) && array.is_valid(index) { let s: &[u8] = array.value(index).as_ref(); - - length_so_far += T::Offset::from_usize(s.len()).unwrap(); values.extend_from_slice(s); } else { // set null bit bit_util::unset_bit(null_slice, i); } - *offset = length_so_far; - } - + T::Offset::usize_as(values.len()) + })); nulls = Some(null_buf.into()) } + T::Offset::from_usize(values.len()).expect("offset overflow"); + let array_data = ArrayData::builder(T::DATA_TYPE) .len(data_len) - .add_buffer(offsets_buffer.into()) + .add_buffer(offsets.into()) .add_buffer(values.into()) .null_bit_buffer(nulls); @@ -1937,6 +1913,7 @@ mod tests { #[test] fn test_take_null_indices() { + // Build indices with values that are out of bounds, but masked by null mask let indices = Int32Array::new( vec![1, 2, 400, 400].into(), Some(NullBuffer::from(vec![true, true, false, false])), @@ -1949,4 +1926,16 @@ mod tests { .collect::>(); assert_eq!(&values, &[Some(23), Some(4), None, None]) } + + #[test] + fn test_take_bytes_null_indices() { + let indices = Int32Array::new( + vec![0, 1, 400, 400].into(), + Some(NullBuffer::from_iter(vec![true, true, false, false])), + ); + let values = StringArray::from(vec![Some("foo"), None]); + let r = take(&values, &indices, None).unwrap(); + let values = r.as_string::().iter().collect::>(); + assert_eq!(&values, &[Some("foo"), None, None, None]) + } }