diff --git a/crates/polars-arrow/src/bitmap/bitmask.rs b/crates/polars-arrow/src/bitmap/bitmask.rs index 2e4e45195266..637ae06a8e6b 100644 --- a/crates/polars-arrow/src/bitmap/bitmask.rs +++ b/crates/polars-arrow/src/bitmap/bitmask.rs @@ -4,7 +4,7 @@ use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount}; use polars_utils::slice::load_padded_le_u64; use super::iterator::FastU56BitmapIter; -use super::utils::{count_zeros, BitmapIter}; +use super::utils::{count_zeros, fmt, BitmapIter}; use crate::bitmap::Bitmap; /// Returns the nth set bit in w, if n+1 bits are set. The indexing is @@ -79,6 +79,15 @@ pub struct BitMask<'a> { len: usize, } +impl std::fmt::Debug for BitMask<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { bytes, offset, len } = self; + let offset_num_bytes = offset / 8; + let offset_in_byte = offset % 8; + fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f) + } +} + impl<'a> BitMask<'a> { pub fn from_bitmap(bitmap: &'a Bitmap) -> Self { let (bytes, offset, len) = bitmap.as_slice(); @@ -92,6 +101,13 @@ impl<'a> BitMask<'a> { self.len } + #[inline] + pub fn advance_by(&mut self, idx: usize) { + assert!(idx <= self.len); + self.offset += idx; + self.len -= idx; + } + #[inline] pub fn split_at(&self, idx: usize) -> (Self, Self) { assert!(idx <= self.len); diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs index f9b553bea397..a73e5e784ab5 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs @@ -8,11 +8,11 @@ use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::buffer::Buffer; use arrow::datatypes::{ArrowDataType, PhysicalType}; -use super::utils::dict_encoded::{append_validity, constrain_page_validity}; +use super::dictionary_encoded::{append_validity, constrain_page_validity}; use super::utils::{ dict_indices_decoder, filter_from_range, freeze_validity, unspecialized_decode, }; -use super::Filter; +use super::{dictionary_encoded, Filter}; use crate::parquet::encoding::{delta_byte_array, delta_length_byte_array, hybrid_rle, Encoding}; use crate::parquet::error::{ParquetError, ParquetResult}; use crate::parquet::page::{split_buffer, DataPage, DictPage}; @@ -521,7 +521,7 @@ impl utils::Decoder for BinViewDecoder { let start_length = decoded.0.views().len(); - utils::dict_encoded::decode_dict( + dictionary_encoded::decode_dict( indexes.clone(), dict, state.is_optional, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs index 11019b3ab614..56dda77cf37a 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs @@ -5,7 +5,7 @@ use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::datatypes::ArrowDataType; use polars_compute::filter::filter_boolean_kernel; -use super::utils::dict_encoded::{append_validity, constrain_page_validity}; +use super::dictionary_encoded::{append_validity, constrain_page_validity}; use super::utils::{ self, decode_hybrid_rle_into_bitmap, filter_from_range, freeze_validity, Decoder, ExactSize, }; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs new file mode 100644 index 000000000000..91ad273ee04b --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs @@ -0,0 +1,233 @@ +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::types::{AlignedBytes, NativeType}; +use polars_compute::filter::filter_boolean_kernel; + +use super::ParquetError; +use crate::parquet::encoding::hybrid_rle::HybridRleDecoder; +use crate::parquet::error::ParquetResult; +use crate::read::Filter; + +mod optional; +mod optional_masked_dense; +mod required; +mod required_masked_dense; + +pub fn decode_dict( + values: HybridRleDecoder<'_>, + dict: &[T], + is_optional: bool, + page_validity: Option<&Bitmap>, + filter: Option, + validity: &mut MutableBitmap, + target: &mut Vec, +) -> ParquetResult<()> { + decode_dict_dispatch( + values, + bytemuck::cast_slice(dict), + is_optional, + page_validity, + filter, + validity, + ::cast_vec_ref_mut(target), + ) +} + +#[inline(never)] +pub fn decode_dict_dispatch( + mut values: HybridRleDecoder<'_>, + dict: &[B], + is_optional: bool, + page_validity: Option<&Bitmap>, + filter: Option, + validity: &mut MutableBitmap, + target: &mut Vec, +) -> ParquetResult<()> { + if cfg!(debug_assertions) && is_optional { + assert_eq!(target.len(), validity.len()); + } + + if is_optional { + append_validity(page_validity, filter.as_ref(), validity, values.len()); + } + + let page_validity = constrain_page_validity(values.len(), page_validity, filter.as_ref()); + + match (filter, page_validity) { + (None, None) => required::decode(values, dict, target, 0), + (Some(Filter::Range(rng)), None) => { + values.limit_to(rng.end); + required::decode(values, dict, target, rng.start) + }, + (None, Some(page_validity)) => optional::decode(values, dict, page_validity, target, 0), + (Some(Filter::Range(rng)), Some(page_validity)) => { + optional::decode(values, dict, page_validity, target, rng.start) + }, + (Some(Filter::Mask(filter)), None) => { + required_masked_dense::decode(values, dict, filter, target) + }, + (Some(Filter::Mask(filter)), Some(page_validity)) => { + optional_masked_dense::decode(values, dict, filter, page_validity, target) + }, + }?; + + if cfg!(debug_assertions) && is_optional { + assert_eq!(target.len(), validity.len()); + } + + Ok(()) +} + +pub(crate) fn append_validity( + page_validity: Option<&Bitmap>, + filter: Option<&Filter>, + validity: &mut MutableBitmap, + values_len: usize, +) { + match (page_validity, filter) { + (None, None) => validity.extend_constant(values_len, true), + (None, Some(f)) => validity.extend_constant(f.num_rows(), true), + (Some(page_validity), None) => validity.extend_from_bitmap(page_validity), + (Some(page_validity), Some(Filter::Range(rng))) => { + let page_validity = page_validity.clone(); + validity.extend_from_bitmap(&page_validity.clone().sliced(rng.start, rng.len())) + }, + (Some(page_validity), Some(Filter::Mask(mask))) => { + validity.extend_from_bitmap(&filter_boolean_kernel(page_validity, mask)) + }, + } +} + +pub(crate) fn constrain_page_validity( + values_len: usize, + page_validity: Option<&Bitmap>, + filter: Option<&Filter>, +) -> Option { + let num_unfiltered_rows = match (filter.as_ref(), page_validity) { + (None, None) => values_len, + (None, Some(pv)) => { + debug_assert!(pv.len() >= values_len); + pv.len() + }, + (Some(f), v) => { + if cfg!(debug_assertions) { + if let Some(v) = v { + assert!(v.len() >= f.max_offset()); + } + } + + f.max_offset() + }, + }; + + page_validity.map(|pv| { + if pv.len() > num_unfiltered_rows { + pv.clone().sliced(0, num_unfiltered_rows) + } else { + pv.clone() + } + }) +} + +#[cold] +fn oob_dict_idx() -> ParquetError { + ParquetError::oos("Dictionary Index is out-of-bounds") +} + +#[cold] +fn no_more_bitpacked_values() -> ParquetError { + ParquetError::oos("Bitpacked Hybrid-RLE ran out before all values were served") +} + +#[inline(always)] +fn verify_dict_indices(indices: &[u32; 32], dict_size: usize) -> ParquetResult<()> { + let mut is_valid = true; + for &idx in indices { + is_valid &= (idx as usize) < dict_size; + } + + if is_valid { + return Ok(()); + } + + Err(oob_dict_idx()) +} + +#[inline(always)] +fn verify_dict_indices_slice(indices: &[u32], dict_size: usize) -> ParquetResult<()> { + let mut is_valid = true; + for &idx in indices { + is_valid &= (idx as usize) < dict_size; + } + + if is_valid { + return Ok(()); + } + + Err(oob_dict_idx()) +} + +/// Skip over entire chunks in a [`HybridRleDecoder`] as long as all skipped chunks do not include +/// more than `num_values_to_skip` values. +#[inline(always)] +fn required_skip_whole_chunks( + values: &mut HybridRleDecoder<'_>, + num_values_to_skip: &mut usize, +) -> ParquetResult<()> { + if *num_values_to_skip == 0 { + return Ok(()); + } + + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; + if *num_values_to_skip < chunk_len { + break; + } + *values = values_clone; + *num_values_to_skip -= chunk_len; + } + + Ok(()) +} + +/// Skip over entire chunks in a [`HybridRleDecoder`] as long as all skipped chunks do not include +/// more than `num_values_to_skip` values. +#[inline(always)] +fn optional_skip_whole_chunks( + values: &mut HybridRleDecoder<'_>, + validity: &mut BitMask<'_>, + num_rows_to_skip: &mut usize, + num_values_to_skip: &mut usize, +) -> ParquetResult<()> { + if *num_values_to_skip == 0 { + return Ok(()); + } + + let mut total_num_skipped_values = 0; + + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; + if *num_values_to_skip < chunk_len { + break; + } + *values = values_clone; + *num_values_to_skip -= chunk_len; + total_num_skipped_values += chunk_len; + } + + if total_num_skipped_values > 0 { + let offset = validity + .nth_set_bit_idx(total_num_skipped_values - 1, 0) + .map_or(validity.len(), |v| v + 1); + *num_rows_to_skip -= offset; + validity.advance_by(offset); + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs new file mode 100644 index 000000000000..f0546027549e --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -0,0 +1,200 @@ +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::Bitmap; +use arrow::types::AlignedBytes; + +use super::{ + no_more_bitpacked_values, oob_dict_idx, optional_skip_whole_chunks, verify_dict_indices, +}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +/// Decoding kernel for optional dictionary encoded. +#[inline(never)] +pub fn decode( + mut values: HybridRleDecoder<'_>, + dict: &[B], + mut validity: Bitmap, + target: &mut Vec, + mut num_rows_to_skip: usize, +) -> ParquetResult<()> { + debug_assert!(num_rows_to_skip <= validity.len()); + + let num_rows = validity.len() - num_rows_to_skip; + let end_length = target.len() + num_rows; + + target.reserve(num_rows); + + // Remove any leading and trailing nulls. This has two benefits: + // 1. It increases the chance of dispatching to the faster kernel (e.g. for sorted data) + // 2. It reduces the amount of iterations in the main loop and replaces it with `memset`s + let leading_nulls = validity.take_leading_zeros(); + let trailing_nulls = validity.take_trailing_zeros(); + + // Special case: all values are skipped, just add the trailing null. + if num_rows_to_skip >= leading_nulls + validity.len() { + target.resize(end_length, B::zeroed()); + return Ok(()); + } + + values.limit_to(validity.set_bits()); + + // Add the leading nulls + if num_rows_to_skip < leading_nulls { + target.resize(target.len() + leading_nulls - num_rows_to_skip, B::zeroed()); + num_rows_to_skip = 0; + } else { + num_rows_to_skip -= leading_nulls; + } + + if validity.set_bits() == validity.len() { + // Dispatch to the required kernel if all rows are valid anyway. + super::required::decode(values, dict, target, num_rows_to_skip)?; + } else { + if dict.is_empty() { + return Err(oob_dict_idx()); + } + + let mut num_values_to_skip = 0; + if num_rows_to_skip > 0 { + num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits(); + } + + let mut validity = BitMask::from_bitmap(&validity); + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; + + // Skip over any whole HybridRleChunks + optional_skip_whole_chunks( + &mut values, + &mut validity, + &mut num_rows_to_skip, + &mut num_values_to_skip, + )?; + + while let Some(chunk) = values.next_chunk()? { + debug_assert!(num_values_to_skip < chunk.len() || chunk.len() == 0); + + match chunk { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + // If we know that we have `size` times `value` that we can append, but there + // might be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. + // This is done with `nth_set_bit_idx` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + + let Some(&value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + + let num_chunk_rows = + validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + validity.advance_by(num_chunk_rows); + + target.resize(target.len() + num_chunk_rows - num_rows_to_skip, value); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let num_rows_for_decoder = validity + .nth_set_bit_idx(decoder.len(), 0) + .unwrap_or(validity.len()); + + let mut chunked = decoder.chunked(); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + + let mut decoder_validity; + (decoder_validity, validity) = validity.split_at(num_rows_for_decoder); + + // Skip over any remaining values. + if num_rows_to_skip > 0 { + decoder_validity.advance_by(num_rows_to_skip); + + chunked.decoder.skip_chunks(num_values_to_skip / 32); + num_values_to_skip %= 32; + + if num_values_to_skip > 0 { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let Some(num_added) = chunked.next_into(buffer_part) else { + return Err(no_more_bitpacked_values()); + }; + + debug_assert!(num_values_to_skip <= num_added); + verify_dict_indices(buffer_part, dict.len())?; + + values_offset += num_values_to_skip; + num_buffered += num_added - num_values_to_skip; + buffer_part_idx += 1; + } + } + + let mut iter = |v: u64, n: usize| { + while num_buffered < v.count_ones() as usize { + buffer_part_idx %= 4; + + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let Some(num_added) = chunked.next_into(buffer_part) else { + return Err(no_more_bitpacked_values()); + }; + + verify_dict_indices(buffer_part, dict.len())?; + + num_buffered += num_added; + + buffer_part_idx += 1; + } + + let mut num_read = 0; + + target.extend((0..n).map(|i| { + let idx = values_buffer[(values_offset + num_read) % 128]; + num_read += ((v >> i) & 1) as usize; + + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + *unsafe { dict.get_unchecked(idx as usize) } + })); + + values_offset += num_read; + values_offset %= 128; + num_buffered -= num_read; + + ParquetResult::Ok(()) + }; + + let mut v_iter = decoder_validity.fast_iter_u56(); + for v in v_iter.by_ref() { + iter(v, 56)?; + } + + let (v, vl) = v_iter.remainder(); + iter(v, vl)?; + }, + } + + num_rows_to_skip = 0; + num_values_to_skip = 0; + } + } + + // Add back the trailing nulls + debug_assert_eq!(target.len(), end_length - trailing_nulls); + target.resize(end_length, B::zeroed()); + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs new file mode 100644 index 000000000000..c779c3b61b0e --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs @@ -0,0 +1,219 @@ +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::Bitmap; +use arrow::types::AlignedBytes; + +use super::{oob_dict_idx, verify_dict_indices}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +#[inline(never)] +pub fn decode( + mut values: HybridRleDecoder<'_>, + dict: &[B], + mut filter: Bitmap, + mut validity: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + // @NOTE: We don't skip leading filtered values, because it is a bit more involved than the + // other kernels. We could probably do it anyway after having tried to dispatch to faster + // kernels, but we lose quite a bit of the potency with that. + filter.take_trailing_zeros(); + validity = validity.sliced(0, filter.len()); + + let num_rows = filter.set_bits(); + let num_valid_values = validity.set_bits(); + + assert_eq!(filter.len(), validity.len()); + assert!(num_valid_values <= values.len()); + + // Dispatch to the non-filter kernel if all rows are needed anyway. + if num_rows == filter.len() { + return super::optional::decode(values, dict, validity, target, 0); + } + + // Dispatch to the required kernel if all rows are valid anyway. + if num_valid_values == validity.len() { + return super::required_masked_dense::decode(values, dict, filter, target); + } + + if dict.is_empty() && num_valid_values > 0 { + return Err(oob_dict_idx()); + } + + target.reserve(num_rows); + + let end_length = target.len() + num_rows; + + let mut filter = BitMask::from_bitmap(&filter); + let mut validity = BitMask::from_bitmap(&validity); + + values.limit_to(num_valid_values); + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; + + let mut num_rows_left = num_rows; + + for chunk in values.into_chunk_iter() { + // Early stop if we have no more rows to load. + if num_rows_left == 0 { + break; + } + + match chunk? { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + // If we know that we have `size` times `value` that we can append, but there might + // be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is + // done with `num_bits_before_nth_one` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + + let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + + let current_filter; + (current_filter, filter) = filter.split_at(num_chunk_values); + validity.advance_by(num_chunk_values); + + let num_chunk_rows = current_filter.set_bits(); + + let Some(&value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + + target.resize(target.len() + num_chunk_rows, value); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + // For bitpacked we do the following: + // 1. See how many rows are encoded by this `decoder`. + // 2. Go through the filter and validity 56 bits at a time and: + // 0. If filter bits are 0, skip the chunk entirely. + // 1. Buffer enough values so that we can branchlessly decode with the filter + // and validity. + // 2. Decode with filter and validity. + // 3. Decode remainder. + + let size = decoder.len(); + let mut chunked = decoder.chunked(); + + let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + let mut skip_values = 0; + + let current_filter; + let current_validity; + + (current_filter, filter) = unsafe { filter.split_at_unchecked(num_chunk_values) }; + (current_validity, validity) = + unsafe { validity.split_at_unchecked(num_chunk_values) }; + + let mut iter = |mut f: u64, mut v: u64| { + // Skip chunk if we don't any values from here. + if f == 0 { + skip_values += v.count_ones() as usize; + return ParquetResult::Ok(()); + } + + // Skip over already buffered items. + let num_buffered_skipped = skip_values.min(num_buffered); + values_offset += num_buffered_skipped; + num_buffered -= num_buffered_skipped; + skip_values -= num_buffered_skipped; + + // If we skipped plenty already, just skip decoding those chunks instead of + // decoding them and throwing them away. + chunked.decoder.skip_chunks(skip_values / 32); + // The leftovers we have to decode but we can also just skip. + skip_values %= 32; + + while num_buffered < v.count_ones() as usize { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let num_added = chunked.next_into(buffer_part).unwrap(); + + verify_dict_indices(buffer_part, dict.len())?; + + let skip_chunk_values = skip_values.min(num_added); + + values_offset += skip_chunk_values; + num_buffered += num_added - skip_chunk_values; + skip_values -= skip_chunk_values; + + buffer_part_idx += 1; + buffer_part_idx %= 4; + } + + let mut num_read = 0; + let mut num_written = 0; + let target_ptr = unsafe { target.as_mut_ptr().add(target.len()) }; + + while f != 0 { + let offset = f.trailing_zeros(); + + num_read += (v & (1u64 << offset).wrapping_sub(1)).count_ones() as usize; + v >>= offset; + + let idx = values_buffer[(values_offset + num_read) % 128]; + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + let value = unsafe { dict.get_unchecked(idx as usize) }; + let value = *value; + unsafe { target_ptr.add(num_written).write(value) }; + + num_written += 1; + num_read += (v & 1) as usize; + + f >>= offset + 1; // Clear least significant bit. + v >>= 1; + } + + num_read += v.count_ones() as usize; + + values_offset += num_read; + values_offset %= 128; + num_buffered -= num_read; + unsafe { + target.set_len(target.len() + num_written); + } + num_rows_left -= num_written; + + ParquetResult::Ok(()) + }; + + let mut f_iter = current_filter.fast_iter_u56(); + let mut v_iter = current_validity.fast_iter_u56(); + + for (f, v) in f_iter.by_ref().zip(v_iter.by_ref()) { + iter(f, v)?; + } + + let (f, fl) = f_iter.remainder(); + let (v, vl) = v_iter.remainder(); + + assert_eq!(fl, vl); + + iter(f, v)?; + }, + } + } + + if cfg!(debug_assertions) { + assert_eq!(validity.set_bits(), 0); + } + + target.resize(end_length, B::zeroed()); + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs new file mode 100644 index 000000000000..a7053944d130 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -0,0 +1,90 @@ +use arrow::types::AlignedBytes; + +use super::{ + oob_dict_idx, required_skip_whole_chunks, verify_dict_indices, verify_dict_indices_slice, +}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +/// Decoding kernel for required dictionary encoded. +#[inline(never)] +pub fn decode( + mut values: HybridRleDecoder<'_>, + dict: &[B], + target: &mut Vec, + mut num_rows_to_skip: usize, +) -> ParquetResult<()> { + debug_assert!(num_rows_to_skip <= values.len()); + + let num_rows = values.len() - num_rows_to_skip; + let end_length = target.len() + num_rows; + + if num_rows == 0 { + return Ok(()); + } + + target.reserve(num_rows); + + if dict.is_empty() { + return Err(oob_dict_idx()); + } + + // Skip over whole HybridRleChunks + required_skip_whole_chunks(&mut values, &mut num_rows_to_skip)?; + + while let Some(chunk) = values.next_chunk()? { + debug_assert!(num_rows_to_skip < chunk.len() || chunk.len() == 0); + + match chunk { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + let Some(&value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + + target.resize(target.len() + size - num_rows_to_skip, value); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + if num_rows_to_skip > 0 { + decoder.skip_chunks(num_rows_to_skip / 32); + num_rows_to_skip %= 32; + + if let Some((chunk, chunk_size)) = decoder.chunked().next_inexact() { + let chunk = &chunk[num_rows_to_skip..chunk_size]; + verify_dict_indices_slice(chunk, dict.len())?; + target.extend(chunk.iter().map(|&idx| { + // SAFETY: The dict indices were verified before. + *unsafe { dict.get_unchecked(idx as usize) } + })); + } + } + + let mut chunked = decoder.chunked(); + for chunk in chunked.by_ref() { + verify_dict_indices(&chunk, dict.len())?; + target.extend(chunk.iter().map(|&idx| { + // SAFETY: The dict indices were verified before. + *unsafe { dict.get_unchecked(idx as usize) } + })); + } + + if let Some((chunk, chunk_size)) = chunked.remainder() { + verify_dict_indices_slice(&chunk[..chunk_size], dict.len())?; + target.extend(chunk[..chunk_size].iter().map(|&idx| { + // SAFETY: The dict indices were verified before. + *unsafe { dict.get_unchecked(idx as usize) } + })); + } + }, + } + + num_rows_to_skip = 0; + } + + debug_assert_eq!(target.len(), end_length); + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs new file mode 100644 index 000000000000..04aed3dbfa1e --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -0,0 +1,177 @@ +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::Bitmap; +use arrow::types::AlignedBytes; + +use super::{oob_dict_idx, required_skip_whole_chunks, verify_dict_indices}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +#[inline(never)] +pub fn decode( + mut values: HybridRleDecoder<'_>, + dict: &[B], + mut filter: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + assert!(values.len() >= filter.len()); + + let mut num_rows_to_skip = filter.take_leading_zeros(); + filter.take_trailing_zeros(); + + let num_rows = filter.set_bits(); + + values.limit_to(num_rows_to_skip + filter.len()); + + // Dispatch to the non-filter kernel if all rows are needed anyway. + if num_rows == filter.len() { + return super::required::decode(values, dict, target, num_rows_to_skip); + } + + if dict.is_empty() && !filter.is_empty() { + return Err(oob_dict_idx()); + } + + target.reserve(num_rows); + + let mut filter = BitMask::from_bitmap(&filter); + + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; + + // Skip over whole HybridRleChunks + required_skip_whole_chunks(&mut values, &mut num_rows_to_skip)?; + + while let Some(chunk) = values.next_chunk()? { + debug_assert!(num_rows_to_skip < chunk.len() || chunk.len() == 0); + + match chunk { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + // If we know that we have `size` times `value` that we can append, but there might + // be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is + // done with `num_bits_before_nth_one` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + + let current_filter; + + (current_filter, filter) = filter.split_at(size - num_rows_to_skip); + let num_chunk_rows = current_filter.set_bits(); + + if num_chunk_rows > 0 { + let Some(&value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + + target.resize(target.len() + num_chunk_rows, value); + } + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let size = decoder.len().min(filter.len()); + let mut chunked = decoder.chunked(); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + let mut skip_values = num_rows_to_skip; + + let current_filter; + + (current_filter, filter) = filter.split_at(size); + + let mut iter = |mut f: u64, len: usize| { + debug_assert!(len <= 64); + + // Skip chunk if we don't any values from here. + if f == 0 { + skip_values += len; + return ParquetResult::Ok(()); + } + + // Skip over already buffered items. + let num_buffered_skipped = skip_values.min(num_buffered); + values_offset += num_buffered_skipped; + num_buffered -= num_buffered_skipped; + skip_values -= num_buffered_skipped; + + // If we skipped plenty already, just skip decoding those chunks instead of + // decoding them and throwing them away. + chunked.decoder.skip_chunks(skip_values / 32); + // The leftovers we have to decode but we can also just skip. + skip_values %= 32; + + while num_buffered < len { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let num_added = chunked.next_into(buffer_part).unwrap(); + + verify_dict_indices(buffer_part, dict.len())?; + + let skip_chunk_values = skip_values.min(num_added); + + values_offset += skip_chunk_values; + num_buffered += num_added - skip_chunk_values; + skip_values -= skip_chunk_values; + + buffer_part_idx += 1; + buffer_part_idx %= 4; + } + + let mut num_read = 0; + let mut num_written = 0; + let target_ptr = unsafe { target.as_mut_ptr().add(target.len()) }; + + while f != 0 { + let offset = f.trailing_zeros() as usize; + + num_read += offset; + + let idx = values_buffer[(values_offset + num_read) % 128]; + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + let value = *unsafe { dict.get_unchecked(idx as usize) }; + unsafe { target_ptr.add(num_written).write(value) }; + + num_written += 1; + num_read += 1; + + f >>= offset + 1; // Clear least significant bit. + } + + values_offset += len; + values_offset %= 128; + num_buffered -= len; + unsafe { + target.set_len(target.len() + num_written); + } + + ParquetResult::Ok(()) + }; + + let mut f_iter = current_filter.fast_iter_u56(); + + for f in f_iter.by_ref() { + iter(f, 56)?; + } + + let (f, fl) = f_iter.remainder(); + + iter(f, fl)?; + }, + } + + num_rows_to_skip = 0; + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs index 7ae39b3366ad..d968e369b7c0 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs @@ -10,16 +10,16 @@ use arrow::types::{ Bytes4Alignment4, Bytes8Alignment8, }; +use super::dictionary_encoded::append_validity; use super::utils::array_chunks::ArrayChunks; -use super::utils::dict_encoded::append_validity; use super::utils::{dict_indices_decoder, freeze_validity, Decoder}; use super::Filter; use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; use crate::parquet::encoding::{hybrid_rle, Encoding}; use crate::parquet::error::{ParquetError, ParquetResult}; use crate::parquet::page::{split_buffer, DataPage, DictPage}; +use crate::read::deserialize::dictionary_encoded::constrain_page_validity; use crate::read::deserialize::utils; -use crate::read::deserialize::utils::dict_encoded::constrain_page_validity; #[allow(clippy::large_enum_variant)] #[derive(Debug)] @@ -271,7 +271,7 @@ fn decode_fsb_dict( macro_rules! decode_static_size { ($dict:ident, $target:ident) => {{ - super::utils::dict_encoded::decode_dict_dispatch( + super::dictionary_encoded::decode_dict_dispatch( values, $dict, is_optional, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs index 3bc1beb30973..77089f1da486 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs @@ -3,6 +3,7 @@ mod binview; mod boolean; mod dictionary; +mod dictionary_encoded; mod fixed_size_binary; mod nested; mod nested_utils; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs index 225738b0c1fd..e5961e4baffe 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs @@ -12,6 +12,7 @@ use crate::parquet::encoding::{byte_stream_split, hybrid_rle, Encoding}; use crate::parquet::error::ParquetResult; use crate::parquet::page::{split_buffer, DataPage, DictPage}; use crate::parquet::types::{decode, NativeType as ParquetNativeType}; +use crate::read::deserialize::dictionary_encoded; use crate::read::deserialize::utils::{ dict_indices_decoder, freeze_validity, unspecialized_decode, }; @@ -170,7 +171,7 @@ where &mut decoded.0, self.0.decoder, ), - StateTranslation::Dictionary(ref mut indexes) => utils::dict_encoded::decode_dict( + StateTranslation::Dictionary(ref mut indexes) => dictionary_encoded::decode_dict( indexes.clone(), state.dict.unwrap(), state.is_optional, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs index 087fc1c447d5..5eb4b438741b 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs @@ -12,6 +12,7 @@ use crate::parquet::encoding::{byte_stream_split, delta_bitpacked, hybrid_rle, E use crate::parquet::error::ParquetResult; use crate::parquet::page::{split_buffer, DataPage, DictPage}; use crate::parquet::types::{decode, NativeType as ParquetNativeType}; +use crate::read::deserialize::dictionary_encoded; use crate::read::deserialize::utils::{ dict_indices_decoder, freeze_validity, unspecialized_decode, }; @@ -201,7 +202,7 @@ where &mut decoded.0, self.0.decoder, ), - StateTranslation::Dictionary(ref mut indexes) => utils::dict_encoded::decode_dict( + StateTranslation::Dictionary(ref mut indexes) => dictionary_encoded::decode_dict( indexes.clone(), state.dict.unwrap(), state.is_optional, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs index 909e68f5af37..fd7c8833e3c3 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs @@ -5,8 +5,8 @@ use arrow::types::{AlignedBytes, NativeType}; use super::DecoderFunction; use crate::parquet::error::ParquetResult; use crate::parquet::types::NativeType as ParquetNativeType; +use crate::read::deserialize::dictionary_encoded::{append_validity, constrain_page_validity}; use crate::read::deserialize::utils::array_chunks::ArrayChunks; -use crate::read::deserialize::utils::dict_encoded::{append_validity, constrain_page_validity}; use crate::read::{Filter, ParquetError}; #[allow(clippy::too_many_arguments)] diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs deleted file mode 100644 index 879bfd85615e..000000000000 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs +++ /dev/null @@ -1,818 +0,0 @@ -use arrow::bitmap::bitmask::BitMask; -use arrow::bitmap::{Bitmap, MutableBitmap}; -use arrow::types::{AlignedBytes, NativeType}; -use polars_compute::filter::filter_boolean_kernel; - -use super::filter_from_range; -use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; -use crate::parquet::error::ParquetResult; -use crate::read::{Filter, ParquetError}; - -pub fn decode_dict( - values: HybridRleDecoder<'_>, - dict: &[T], - is_optional: bool, - page_validity: Option<&Bitmap>, - filter: Option, - validity: &mut MutableBitmap, - target: &mut Vec, -) -> ParquetResult<()> { - decode_dict_dispatch( - values, - bytemuck::cast_slice(dict), - is_optional, - page_validity, - filter, - validity, - ::cast_vec_ref_mut(target), - ) -} - -pub(crate) fn append_validity( - page_validity: Option<&Bitmap>, - filter: Option<&Filter>, - validity: &mut MutableBitmap, - values_len: usize, -) { - match (page_validity, filter) { - (None, None) => validity.extend_constant(values_len, true), - (None, Some(f)) => validity.extend_constant(f.num_rows(), true), - (Some(page_validity), None) => validity.extend_from_bitmap(page_validity), - (Some(page_validity), Some(Filter::Range(rng))) => { - let page_validity = page_validity.clone(); - validity.extend_from_bitmap(&page_validity.clone().sliced(rng.start, rng.len())) - }, - (Some(page_validity), Some(Filter::Mask(mask))) => { - validity.extend_from_bitmap(&filter_boolean_kernel(page_validity, mask)) - }, - } -} - -pub(crate) fn constrain_page_validity( - values_len: usize, - page_validity: Option<&Bitmap>, - filter: Option<&Filter>, -) -> Option { - let num_unfiltered_rows = match (filter.as_ref(), page_validity) { - (None, None) => values_len, - (None, Some(pv)) => { - debug_assert!(pv.len() >= values_len); - pv.len() - }, - (Some(f), v) => { - if cfg!(debug_assertions) { - if let Some(v) = v { - assert!(v.len() >= f.max_offset()); - } - } - - f.max_offset() - }, - }; - - page_validity.map(|pv| { - if pv.len() > num_unfiltered_rows { - pv.clone().sliced(0, num_unfiltered_rows) - } else { - pv.clone() - } - }) -} - -#[inline(never)] -pub fn decode_dict_dispatch( - mut values: HybridRleDecoder<'_>, - dict: &[B], - is_optional: bool, - page_validity: Option<&Bitmap>, - filter: Option, - validity: &mut MutableBitmap, - target: &mut Vec, -) -> ParquetResult<()> { - if cfg!(debug_assertions) && is_optional { - assert_eq!(target.len(), validity.len()); - } - - if is_optional { - append_validity(page_validity, filter.as_ref(), validity, values.len()); - } - - let page_validity = constrain_page_validity(values.len(), page_validity, filter.as_ref()); - - match (filter, page_validity) { - (None, None) => decode_required_dict(values, dict, target), - (Some(Filter::Range(rng)), None) if rng.start == 0 => { - values.limit_to(rng.end); - decode_required_dict(values, dict, target) - }, - (None, Some(page_validity)) => decode_optional_dict(values, dict, page_validity, target), - (Some(Filter::Range(rng)), Some(page_validity)) if rng.start == 0 => { - decode_optional_dict(values, dict, page_validity, target) - }, - (Some(Filter::Mask(filter)), None) => { - decode_masked_required_dict(values, dict, filter, target) - }, - (Some(Filter::Mask(filter)), Some(page_validity)) => { - decode_masked_optional_dict(values, dict, filter, page_validity, target) - }, - (Some(Filter::Range(rng)), None) => { - decode_masked_required_dict(values, dict, filter_from_range(rng.clone()), target) - }, - (Some(Filter::Range(rng)), Some(page_validity)) => decode_masked_optional_dict( - values, - dict, - filter_from_range(rng.clone()), - page_validity, - target, - ), - }?; - - if cfg!(debug_assertions) && is_optional { - assert_eq!(target.len(), validity.len()); - } - - Ok(()) -} - -#[cold] -fn oob_dict_idx() -> ParquetError { - ParquetError::oos("Dictionary Index is out-of-bounds") -} - -#[inline(always)] -fn verify_dict_indices(indices: &[u32; 32], dict_size: usize) -> ParquetResult<()> { - let mut is_valid = true; - for &idx in indices { - is_valid &= (idx as usize) < dict_size; - } - - if is_valid { - return Ok(()); - } - - Err(oob_dict_idx()) -} - -#[inline(never)] -pub fn decode_required_dict( - mut values: HybridRleDecoder<'_>, - dict: &[B], - target: &mut Vec, -) -> ParquetResult<()> { - if dict.is_empty() && values.len() > 0 { - return Err(oob_dict_idx()); - } - - let start_length = target.len(); - let end_length = start_length + values.len(); - - target.reserve(values.len()); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - - while values.len() > 0 { - let chunk = values.next_chunk()?.unwrap(); - - match chunk { - HybridRleChunk::Rle(value, length) => { - if length == 0 { - continue; - } - - let target_slice; - // SAFETY: - // 1. `target_ptr..target_ptr + values.len()` is allocated - // 2. `length <= limit` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, length); - target_ptr = target_ptr.add(length); - } - - let Some(&value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; - - target_slice.fill(value); - }, - HybridRleChunk::Bitpacked(mut decoder) => { - let mut chunked = decoder.chunked(); - for chunk in chunked.by_ref() { - verify_dict_indices(&chunk, dict.len())?; - - for (i, &idx) in chunk.iter().enumerate() { - unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; - } - unsafe { - target_ptr = target_ptr.add(32); - } - } - - if let Some((chunk, chunk_size)) = chunked.remainder() { - let highest_idx = chunk[..chunk_size].iter().copied().max().unwrap(); - if highest_idx as usize >= dict.len() { - return Err(oob_dict_idx()); - } - - for (i, &idx) in chunk[..chunk_size].iter().enumerate() { - unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; - } - unsafe { - target_ptr = target_ptr.add(chunk_size); - } - } - }, - } - } - - unsafe { - target.set_len(end_length); - } - - Ok(()) -} - -#[inline(never)] -pub fn decode_optional_dict( - mut values: HybridRleDecoder<'_>, - dict: &[B], - validity: Bitmap, - target: &mut Vec, -) -> ParquetResult<()> { - let num_valid_values = validity.set_bits(); - - // Dispatch to the required kernel if all rows are valid anyway. - if num_valid_values == validity.len() { - values.limit_to(validity.len()); - return decode_required_dict(values, dict, target); - } - - if dict.is_empty() && num_valid_values > 0 { - return Err(oob_dict_idx()); - } - - assert!(num_valid_values <= values.len()); - let start_length = target.len(); - let end_length = start_length + validity.len(); - - target.reserve(validity.len()); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - - values.limit_to(num_valid_values); - let mut validity = BitMask::from_bitmap(&validity); - let mut values_buffer = [0u32; 128]; - let values_buffer = &mut values_buffer; - - for chunk in values.into_chunk_iter() { - match chunk? { - HybridRleChunk::Rle(value, size) => { - if size == 0 { - continue; - } - - // If we know that we have `size` times `value` that we can append, but there might - // be nulls in between those values. - // - // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is - // done with `num_bits_before_nth_one` on the validity mask. - // 2. Fill `num_rows` values into the target buffer. - // 3. Advance the validity mask by `num_rows` values. - - let num_chunk_rows = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); - - (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_rows) }; - - let Some(&value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; - - let target_slice; - // SAFETY: - // Given `validity_iter` before the `advance_by_bits` - // - // 1. `target_ptr..target_ptr + validity_iter.bits_left()` is allocated - // 2. `num_chunk_rows <= validity_iter.bits_left()` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); - target_ptr = target_ptr.add(num_chunk_rows); - } - - target_slice.fill(value); - }, - HybridRleChunk::Bitpacked(mut decoder) => { - let mut chunked = decoder.chunked(); - - let mut buffer_part_idx = 0; - let mut values_offset = 0; - let mut num_buffered: usize = 0; - - { - let mut num_done = 0; - let mut validity_iter = validity.fast_iter_u56(); - - 'outer: for v in validity_iter.by_ref() { - while num_buffered < v.count_ones() as usize { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let Some(num_added) = chunked.next_into(buffer_part) else { - break 'outer; - }; - - verify_dict_indices(buffer_part, dict.len())?; - - num_buffered += num_added; - - buffer_part_idx += 1; - buffer_part_idx %= 4; - } - - let mut num_read = 0; - - for i in 0..56 { - let idx = values_buffer[(values_offset + num_read) % 128]; - - // SAFETY: - // 1. `values_buffer` starts out as only zeros, which we know is in the - // dictionary following the original `dict.is_empty` check. - // 2. Each time we write to `values_buffer`, it is followed by a - // `verify_dict_indices`. - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { target_ptr.add(i).write(value) }; - num_read += ((v >> i) & 1) as usize; - } - - values_offset += num_read; - values_offset %= 128; - num_buffered -= num_read; - unsafe { - target_ptr = target_ptr.add(56); - } - num_done += 56; - } - - (_, validity) = unsafe { validity.split_at_unchecked(num_done) }; - } - - let num_decoder_remaining = num_buffered + chunked.decoder.len(); - let decoder_limit = validity - .nth_set_bit_idx(num_decoder_remaining, 0) - .unwrap_or(validity.len()); - - let current_validity; - (current_validity, validity) = - unsafe { validity.split_at_unchecked(decoder_limit) }; - let (v, _) = current_validity.fast_iter_u56().remainder(); - - while num_buffered < v.count_ones() as usize { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let num_added = chunked.next_into(buffer_part).unwrap(); - - verify_dict_indices(buffer_part, dict.len())?; - - num_buffered += num_added; - - buffer_part_idx += 1; - buffer_part_idx %= 4; - } - - let mut num_read = 0; - - for i in 0..decoder_limit { - let idx = values_buffer[(values_offset + num_read) % 128]; - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { *target_ptr.add(i) = value }; - num_read += ((v >> i) & 1) as usize; - } - - unsafe { - target_ptr = target_ptr.add(decoder_limit); - } - }, - } - } - - if cfg!(debug_assertions) { - assert_eq!(validity.set_bits(), 0); - } - - let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, validity.len()) }; - target_slice.fill(B::zeroed()); - unsafe { - target.set_len(end_length); - } - - Ok(()) -} - -#[inline(never)] -pub fn decode_masked_optional_dict( - mut values: HybridRleDecoder<'_>, - dict: &[B], - filter: Bitmap, - validity: Bitmap, - target: &mut Vec, -) -> ParquetResult<()> { - let num_rows = filter.set_bits(); - let num_valid_values = validity.set_bits(); - - // Dispatch to the non-filter kernel if all rows are needed anyway. - if num_rows == filter.len() { - return decode_optional_dict(values, dict, validity, target); - } - - // Dispatch to the required kernel if all rows are valid anyway. - if num_valid_values == validity.len() { - return decode_masked_required_dict(values, dict, filter, target); - } - - if dict.is_empty() && num_valid_values > 0 { - return Err(oob_dict_idx()); - } - - debug_assert_eq!(filter.len(), validity.len()); - assert!(num_valid_values <= values.len()); - let start_length = target.len(); - - target.reserve(num_rows); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - - let mut filter = BitMask::from_bitmap(&filter); - let mut validity = BitMask::from_bitmap(&validity); - - values.limit_to(num_valid_values); - let mut values_buffer = [0u32; 128]; - let values_buffer = &mut values_buffer; - - let mut num_rows_left = num_rows; - - for chunk in values.into_chunk_iter() { - // Early stop if we have no more rows to load. - if num_rows_left == 0 { - break; - } - - match chunk? { - HybridRleChunk::Rle(value, size) => { - if size == 0 { - continue; - } - - // If we know that we have `size` times `value` that we can append, but there might - // be nulls in between those values. - // - // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is - // done with `num_bits_before_nth_one` on the validity mask. - // 2. Fill `num_rows` values into the target buffer. - // 3. Advance the validity mask by `num_rows` values. - - let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); - - let current_filter; - (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_values) }; - (current_filter, filter) = unsafe { filter.split_at_unchecked(num_chunk_values) }; - - let num_chunk_rows = current_filter.set_bits(); - - if num_chunk_rows > 0 { - let target_slice; - // SAFETY: - // Given `filter_iter` before the `advance_by_bits`. - // - // 1. `target_ptr..target_ptr + filter_iter.count_ones()` is allocated - // 2. `num_chunk_rows < filter_iter.count_ones()` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); - target_ptr = target_ptr.add(num_chunk_rows); - } - - let Some(value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; - - target_slice.fill(*value); - num_rows_left -= num_chunk_rows; - } - }, - HybridRleChunk::Bitpacked(mut decoder) => { - // For bitpacked we do the following: - // 1. See how many rows are encoded by this `decoder`. - // 2. Go through the filter and validity 56 bits at a time and: - // 0. If filter bits are 0, skip the chunk entirely. - // 1. Buffer enough values so that we can branchlessly decode with the filter - // and validity. - // 2. Decode with filter and validity. - // 3. Decode remainder. - - let size = decoder.len(); - let mut chunked = decoder.chunked(); - - let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); - - let mut buffer_part_idx = 0; - let mut values_offset = 0; - let mut num_buffered: usize = 0; - let mut skip_values = 0; - - let current_filter; - let current_validity; - - (current_filter, filter) = unsafe { filter.split_at_unchecked(num_chunk_values) }; - (current_validity, validity) = - unsafe { validity.split_at_unchecked(num_chunk_values) }; - - let mut iter = |mut f: u64, mut v: u64| { - // Skip chunk if we don't any values from here. - if f == 0 { - skip_values += v.count_ones() as usize; - return ParquetResult::Ok(()); - } - - // Skip over already buffered items. - let num_buffered_skipped = skip_values.min(num_buffered); - values_offset += num_buffered_skipped; - num_buffered -= num_buffered_skipped; - skip_values -= num_buffered_skipped; - - // If we skipped plenty already, just skip decoding those chunks instead of - // decoding them and throwing them away. - chunked.decoder.skip_chunks(skip_values / 32); - // The leftovers we have to decode but we can also just skip. - skip_values %= 32; - - while num_buffered < v.count_ones() as usize { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let num_added = chunked.next_into(buffer_part).unwrap(); - - verify_dict_indices(buffer_part, dict.len())?; - - let skip_chunk_values = skip_values.min(num_added); - - values_offset += skip_chunk_values; - num_buffered += num_added - skip_chunk_values; - skip_values -= skip_chunk_values; - - buffer_part_idx += 1; - buffer_part_idx %= 4; - } - - let mut num_read = 0; - let mut num_written = 0; - - while f != 0 { - let offset = f.trailing_zeros(); - - num_read += (v & (1u64 << offset).wrapping_sub(1)).count_ones() as usize; - v >>= offset; - - let idx = values_buffer[(values_offset + num_read) % 128]; - // SAFETY: - // 1. `values_buffer` starts out as only zeros, which we know is in the - // dictionary following the original `dict.is_empty` check. - // 2. Each time we write to `values_buffer`, it is followed by a - // `verify_dict_indices`. - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { target_ptr.add(num_written).write(value) }; - - num_written += 1; - num_read += (v & 1) as usize; - - f >>= offset + 1; // Clear least significant bit. - v >>= 1; - } - - num_read += v.count_ones() as usize; - - values_offset += num_read; - values_offset %= 128; - num_buffered -= num_read; - unsafe { - target_ptr = target_ptr.add(num_written); - } - num_rows_left -= num_written; - - ParquetResult::Ok(()) - }; - - let mut f_iter = current_filter.fast_iter_u56(); - let mut v_iter = current_validity.fast_iter_u56(); - - for (f, v) in f_iter.by_ref().zip(v_iter.by_ref()) { - iter(f, v)?; - } - - let (f, fl) = f_iter.remainder(); - let (v, vl) = v_iter.remainder(); - - assert_eq!(fl, vl); - - iter(f, v)?; - }, - } - } - - if cfg!(debug_assertions) { - assert_eq!(validity.set_bits(), 0); - } - - let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, num_rows_left) }; - target_slice.fill(B::zeroed()); - unsafe { - target.set_len(start_length + num_rows); - } - - Ok(()) -} - -#[inline(never)] -pub fn decode_masked_required_dict( - mut values: HybridRleDecoder<'_>, - dict: &[B], - filter: Bitmap, - target: &mut Vec, -) -> ParquetResult<()> { - let num_rows = filter.set_bits(); - - // Dispatch to the non-filter kernel if all rows are needed anyway. - if num_rows == filter.len() { - values.limit_to(filter.len()); - return decode_required_dict(values, dict, target); - } - - if dict.is_empty() && !filter.is_empty() { - return Err(oob_dict_idx()); - } - - let start_length = target.len(); - - target.reserve(num_rows); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - - let mut filter = BitMask::from_bitmap(&filter); - - values.limit_to(filter.len()); - let mut values_buffer = [0u32; 128]; - let values_buffer = &mut values_buffer; - - let mut num_rows_left = num_rows; - - for chunk in values.into_chunk_iter() { - if num_rows_left == 0 { - break; - } - - match chunk? { - HybridRleChunk::Rle(value, size) => { - if size == 0 { - continue; - } - - let size = size.min(filter.len()); - - // If we know that we have `size` times `value` that we can append, but there might - // be nulls in between those values. - // - // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is - // done with `num_bits_before_nth_one` on the validity mask. - // 2. Fill `num_rows` values into the target buffer. - // 3. Advance the validity mask by `num_rows` values. - - let current_filter; - - (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; - let num_chunk_rows = current_filter.set_bits(); - - if num_chunk_rows > 0 { - let target_slice; - // SAFETY: - // Given `filter_iter` before the `advance_by_bits`. - // - // 1. `target_ptr..target_ptr + filter_iter.count_ones()` is allocated - // 2. `num_chunk_rows < filter_iter.count_ones()` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); - target_ptr = target_ptr.add(num_chunk_rows); - } - - let Some(value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; - - target_slice.fill(*value); - num_rows_left -= num_chunk_rows; - } - }, - HybridRleChunk::Bitpacked(mut decoder) => { - let size = decoder.len().min(filter.len()); - let mut chunked = decoder.chunked(); - - let mut buffer_part_idx = 0; - let mut values_offset = 0; - let mut num_buffered: usize = 0; - let mut skip_values = 0; - - let current_filter; - - (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; - - let mut iter = |mut f: u64, len: usize| { - debug_assert!(len <= 64); - - // Skip chunk if we don't any values from here. - if f == 0 { - skip_values += len; - return ParquetResult::Ok(()); - } - - // Skip over already buffered items. - let num_buffered_skipped = skip_values.min(num_buffered); - values_offset += num_buffered_skipped; - num_buffered -= num_buffered_skipped; - skip_values -= num_buffered_skipped; - - // If we skipped plenty already, just skip decoding those chunks instead of - // decoding them and throwing them away. - chunked.decoder.skip_chunks(skip_values / 32); - // The leftovers we have to decode but we can also just skip. - skip_values %= 32; - - while num_buffered < len { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let num_added = chunked.next_into(buffer_part).unwrap(); - - verify_dict_indices(buffer_part, dict.len())?; - - let skip_chunk_values = skip_values.min(num_added); - - values_offset += skip_chunk_values; - num_buffered += num_added - skip_chunk_values; - skip_values -= skip_chunk_values; - - buffer_part_idx += 1; - buffer_part_idx %= 4; - } - - let mut num_read = 0; - let mut num_written = 0; - - while f != 0 { - let offset = f.trailing_zeros() as usize; - - num_read += offset; - - let idx = values_buffer[(values_offset + num_read) % 128]; - // SAFETY: - // 1. `values_buffer` starts out as only zeros, which we know is in the - // dictionary following the original `dict.is_empty` check. - // 2. Each time we write to `values_buffer`, it is followed by a - // `verify_dict_indices`. - let value = *unsafe { dict.get_unchecked(idx as usize) }; - unsafe { target_ptr.add(num_written).write(value) }; - - num_written += 1; - num_read += 1; - - f >>= offset + 1; // Clear least significant bit. - } - - values_offset += len; - values_offset %= 128; - num_buffered -= len; - unsafe { - target_ptr = target_ptr.add(num_written); - } - num_rows_left -= num_written; - - ParquetResult::Ok(()) - }; - - let mut f_iter = current_filter.fast_iter_u56(); - - for f in f_iter.by_ref() { - iter(f, 56)?; - } - - let (f, fl) = f_iter.remainder(); - - iter(f, fl)?; - }, - } - } - - unsafe { - target.set_len(start_length + num_rows); - } - - Ok(()) -} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs index 960c11d75c12..191d132a52ae 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs @@ -1,5 +1,4 @@ pub(crate) mod array_chunks; -pub(crate) mod dict_encoded; pub(crate) mod filter; use std::ops::Range; diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs index 72e7b82cc4ad..b77b32a05ca1 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs @@ -43,6 +43,16 @@ impl<'a> Iterator for HybridRleChunkIter<'a> { } } +impl HybridRleChunk<'_> { + #[inline] + pub fn len(&self) -> usize { + match self { + HybridRleChunk::Rle(_, size) => *size, + HybridRleChunk::Bitpacked(decoder) => decoder.len(), + } + } +} + impl<'a> HybridRleDecoder<'a> { /// Returns a new [`HybridRleDecoder`] pub fn new(data: &'a [u8], num_bits: u32, num_values: usize) -> Self { @@ -122,6 +132,54 @@ impl<'a> HybridRleDecoder<'a> { })) } + pub fn next_chunk_length(&mut self) -> ParquetResult> { + if self.len() == 0 { + return Ok(None); + } + + if self.num_bits == 0 { + let num_values = self.num_values; + self.num_values = 0; + return Ok(Some(num_values)); + } + + if self.data.is_empty() { + return Ok(None); + } + + let (indicator, consumed) = uleb128::decode(self.data); + self.data = unsafe { self.data.get_unchecked(consumed..) }; + + Ok(Some(if indicator & 1 == 1 { + // is bitpacking + let bytes = (indicator as usize >> 1) * self.num_bits; + let bytes = std::cmp::min(bytes, self.data.len()); + let Some((packed, remaining)) = self.data.split_at_checked(bytes) else { + return Err(ParquetError::oos("Not enough bytes for bitpacked data")); + }; + self.data = remaining; + + let length = std::cmp::min(packed.len() * 8 / self.num_bits, self.num_values); + self.num_values -= length; + + length + } else { + // is rle + let run_length = indicator as usize >> 1; + // repeated-value := value that is repeated, using a fixed-width of round-up-to-next-byte(bit-width) + let rle_bytes = self.num_bits.div_ceil(8); + let Some(remaining) = self.data.get(rle_bytes..) else { + return Err(ParquetError::oos("Not enough bytes for RLE encoded data")); + }; + self.data = remaining; + + let length = std::cmp::min(run_length, self.num_values); + self.num_values -= length; + + length + })) + } + pub fn limit_to(&mut self, length: usize) { self.num_values = self.num_values.min(length); } diff --git a/crates/polars-python/src/dataframe/io.rs b/crates/polars-python/src/dataframe/io.rs index 9b34eb7e8ae9..d32a2c11ba8a 100644 --- a/crates/polars-python/src/dataframe/io.rs +++ b/crates/polars-python/src/dataframe/io.rs @@ -425,7 +425,7 @@ impl PyDataFrame { let buf = get_file_like(py_f, true)?; py.allow_threads(|| { - ParquetWriter::new(buf) + ParquetWriter::new(BufWriter::new(buf)) .with_compression(compression) .with_statistics(statistics.0) .with_row_group_size(row_group_size) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 773862591f0e..8ea7d2152bc0 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -4,6 +4,7 @@ import io from datetime import datetime, time, timezone from decimal import Decimal +from itertools import chain from typing import IO, TYPE_CHECKING, Any, Callable, Literal, cast import fsspec @@ -2319,3 +2320,124 @@ def test_nested_dicts(content: list[float | None]) -> None: df.write_parquet(f, use_pyarrow=True) f.seek(0) assert_frame_equal(df, pl.read_parquet(f)) + + +@pytest.mark.parametrize( + "leading_nulls", + [ + [], + [None] * 7, + ], +) +@pytest.mark.parametrize( + "trailing_nulls", + [ + [], + [None] * 7, + ], +) +@pytest.mark.parametrize( + "first_chunk", + # Create both RLE and Bitpacked chunks + [ + [1] * 57, + [1 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +@pytest.mark.parametrize( + "second_chunk", + # Create both RLE and Bitpacked chunks + [ + [2] * 57, + [2 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +def test_dict_slices( + leading_nulls: list[None], + trailing_nulls: list[None], + first_chunk: list[None | int], + second_chunk: list[None | int], +) -> None: + df = pl.Series( + "a", leading_nulls + first_chunk + second_chunk + trailing_nulls, pl.Int64 + ).to_frame() + + f = io.BytesIO() + df.write_parquet(f) + + for offset in chain([0, 1, 2], range(3, df.height, 3)): + for length in chain([df.height, 1, 2], range(3, df.height - offset, 3)): + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).slice(offset, length).collect(), + df.slice(offset, length), + ) + + +@pytest.mark.parametrize( + "mask", + [ + [i % 13 < 3 and i % 17 > 3 for i in range(57 * 2)], + [False] * 23 + [True] * 68 + [False] * 23, + [False] * 23 + [True] * 24 + [False] * 20 + [True] * 24 + [False] * 23, + [True] + [False] * 22 + [True] * 24 + [False] * 20 + [True] * 24 + [False] * 23, + [False] * 23 + [True] * 24 + [False] * 20 + [True] * 24 + [False] * 22 + [True], + [True] + + [False] * 22 + + [True] * 24 + + [False] * 20 + + [True] * 24 + + [False] * 22 + + [True], + [False] * 56 + [True] * 58, + [False] * 57 + [True] * 57, + [False] * 58 + [True] * 56, + [True] * 56 + [False] * 58, + [True] * 57 + [False] * 57, + [True] * 58 + [False] * 56, + ], +) +@pytest.mark.parametrize( + "first_chunk", + # Create both RLE and Bitpacked chunks + [ + [1] * 57, + [1 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +@pytest.mark.parametrize( + "second_chunk", + # Create both RLE and Bitpacked chunks + [ + [2] * 57, + [2 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +def test_dict_masked( + mask: list[bool], + first_chunk: list[None | int], + second_chunk: list[None | int], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", first_chunk + second_chunk, pl.Int64), + pl.Series("f", mask, pl.Boolean), + ] + ) + + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f, parallel="prefiltered").filter(pl.col.f).collect(), + df.filter(pl.col.f), + )