Skip to content

Commit

Permalink
Fix take_bytes Null and Overflow Handling (#4576) (#4579)
Browse files Browse the repository at this point in the history
* Cleanup take_bytes

* Use extend

* Tweak

* Review feedback
  • Loading branch information
tustvold authored Jul 29, 2023
1 parent 8c85d34 commit 18385e5
Showing 1 changed file with 37 additions and 48 deletions.
85 changes: 37 additions & 48 deletions arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,94 +331,70 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
let data_len = indices.len();

let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset);
let mut offsets = MutableBuffer::new(bytes_offset);
offsets.push(T::Offset::default());

let offsets = offsets_buffer.typed_data_mut();
let mut values = MutableBuffer::new(0);
let mut length_so_far = T::Offset::from_usize(0).unwrap();
offsets[0] = length_so_far;

let nulls;
if array.null_count() == 0 && indices.null_count() == 0 {
for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
let index = indices.value(i).to_usize().ok_or_else(|| {
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;

let s = array.value(index);

let s: &[u8] = s.as_ref();
length_so_far += T::Offset::from_usize(s.len()).unwrap();
offsets.extend(indices.values().iter().map(|index| {
let s: &[u8] = array.value(index.as_usize()).as_ref();
values.extend_from_slice(s);
*offset = length_so_far;
}
T::Offset::usize_as(values.len())
}));
nulls = None
} else if indices.null_count() == 0 {
let num_bytes = bit_util::ceil(data_len, 8);

let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = null_buf.as_slice_mut();

for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
let index = indices.value(i).to_usize().ok_or_else(|| {
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;

offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
let index = index.as_usize();
if array.is_valid(index) {
let s: &[u8] = array.value(index).as_ref();

length_so_far += T::Offset::from_usize(s.len()).unwrap();
values.extend_from_slice(s.as_ref());
} else {
bit_util::unset_bit(null_slice, i);
}
*offset = length_so_far;
}
T::Offset::usize_as(values.len())
}));
nulls = Some(null_buf.into());
} else if array.null_count() == 0 {
for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
if indices.is_valid(i) {
let index = indices.value(i).to_usize().ok_or_else(|| {
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;

let s: &[u8] = array.value(index).as_ref();

length_so_far += T::Offset::from_usize(s.len()).unwrap();
let s: &[u8] = array.value(index.as_usize()).as_ref();
values.extend_from_slice(s);
}
*offset = length_so_far;
}
T::Offset::usize_as(values.len())
}));
nulls = indices.nulls().map(|b| b.inner().sliced());
} else {
let num_bytes = bit_util::ceil(data_len, 8);

let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = null_buf.as_slice_mut();

for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
let index = indices.value(i).to_usize().ok_or_else(|| {
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;

if array.is_valid(index) && indices.is_valid(i) {
offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
// check index is valid before using index. The value in
// NULL index slots may not be within bounds of array
let index = index.as_usize();
if indices.is_valid(i) && array.is_valid(index) {
let s: &[u8] = array.value(index).as_ref();

length_so_far += T::Offset::from_usize(s.len()).unwrap();
values.extend_from_slice(s);
} else {
// set null bit
bit_util::unset_bit(null_slice, i);
}
*offset = length_so_far;
}

T::Offset::usize_as(values.len())
}));
nulls = Some(null_buf.into())
}

T::Offset::from_usize(values.len()).expect("offset overflow");

let array_data = ArrayData::builder(T::DATA_TYPE)
.len(data_len)
.add_buffer(offsets_buffer.into())
.add_buffer(offsets.into())
.add_buffer(values.into())
.null_bit_buffer(nulls);

Expand Down Expand Up @@ -1937,6 +1913,7 @@ mod tests {

#[test]
fn test_take_null_indices() {
// Build indices with values that are out of bounds, but masked by null mask
let indices = Int32Array::new(
vec![1, 2, 400, 400].into(),
Some(NullBuffer::from(vec![true, true, false, false])),
Expand All @@ -1949,4 +1926,16 @@ mod tests {
.collect::<Vec<_>>();
assert_eq!(&values, &[Some(23), Some(4), None, None])
}

#[test]
fn test_take_bytes_null_indices() {
let indices = Int32Array::new(
vec![0, 1, 400, 400].into(),
Some(NullBuffer::from_iter(vec![true, true, false, false])),
);
let values = StringArray::from(vec![Some("foo"), None]);
let r = take(&values, &indices, None).unwrap();
let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
assert_eq!(&values, &[Some("foo"), None, None, None])
}
}

0 comments on commit 18385e5

Please sign in to comment.