diff --git a/crates/polars-arrow/src/bitmap/immutable.rs b/crates/polars-arrow/src/bitmap/immutable.rs index a896651467d2..5b8d510dfe6c 100644 --- a/crates/polars-arrow/src/bitmap/immutable.rs +++ b/crates/polars-arrow/src/bitmap/immutable.rs @@ -5,7 +5,7 @@ use std::sync::LazyLock; use either::Either; use polars_error::{polars_bail, PolarsResult}; -use super::utils::{count_zeros, fmt, get_bit_unchecked, BitChunk, BitChunks, BitmapIter}; +use super::utils::{self, count_zeros, fmt, get_bit_unchecked, BitChunk, BitChunks, BitmapIter}; use super::{chunk_iter_to_vec, intersects_with, num_intersections_with, IntoIter, MutableBitmap}; use crate::array::Splitable; use crate::bitmap::aligned::AlignedBitmapSlice; @@ -532,6 +532,104 @@ impl Bitmap { pub fn num_edges(&self) -> usize { super::bitmap_ops::num_edges(self) } + + /// Returns the number of zero bits from the start before a one bit is seen + pub fn leading_zeros(&self) -> usize { + utils::leading_zeros(&self.storage, self.offset, self.length) + } + /// Returns the number of one bits from the start before a zero bit is seen + pub fn leading_ones(&self) -> usize { + utils::leading_ones(&self.storage, self.offset, self.length) + } + /// Returns the number of zero bits from the back before a one bit is seen + pub fn trailing_zeros(&self) -> usize { + utils::trailing_zeros(&self.storage, self.offset, self.length) + } + /// Returns the number of one bits from the back before a zero bit is seen + pub fn trailing_ones(&mut self) -> usize { + utils::trailing_ones(&self.storage, self.offset, self.length) + } + + /// Take all `0` bits at the start of the [`Bitmap`] before a `1` is seen, returning how many + /// bits were taken + pub fn take_leading_zeros(&mut self) -> usize { + if self + .lazy_unset_bits() + .is_some_and(|unset_bits| unset_bits == self.length) + { + let leading_zeros = self.length; + self.offset += self.length; + self.length = 0; + *self.unset_bit_count_cache.get_mut() = 0; + return leading_zeros; + } + + let leading_zeros = self.leading_zeros(); + self.offset += leading_zeros; + self.length -= leading_zeros; + if has_cached_unset_bit_count(*self.unset_bit_count_cache.get_mut()) { + *self.unset_bit_count_cache.get_mut() -= leading_zeros as u64; + } + leading_zeros + } + /// Take all `1` bits at the start of the [`Bitmap`] before a `0` is seen, returning how many + /// bits were taken + pub fn take_leading_ones(&mut self) -> usize { + if self + .lazy_unset_bits() + .is_some_and(|unset_bits| unset_bits == 0) + { + let leading_ones = self.length; + self.offset += self.length; + self.length = 0; + *self.unset_bit_count_cache.get_mut() = 0; + return leading_ones; + } + + let leading_ones = self.leading_ones(); + self.offset += leading_ones; + self.length -= leading_ones; + // @NOTE: the unset_bit_count_cache remains unchanged + leading_ones + } + /// Take all `0` bits at the back of the [`Bitmap`] before a `1` is seen, returning how many + /// bits were taken + pub fn take_trailing_zeros(&mut self) -> usize { + if self + .lazy_unset_bits() + .is_some_and(|unset_bits| unset_bits == self.length) + { + let trailing_zeros = self.length; + self.length = 0; + *self.unset_bit_count_cache.get_mut() = 0; + return trailing_zeros; + } + + let trailing_zeros = self.trailing_zeros(); + self.length -= trailing_zeros; + if has_cached_unset_bit_count(*self.unset_bit_count_cache.get_mut()) { + *self.unset_bit_count_cache.get_mut() -= trailing_zeros as u64; + } + trailing_zeros + } + /// Take all `1` bits at the back of the [`Bitmap`] before a `0` is seen, returning how many + /// bits were taken + pub fn take_trailing_ones(&mut self) -> usize { + if self + .lazy_unset_bits() + .is_some_and(|unset_bits| unset_bits == 0) + { + let trailing_ones = self.length; + self.length = 0; + *self.unset_bit_count_cache.get_mut() = 0; + return trailing_ones; + } + + let trailing_ones = self.trailing_ones(); + self.length -= trailing_ones; + // @NOTE: the unset_bit_count_cache remains unchanged + trailing_ones + } } impl> From

for Bitmap { diff --git a/crates/polars-arrow/src/bitmap/utils/mod.rs b/crates/polars-arrow/src/bitmap/utils/mod.rs index 7979fbdca448..47729f94afa6 100644 --- a/crates/polars-arrow/src/bitmap/utils/mod.rs +++ b/crates/polars-arrow/src/bitmap/utils/mod.rs @@ -85,3 +85,213 @@ pub fn count_zeros(slice: &[u8], offset: usize, len: usize) -> usize { let ones_in_suffix = aligned.suffix().count_ones() as usize; len - ones_in_prefix - ones_in_bulk - ones_in_suffix } + +/// Returns the number of zero bits before seeing a one bit in the slice offsetted by `offset` and +/// a length of `length`. +/// +/// # Panics +/// This function panics iff `offset + len > 8 * slice.len()``. +pub fn leading_zeros(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + } + + assert!(8 * slice.len() >= offset + len); + + let aligned = AlignedBitmapSlice::::new(slice, offset, len); + let leading_zeros_in_prefix = + (aligned.prefix().trailing_zeros() as usize).min(aligned.prefix_bitlen()); + if leading_zeros_in_prefix < aligned.prefix_bitlen() { + return leading_zeros_in_prefix; + } + if let Some(full_zero_bulk_words) = aligned.bulk_iter().position(|w| w != 0) { + return aligned.prefix_bitlen() + + full_zero_bulk_words * 64 + + aligned.bulk()[full_zero_bulk_words].trailing_zeros() as usize; + } + + aligned.prefix_bitlen() + + aligned.bulk_bitlen() + + (aligned.suffix().trailing_zeros() as usize).min(aligned.suffix_bitlen()) +} + +/// Returns the number of one bits before seeing a zero bit in the slice offsetted by `offset` and +/// a length of `length`. +/// +/// # Panics +/// This function panics iff `offset + len > 8 * slice.len()``. +pub fn leading_ones(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + } + + assert!(8 * slice.len() >= offset + len); + + let aligned = AlignedBitmapSlice::::new(slice, offset, len); + let leading_ones_in_prefix = aligned.prefix().trailing_ones() as usize; + if leading_ones_in_prefix < aligned.prefix_bitlen() { + return leading_ones_in_prefix; + } + if let Some(full_one_bulk_words) = aligned.bulk_iter().position(|w| w != u64::MAX) { + return aligned.prefix_bitlen() + + full_one_bulk_words * 64 + + aligned.bulk()[full_one_bulk_words].trailing_ones() as usize; + } + + aligned.prefix_bitlen() + aligned.bulk_bitlen() + aligned.suffix().trailing_ones() as usize +} + +/// Returns the number of zero bits before seeing a one bit in the slice offsetted by `offset` and +/// a length of `length`. +/// +/// # Panics +/// This function panics iff `offset + len > 8 * slice.len()``. +pub fn trailing_zeros(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + } + + assert!(8 * slice.len() >= offset + len); + + let aligned = AlignedBitmapSlice::::new(slice, offset, len); + let trailing_zeros_in_suffix = ((aligned.suffix() << ((64 - aligned.suffix_bitlen()) % 64)) + .leading_zeros() as usize) + .min(aligned.suffix_bitlen()); + if trailing_zeros_in_suffix < aligned.suffix_bitlen() { + return trailing_zeros_in_suffix; + } + if let Some(full_zero_bulk_words) = aligned.bulk_iter().rev().position(|w| w != 0) { + return aligned.suffix_bitlen() + + full_zero_bulk_words * 64 + + aligned.bulk()[aligned.bulk().len() - full_zero_bulk_words - 1].leading_zeros() + as usize; + } + + let trailing_zeros_in_prefix = ((aligned.prefix() << ((64 - aligned.prefix_bitlen()) % 64)) + .leading_zeros() as usize) + .min(aligned.prefix_bitlen()); + aligned.suffix_bitlen() + aligned.bulk_bitlen() + trailing_zeros_in_prefix +} + +/// Returns the number of one bits before seeing a zero bit in the slice offsetted by `offset` and +/// a length of `length`. +/// +/// # Panics +/// This function panics iff `offset + len > 8 * slice.len()``. +pub fn trailing_ones(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + } + + assert!(8 * slice.len() >= offset + len); + + let aligned = AlignedBitmapSlice::::new(slice, offset, len); + let trailing_ones_in_suffix = + (aligned.suffix() << ((64 - aligned.suffix_bitlen()) % 64)).leading_ones() as usize; + if trailing_ones_in_suffix < aligned.suffix_bitlen() { + return trailing_ones_in_suffix; + } + if let Some(full_one_bulk_words) = aligned.bulk_iter().rev().position(|w| w != u64::MAX) { + return aligned.suffix_bitlen() + + full_one_bulk_words * 64 + + aligned.bulk()[aligned.bulk().len() - full_one_bulk_words - 1].leading_ones() + as usize; + } + + let trailing_ones_in_prefix = + (aligned.prefix() << ((64 - aligned.prefix_bitlen()) % 64)).leading_ones() as usize; + aligned.suffix_bitlen() + aligned.bulk_bitlen() + trailing_ones_in_prefix +} + +#[cfg(test)] +mod tests { + use rand::Rng; + + use super::*; + use crate::bitmap::Bitmap; + + #[test] + fn leading_trailing() { + macro_rules! testcase { + ($slice:expr, $offset:expr, $length:expr => lz=$lz:expr,lo=$lo:expr,tz=$tz:expr,to=$to:expr) => { + assert_eq!( + leading_zeros($slice, $offset, $length), + $lz, + "leading_zeros" + ); + assert_eq!(leading_ones($slice, $offset, $length), $lo, "leading_ones"); + assert_eq!( + trailing_zeros($slice, $offset, $length), + $tz, + "trailing_zeros" + ); + assert_eq!( + trailing_ones($slice, $offset, $length), + $to, + "trailing_ones" + ); + }; + } + + testcase!(&[], 0, 0 => lz=0,lo=0,tz=0,to=0); + testcase!(&[0], 0, 1 => lz=1,lo=0,tz=1,to=0); + testcase!(&[1], 0, 1 => lz=0,lo=1,tz=0,to=1); + + testcase!(&[0b010], 0, 3 => lz=1,lo=0,tz=1,to=0); + testcase!(&[0b101], 0, 3 => lz=0,lo=1,tz=0,to=1); + testcase!(&[0b100], 0, 3 => lz=2,lo=0,tz=0,to=1); + testcase!(&[0b110], 0, 3 => lz=1,lo=0,tz=0,to=2); + testcase!(&[0b001], 0, 3 => lz=0,lo=1,tz=2,to=0); + testcase!(&[0b011], 0, 3 => lz=0,lo=2,tz=1,to=0); + + testcase!(&[0b010], 1, 2 => lz=0,lo=1,tz=1,to=0); + testcase!(&[0b101], 1, 2 => lz=1,lo=0,tz=0,to=1); + testcase!(&[0b100], 1, 2 => lz=1,lo=0,tz=0,to=1); + testcase!(&[0b110], 1, 2 => lz=0,lo=2,tz=0,to=2); + testcase!(&[0b001], 1, 2 => lz=2,lo=0,tz=2,to=0); + testcase!(&[0b011], 1, 2 => lz=0,lo=1,tz=1,to=0); + } + + #[ignore = "Fuzz test. Too slow"] + #[test] + fn leading_trailing_fuzz() { + let mut rng = rand::thread_rng(); + + const SIZE: usize = 1000; + const REPEATS: usize = 10_000; + + let mut v = Vec::::with_capacity(SIZE); + + for _ in 0..REPEATS { + v.clear(); + let offset = rng.gen_range(0..SIZE); + let length = rng.gen_range(0..SIZE - offset); + let extra_padding = rng.gen_range(0..64); + + let mut num_remaining = usize::min(SIZE, offset + length + extra_padding); + while num_remaining > 0 { + let chunk_size = rng.gen_range(1..=num_remaining); + v.extend( + rng.clone() + .sample_iter(rand::distributions::Slice::new(&[false, true]).unwrap()) + .take(chunk_size), + ); + num_remaining -= chunk_size; + } + + let v_slice = &v[offset..offset + length]; + let lz = v_slice.iter().take_while(|&v| !*v).count(); + let lo = v_slice.iter().take_while(|&v| *v).count(); + let tz = v_slice.iter().rev().take_while(|&v| !*v).count(); + let to = v_slice.iter().rev().take_while(|&v| *v).count(); + + let bm = Bitmap::from_iter(v.iter().copied()); + let (slice, _, _) = bm.as_slice(); + + assert_eq!(leading_zeros(slice, offset, length), lz); + assert_eq!(leading_ones(slice, offset, length), lo); + assert_eq!(trailing_zeros(slice, offset, length), tz); + assert_eq!(trailing_ones(slice, offset, length), to); + } + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs index 9e843f673072..909e68f5af37 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs @@ -88,30 +88,27 @@ pub fn decode_aligned_bytes_dispatch( match (filter, page_validity) { (None, None) => decode_required(values, target), - (None, Some(page_validity)) => decode_optional(values, &page_validity, target), + (None, Some(page_validity)) => decode_optional(values, page_validity, target), - (Some(Filter::Range(rng)), None) => decode_required( - unsafe { values.slice_unchecked(rng.start, rng.end) }, - target, - ), + (Some(Filter::Range(rng)), None) => { + decode_required(values.slice(rng.start, rng.len()), target) + }, (Some(Filter::Range(rng)), Some(mut page_validity)) => { - let prevalidity; - (prevalidity, page_validity) = page_validity.split_at(rng.start); - - (page_validity, _) = page_validity.split_at(rng.len()); - - let values_start = prevalidity.set_bits(); + let mut values = values; + if rng.start > 0 { + let prevalidity; + (prevalidity, page_validity) = page_validity.split_at(rng.start); + page_validity.slice(0, rng.len()); + let values_start = prevalidity.set_bits(); + values = values.slice(values_start, values.len() - values_start); + } - decode_optional( - unsafe { values.slice_unchecked(values_start, values.len()) }, - &page_validity, - target, - ) + decode_optional(values, page_validity, target) }, - (Some(Filter::Mask(filter)), None) => decode_masked_required(values, &filter, target), + (Some(Filter::Mask(filter)), None) => decode_masked_required(values, filter, target), (Some(Filter::Mask(filter)), Some(page_validity)) => { - decode_masked_optional(values, &page_validity, &filter, target) + decode_masked_optional(values, page_validity, filter, target) }, }?; @@ -150,22 +147,29 @@ fn decode_required( #[inline(never)] fn decode_optional( values: ArrayChunks<'_, B>, - validity: &Bitmap, + mut validity: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { - let num_values = validity.set_bits(); + target.reserve(validity.len()); + // Handle the leading and trailing zeros. This may allow dispatch to a faster kernel or + // possibly removes iterations from the lower kernel. + let num_leading_nulls = validity.take_leading_zeros(); + target.resize(target.len() + num_leading_nulls, B::zeroed()); + let num_trailing_nulls = validity.take_trailing_zeros(); + + // Dispatch to a faster kernel if possible. + let num_values = validity.set_bits(); if num_values == validity.len() { - return decode_required(values.truncate(validity.len()), target); + decode_required(values.truncate(validity.len()), target)?; + target.resize(target.len() + num_trailing_nulls, B::zeroed()); + return Ok(()); } - let mut limit = validity.len(); - assert!(num_values <= values.len()); let start_length = target.len(); - let end_length = target.len() + limit; - target.reserve(limit); + let end_length = target.len() + validity.len(); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; let mut validity_iter = validity.fast_iter_u56(); @@ -206,20 +210,22 @@ fn decode_optional( } }; + let mut num_remaining = validity.len(); for v in validity_iter.by_ref() { - if limit < 56 { - iter(v, limit); + if num_remaining < 56 { + iter(v, num_remaining); } else { iter(v, 56); } - limit -= 56; + num_remaining -= 56; } let (v, vl) = validity_iter.remainder(); - iter(v, vl.min(limit)); + iter(v, vl.min(num_remaining)); unsafe { target.set_len(end_length) }; + target.resize(target.len() + num_trailing_nulls, B::zeroed()); Ok(()) } @@ -227,11 +233,17 @@ fn decode_optional( #[inline(never)] fn decode_masked_required( values: ArrayChunks<'_, B>, - mask: &Bitmap, + mut mask: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { - let num_rows = mask.set_bits(); + // Remove leading or trailing filtered values. This may allow dispatch to a faster kernel or + // may remove iterations from the slower kernel below. + let num_leading_filtered = mask.take_leading_zeros(); + mask.take_trailing_zeros(); + let values = values.slice(num_leading_filtered, mask.len()); + // Dispatch to a faster kernel if possible. + let num_rows = mask.set_bits(); if num_rows == mask.len() { return decode_required(values.truncate(num_rows), target); } @@ -287,9 +299,7 @@ fn decode_masked_required( break; } } - let (f, fl) = mask_iter.remainder(); - iter(f, fl); unsafe { target.set_len(start_length + num_rows) }; @@ -300,17 +310,27 @@ fn decode_masked_required( #[inline(never)] fn decode_masked_optional( values: ArrayChunks<'_, B>, - validity: &Bitmap, - mask: &Bitmap, + mut validity: Bitmap, + mut mask: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { + assert_eq!(validity.len(), mask.len()); + + let num_leading_filtered = mask.take_leading_zeros(); + mask.take_trailing_zeros(); + let leading_validity; + (leading_validity, validity) = validity.split_at(num_leading_filtered); + validity.slice(0, mask.len()); + let num_rows = mask.set_bits(); let num_values = validity.set_bits(); + let values = values.slice(leading_validity.set_bits(), num_values); + + // Dispatch to a faster kernel if possible. if num_rows == mask.len() { return decode_optional(values, validity, target); } - if num_values == validity.len() { return decode_masked_required(values, mask, target); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs index d201db813628..33932223a04c 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs @@ -37,12 +37,17 @@ impl<'a, B: AlignedBytes> ArrayChunks<'a, B> { } } - pub unsafe fn slice_unchecked(&self, start: usize, end: usize) -> ArrayChunks<'a, B> { - debug_assert!(start <= self.bytes.len()); - debug_assert!(end <= self.bytes.len()); + pub fn slice(&self, start: usize, length: usize) -> ArrayChunks<'a, B> { + assert!(start <= self.bytes.len()); + assert!(start + length <= self.bytes.len()); + unsafe { self.slice_unchecked(start, length) } + } + pub unsafe fn slice_unchecked(&self, start: usize, length: usize) -> ArrayChunks<'a, B> { + debug_assert!(start <= self.bytes.len()); + debug_assert!(start + length <= self.bytes.len()); Self { - bytes: unsafe { self.bytes.get_unchecked(start..end) }, + bytes: unsafe { self.bytes.get_unchecked(start..start + length) }, } } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs index fdb6a7dab1d2..879bfd85615e 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs @@ -105,24 +105,24 @@ pub fn decode_dict_dispatch( values.limit_to(rng.end); decode_required_dict(values, dict, target) }, - (None, Some(page_validity)) => decode_optional_dict(values, dict, &page_validity, target), + (None, Some(page_validity)) => decode_optional_dict(values, dict, page_validity, target), (Some(Filter::Range(rng)), Some(page_validity)) if rng.start == 0 => { - decode_optional_dict(values, dict, &page_validity, target) + decode_optional_dict(values, dict, page_validity, target) }, (Some(Filter::Mask(filter)), None) => { - decode_masked_required_dict(values, dict, &filter, target) + decode_masked_required_dict(values, dict, filter, target) }, (Some(Filter::Mask(filter)), Some(page_validity)) => { - decode_masked_optional_dict(values, dict, &filter, &page_validity, target) + decode_masked_optional_dict(values, dict, filter, page_validity, target) }, (Some(Filter::Range(rng)), None) => { - decode_masked_required_dict(values, dict, &filter_from_range(rng.clone()), target) + decode_masked_required_dict(values, dict, filter_from_range(rng.clone()), target) }, (Some(Filter::Range(rng)), Some(page_validity)) => decode_masked_optional_dict( values, dict, - &filter_from_range(rng.clone()), - &page_validity, + filter_from_range(rng.clone()), + page_validity, target, ), }?; @@ -234,7 +234,7 @@ pub fn decode_required_dict( pub fn decode_optional_dict( mut values: HybridRleDecoder<'_>, dict: &[B], - validity: &Bitmap, + validity: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { let num_valid_values = validity.set_bits(); @@ -257,7 +257,7 @@ pub fn decode_optional_dict( let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; values.limit_to(num_valid_values); - let mut validity = BitMask::from_bitmap(validity); + let mut validity = BitMask::from_bitmap(&validity); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; @@ -413,8 +413,8 @@ pub fn decode_optional_dict( pub fn decode_masked_optional_dict( mut values: HybridRleDecoder<'_>, dict: &[B], - filter: &Bitmap, - validity: &Bitmap, + filter: Bitmap, + validity: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { let num_rows = filter.set_bits(); @@ -441,8 +441,8 @@ pub fn decode_masked_optional_dict( target.reserve(num_rows); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - let mut filter = BitMask::from_bitmap(filter); - let mut validity = BitMask::from_bitmap(validity); + let mut filter = BitMask::from_bitmap(&filter); + let mut validity = BitMask::from_bitmap(&validity); values.limit_to(num_valid_values); let mut values_buffer = [0u32; 128]; @@ -636,7 +636,7 @@ pub fn decode_masked_optional_dict( pub fn decode_masked_required_dict( mut values: HybridRleDecoder<'_>, dict: &[B], - filter: &Bitmap, + filter: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { let num_rows = filter.set_bits(); @@ -656,7 +656,7 @@ pub fn decode_masked_required_dict( target.reserve(num_rows); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - let mut filter = BitMask::from_bitmap(filter); + let mut filter = BitMask::from_bitmap(&filter); values.limit_to(filter.len()); let mut values_buffer = [0u32; 128];