From 7ecf06381ef6877c73a5df0423c2c95e76c6fa4d Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 5 Nov 2024 14:22:07 +0100 Subject: [PATCH 01/11] move dictionary to separate folder --- .../src/arrow/read/deserialize/binview.rs | 6 +- .../src/arrow/read/deserialize/boolean.rs | 2 +- .../deserialize/dictionary_encoded/mod.rs | 174 ++++ .../dictionary_encoded/optional.rs | 186 ++++ .../optional_masked_dense.rs | 230 +++++ .../dictionary_encoded/required.rs | 79 ++ .../required_masked_dense.rs | 192 ++++ .../read/deserialize/fixed_size_binary.rs | 6 +- .../src/arrow/read/deserialize/mod.rs | 1 + .../arrow/read/deserialize/primitive/float.rs | 3 +- .../read/deserialize/primitive/integer.rs | 3 +- .../arrow/read/deserialize/primitive/plain.rs | 2 +- .../read/deserialize/utils/dict_encoded.rs | 818 ------------------ .../src/arrow/read/deserialize/utils/mod.rs | 1 - .../src/parquet/encoding/hybrid_rle/mod.rs | 10 + 15 files changed, 884 insertions(+), 829 deletions(-) create mode 100644 crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs create mode 100644 crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs create mode 100644 crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs create mode 100644 crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs create mode 100644 crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs delete mode 100644 crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs index f9b553bea397..a73e5e784ab5 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs @@ -8,11 +8,11 @@ use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::buffer::Buffer; use arrow::datatypes::{ArrowDataType, PhysicalType}; -use super::utils::dict_encoded::{append_validity, constrain_page_validity}; +use super::dictionary_encoded::{append_validity, constrain_page_validity}; use super::utils::{ dict_indices_decoder, filter_from_range, freeze_validity, unspecialized_decode, }; -use super::Filter; +use super::{dictionary_encoded, Filter}; use crate::parquet::encoding::{delta_byte_array, delta_length_byte_array, hybrid_rle, Encoding}; use crate::parquet::error::{ParquetError, ParquetResult}; use crate::parquet::page::{split_buffer, DataPage, DictPage}; @@ -521,7 +521,7 @@ impl utils::Decoder for BinViewDecoder { let start_length = decoded.0.views().len(); - utils::dict_encoded::decode_dict( + dictionary_encoded::decode_dict( indexes.clone(), dict, state.is_optional, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs index 11019b3ab614..56dda77cf37a 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs @@ -5,7 +5,7 @@ use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::datatypes::ArrowDataType; use polars_compute::filter::filter_boolean_kernel; -use super::utils::dict_encoded::{append_validity, constrain_page_validity}; +use super::dictionary_encoded::{append_validity, constrain_page_validity}; use super::utils::{ self, decode_hybrid_rle_into_bitmap, filter_from_range, freeze_validity, Decoder, ExactSize, }; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs new file mode 100644 index 000000000000..e24503a31826 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs @@ -0,0 +1,174 @@ +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::types::{AlignedBytes, NativeType}; +use polars_compute::filter::filter_boolean_kernel; + +use super::utils::filter_from_range; +use super::ParquetError; +use crate::parquet::encoding::hybrid_rle::HybridRleDecoder; +use crate::parquet::error::ParquetResult; +use crate::read::Filter; + +mod optional; +mod optional_masked_dense; +mod required; +mod required_masked_dense; + +pub fn decode_dict( + values: HybridRleDecoder<'_>, + dict: &[T], + is_optional: bool, + page_validity: Option<&Bitmap>, + filter: Option, + validity: &mut MutableBitmap, + target: &mut Vec, +) -> ParquetResult<()> { + decode_dict_dispatch( + values, + bytemuck::cast_slice(dict), + is_optional, + page_validity, + filter, + validity, + ::cast_vec_ref_mut(target), + ) +} + +#[inline(never)] +pub fn decode_dict_dispatch( + mut values: HybridRleDecoder<'_>, + dict: &[B], + is_optional: bool, + page_validity: Option<&Bitmap>, + filter: Option, + validity: &mut MutableBitmap, + target: &mut Vec, +) -> ParquetResult<()> { + if cfg!(debug_assertions) && is_optional { + assert_eq!(target.len(), validity.len()); + } + + if is_optional { + append_validity(page_validity, filter.as_ref(), validity, values.len()); + } + + let page_validity = constrain_page_validity(values.len(), page_validity, filter.as_ref()); + + match (filter, page_validity) { + (None, None) => required::decode(values, dict, target), + (Some(Filter::Range(rng)), None) if rng.start == 0 => { + values.limit_to(rng.end); + required::decode(values, dict, target) + }, + (None, Some(page_validity)) => optional::decode(values, dict, page_validity, target), + (Some(Filter::Range(rng)), Some(page_validity)) if rng.start == 0 => { + optional::decode(values, dict, page_validity, target) + }, + (Some(Filter::Mask(filter)), None) => { + required_masked_dense::decode(values, dict, filter, target) + }, + (Some(Filter::Mask(filter)), Some(page_validity)) => { + optional_masked_dense::decode(values, dict, filter, page_validity, target) + }, + (Some(Filter::Range(rng)), None) => { + required_masked_dense::decode(values, dict, filter_from_range(rng.clone()), target) + }, + (Some(Filter::Range(rng)), Some(page_validity)) => optional_masked_dense::decode( + values, + dict, + filter_from_range(rng.clone()), + page_validity, + target, + ), + }?; + + if cfg!(debug_assertions) && is_optional { + assert_eq!(target.len(), validity.len()); + } + + Ok(()) +} + +pub(crate) fn append_validity( + page_validity: Option<&Bitmap>, + filter: Option<&Filter>, + validity: &mut MutableBitmap, + values_len: usize, +) { + match (page_validity, filter) { + (None, None) => validity.extend_constant(values_len, true), + (None, Some(f)) => validity.extend_constant(f.num_rows(), true), + (Some(page_validity), None) => validity.extend_from_bitmap(page_validity), + (Some(page_validity), Some(Filter::Range(rng))) => { + let page_validity = page_validity.clone(); + validity.extend_from_bitmap(&page_validity.clone().sliced(rng.start, rng.len())) + }, + (Some(page_validity), Some(Filter::Mask(mask))) => { + validity.extend_from_bitmap(&filter_boolean_kernel(page_validity, mask)) + }, + } +} + +pub(crate) fn constrain_page_validity( + values_len: usize, + page_validity: Option<&Bitmap>, + filter: Option<&Filter>, +) -> Option { + let num_unfiltered_rows = match (filter.as_ref(), page_validity) { + (None, None) => values_len, + (None, Some(pv)) => { + debug_assert!(pv.len() >= values_len); + pv.len() + }, + (Some(f), v) => { + if cfg!(debug_assertions) { + if let Some(v) = v { + assert!(v.len() >= f.max_offset()); + } + } + + f.max_offset() + }, + }; + + page_validity.map(|pv| { + if pv.len() > num_unfiltered_rows { + pv.clone().sliced(0, num_unfiltered_rows) + } else { + pv.clone() + } + }) +} + + +#[cold] +fn oob_dict_idx() -> ParquetError { + ParquetError::oos("Dictionary Index is out-of-bounds") +} + +#[inline(always)] +fn verify_dict_indices(indices: &[u32; 32], dict_size: usize) -> ParquetResult<()> { + let mut is_valid = true; + for &idx in indices { + is_valid &= (idx as usize) < dict_size; + } + + if is_valid { + return Ok(()); + } + + Err(oob_dict_idx()) +} + +#[inline(always)] +fn verify_dict_indices_slice(indices: &[u32], dict_size: usize) -> ParquetResult<()> { + let mut is_valid = true; + for &idx in indices { + is_valid &= (idx as usize) < dict_size; + } + + if is_valid { + return Ok(()); + } + + Err(oob_dict_idx()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs new file mode 100644 index 000000000000..23c4fd23f0db --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -0,0 +1,186 @@ +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::Bitmap; +use arrow::types::AlignedBytes; + +use super::{oob_dict_idx, verify_dict_indices}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +#[inline(never)] +pub fn decode( + mut values: HybridRleDecoder<'_>, + dict: &[B], + validity: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + let num_valid_values = validity.set_bits(); + + // Dispatch to the required kernel if all rows are valid anyway. + if num_valid_values == validity.len() { + values.limit_to(validity.len()); + return super::required::decode(values, dict, target); + } + + if dict.is_empty() && num_valid_values > 0 { + return Err(oob_dict_idx()); + } + + assert!(num_valid_values <= values.len()); + let start_length = target.len(); + let end_length = start_length + validity.len(); + + target.reserve(validity.len()); + let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + + values.limit_to(num_valid_values); + let mut validity = BitMask::from_bitmap(&validity); + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; + + for chunk in values.into_chunk_iter() { + match chunk? { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + // If we know that we have `size` times `value` that we can append, but there might + // be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is + // done with `num_bits_before_nth_one` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + + let num_chunk_rows = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + + (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_rows) }; + + let Some(&value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + + let target_slice; + // SAFETY: + // Given `validity_iter` before the `advance_by_bits` + // + // 1. `target_ptr..target_ptr + validity_iter.bits_left()` is allocated + // 2. `num_chunk_rows <= validity_iter.bits_left()` + unsafe { + target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); + target_ptr = target_ptr.add(num_chunk_rows); + } + + target_slice.fill(value); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let mut chunked = decoder.chunked(); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + + { + let mut num_done = 0; + let mut validity_iter = validity.fast_iter_u56(); + + 'outer: for v in validity_iter.by_ref() { + while num_buffered < v.count_ones() as usize { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let Some(num_added) = chunked.next_into(buffer_part) else { + break 'outer; + }; + + verify_dict_indices(buffer_part, dict.len())?; + + num_buffered += num_added; + + buffer_part_idx += 1; + buffer_part_idx %= 4; + } + + let mut num_read = 0; + + for i in 0..56 { + let idx = values_buffer[(values_offset + num_read) % 128]; + + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + let value = unsafe { dict.get_unchecked(idx as usize) }; + let value = *value; + unsafe { target_ptr.add(i).write(value) }; + num_read += ((v >> i) & 1) as usize; + } + + values_offset += num_read; + values_offset %= 128; + num_buffered -= num_read; + unsafe { + target_ptr = target_ptr.add(56); + } + num_done += 56; + } + + (_, validity) = unsafe { validity.split_at_unchecked(num_done) }; + } + + let num_decoder_remaining = num_buffered + chunked.decoder.len(); + let decoder_limit = validity + .nth_set_bit_idx(num_decoder_remaining, 0) + .unwrap_or(validity.len()); + + let current_validity; + (current_validity, validity) = + unsafe { validity.split_at_unchecked(decoder_limit) }; + let (v, _) = current_validity.fast_iter_u56().remainder(); + + while num_buffered < v.count_ones() as usize { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let num_added = chunked.next_into(buffer_part).unwrap(); + + verify_dict_indices(buffer_part, dict.len())?; + + num_buffered += num_added; + + buffer_part_idx += 1; + buffer_part_idx %= 4; + } + + let mut num_read = 0; + + for i in 0..decoder_limit { + let idx = values_buffer[(values_offset + num_read) % 128]; + let value = unsafe { dict.get_unchecked(idx as usize) }; + let value = *value; + unsafe { *target_ptr.add(i) = value }; + num_read += ((v >> i) & 1) as usize; + } + + unsafe { + target_ptr = target_ptr.add(decoder_limit); + } + }, + } + } + + if cfg!(debug_assertions) { + assert_eq!(validity.set_bits(), 0); + } + + let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, validity.len()) }; + target_slice.fill(B::zeroed()); + unsafe { + target.set_len(end_length); + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs new file mode 100644 index 000000000000..9ad9ee6ac22c --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs @@ -0,0 +1,230 @@ +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::Bitmap; +use arrow::types::AlignedBytes; + +use super::{oob_dict_idx, verify_dict_indices}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +#[inline(never)] +pub fn decode( + mut values: HybridRleDecoder<'_>, + dict: &[B], + filter: Bitmap, + validity: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + let num_rows = filter.set_bits(); + let num_valid_values = validity.set_bits(); + + // Dispatch to the non-filter kernel if all rows are needed anyway. + if num_rows == filter.len() { + return super::optional::decode(values, dict, validity, target); + } + + // Dispatch to the required kernel if all rows are valid anyway. + if num_valid_values == validity.len() { + return super::required_masked_dense::decode(values, dict, filter, target); + } + + if dict.is_empty() && num_valid_values > 0 { + return Err(oob_dict_idx()); + } + + debug_assert_eq!(filter.len(), validity.len()); + assert!(num_valid_values <= values.len()); + let start_length = target.len(); + + target.reserve(num_rows); + let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + + let mut filter = BitMask::from_bitmap(&filter); + let mut validity = BitMask::from_bitmap(&validity); + + values.limit_to(num_valid_values); + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; + + let mut num_rows_left = num_rows; + + for chunk in values.into_chunk_iter() { + // Early stop if we have no more rows to load. + if num_rows_left == 0 { + break; + } + + match chunk? { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + // If we know that we have `size` times `value` that we can append, but there might + // be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is + // done with `num_bits_before_nth_one` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + + let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + + let current_filter; + (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_values) }; + (current_filter, filter) = unsafe { filter.split_at_unchecked(num_chunk_values) }; + + let num_chunk_rows = current_filter.set_bits(); + + if num_chunk_rows > 0 { + let target_slice; + // SAFETY: + // Given `filter_iter` before the `advance_by_bits`. + // + // 1. `target_ptr..target_ptr + filter_iter.count_ones()` is allocated + // 2. `num_chunk_rows < filter_iter.count_ones()` + unsafe { + target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); + target_ptr = target_ptr.add(num_chunk_rows); + } + + let Some(value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + + target_slice.fill(*value); + num_rows_left -= num_chunk_rows; + } + }, + HybridRleChunk::Bitpacked(mut decoder) => { + // For bitpacked we do the following: + // 1. See how many rows are encoded by this `decoder`. + // 2. Go through the filter and validity 56 bits at a time and: + // 0. If filter bits are 0, skip the chunk entirely. + // 1. Buffer enough values so that we can branchlessly decode with the filter + // and validity. + // 2. Decode with filter and validity. + // 3. Decode remainder. + + let size = decoder.len(); + let mut chunked = decoder.chunked(); + + let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + let mut skip_values = 0; + + let current_filter; + let current_validity; + + (current_filter, filter) = unsafe { filter.split_at_unchecked(num_chunk_values) }; + (current_validity, validity) = + unsafe { validity.split_at_unchecked(num_chunk_values) }; + + let mut iter = |mut f: u64, mut v: u64| { + // Skip chunk if we don't any values from here. + if f == 0 { + skip_values += v.count_ones() as usize; + return ParquetResult::Ok(()); + } + + // Skip over already buffered items. + let num_buffered_skipped = skip_values.min(num_buffered); + values_offset += num_buffered_skipped; + num_buffered -= num_buffered_skipped; + skip_values -= num_buffered_skipped; + + // If we skipped plenty already, just skip decoding those chunks instead of + // decoding them and throwing them away. + chunked.decoder.skip_chunks(skip_values / 32); + // The leftovers we have to decode but we can also just skip. + skip_values %= 32; + + while num_buffered < v.count_ones() as usize { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let num_added = chunked.next_into(buffer_part).unwrap(); + + verify_dict_indices(buffer_part, dict.len())?; + + let skip_chunk_values = skip_values.min(num_added); + + values_offset += skip_chunk_values; + num_buffered += num_added - skip_chunk_values; + skip_values -= skip_chunk_values; + + buffer_part_idx += 1; + buffer_part_idx %= 4; + } + + let mut num_read = 0; + let mut num_written = 0; + + while f != 0 { + let offset = f.trailing_zeros(); + + num_read += (v & (1u64 << offset).wrapping_sub(1)).count_ones() as usize; + v >>= offset; + + let idx = values_buffer[(values_offset + num_read) % 128]; + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + let value = unsafe { dict.get_unchecked(idx as usize) }; + let value = *value; + unsafe { target_ptr.add(num_written).write(value) }; + + num_written += 1; + num_read += (v & 1) as usize; + + f >>= offset + 1; // Clear least significant bit. + v >>= 1; + } + + num_read += v.count_ones() as usize; + + values_offset += num_read; + values_offset %= 128; + num_buffered -= num_read; + unsafe { + target_ptr = target_ptr.add(num_written); + } + num_rows_left -= num_written; + + ParquetResult::Ok(()) + }; + + let mut f_iter = current_filter.fast_iter_u56(); + let mut v_iter = current_validity.fast_iter_u56(); + + for (f, v) in f_iter.by_ref().zip(v_iter.by_ref()) { + iter(f, v)?; + } + + let (f, fl) = f_iter.remainder(); + let (v, vl) = v_iter.remainder(); + + assert_eq!(fl, vl); + + iter(f, v)?; + }, + } + } + + if cfg!(debug_assertions) { + assert_eq!(validity.set_bits(), 0); + } + + let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, num_rows_left) }; + target_slice.fill(B::zeroed()); + unsafe { + target.set_len(start_length + num_rows); + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs new file mode 100644 index 000000000000..a76b4dc370a6 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -0,0 +1,79 @@ +use arrow::types::AlignedBytes; + +use super::{oob_dict_idx, verify_dict_indices, verify_dict_indices_slice}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +#[inline(never)] +pub fn decode( + mut values: HybridRleDecoder<'_>, + dict: &[B], + target: &mut Vec, +) -> ParquetResult<()> { + if dict.is_empty() && values.len() > 0 { + return Err(oob_dict_idx()); + } + + let start_length = target.len(); + let end_length = start_length + values.len(); + + target.reserve(values.len()); + let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + + while values.len() > 0 { + let chunk = values.next_chunk()?.unwrap(); + + match chunk { + HybridRleChunk::Rle(value, length) => { + if length == 0 { + continue; + } + + let target_slice; + // SAFETY: + // 1. `target_ptr..target_ptr + values.len()` is allocated + // 2. `length <= limit` + unsafe { + target_slice = std::slice::from_raw_parts_mut(target_ptr, length); + target_ptr = target_ptr.add(length); + } + + let Some(&value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + + target_slice.fill(value); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let mut chunked = decoder.chunked(); + for chunk in chunked.by_ref() { + verify_dict_indices(&chunk, dict.len())?; + + for (i, &idx) in chunk.iter().enumerate() { + unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; + } + unsafe { + target_ptr = target_ptr.add(32); + } + } + + if let Some((chunk, chunk_size)) = chunked.remainder() { + verify_dict_indices_slice(&chunk[..chunk_size], dict.len())?; + + for (i, &idx) in chunk[..chunk_size].iter().enumerate() { + unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; + } + unsafe { + target_ptr = target_ptr.add(chunk_size); + } + } + }, + } + } + + unsafe { + target.set_len(end_length); + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs new file mode 100644 index 000000000000..da4385023732 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -0,0 +1,192 @@ +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::Bitmap; +use arrow::types::AlignedBytes; + +use super::{oob_dict_idx, verify_dict_indices}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +#[inline(never)] +pub fn decode( + mut values: HybridRleDecoder<'_>, + dict: &[B], + filter: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + let num_rows = filter.set_bits(); + + // Dispatch to the non-filter kernel if all rows are needed anyway. + if num_rows == filter.len() { + values.limit_to(filter.len()); + return super::required::decode(values, dict, target); + } + + if dict.is_empty() && !filter.is_empty() { + return Err(oob_dict_idx()); + } + + let start_length = target.len(); + + target.reserve(num_rows); + let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + + let mut filter = BitMask::from_bitmap(&filter); + + values.limit_to(filter.len()); + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; + + let mut num_rows_left = num_rows; + + for chunk in values.into_chunk_iter() { + if num_rows_left == 0 { + break; + } + + match chunk? { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + let size = size.min(filter.len()); + + // If we know that we have `size` times `value` that we can append, but there might + // be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is + // done with `num_bits_before_nth_one` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + + let current_filter; + + (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; + let num_chunk_rows = current_filter.set_bits(); + + if num_chunk_rows > 0 { + let target_slice; + // SAFETY: + // Given `filter_iter` before the `advance_by_bits`. + // + // 1. `target_ptr..target_ptr + filter_iter.count_ones()` is allocated + // 2. `num_chunk_rows < filter_iter.count_ones()` + unsafe { + target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); + target_ptr = target_ptr.add(num_chunk_rows); + } + + let Some(value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + + target_slice.fill(*value); + num_rows_left -= num_chunk_rows; + } + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let size = decoder.len().min(filter.len()); + let mut chunked = decoder.chunked(); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + let mut skip_values = 0; + + let current_filter; + + (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; + + let mut iter = |mut f: u64, len: usize| { + debug_assert!(len <= 64); + + // Skip chunk if we don't any values from here. + if f == 0 { + skip_values += len; + return ParquetResult::Ok(()); + } + + // Skip over already buffered items. + let num_buffered_skipped = skip_values.min(num_buffered); + values_offset += num_buffered_skipped; + num_buffered -= num_buffered_skipped; + skip_values -= num_buffered_skipped; + + // If we skipped plenty already, just skip decoding those chunks instead of + // decoding them and throwing them away. + chunked.decoder.skip_chunks(skip_values / 32); + // The leftovers we have to decode but we can also just skip. + skip_values %= 32; + + while num_buffered < len { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let num_added = chunked.next_into(buffer_part).unwrap(); + + verify_dict_indices(buffer_part, dict.len())?; + + let skip_chunk_values = skip_values.min(num_added); + + values_offset += skip_chunk_values; + num_buffered += num_added - skip_chunk_values; + skip_values -= skip_chunk_values; + + buffer_part_idx += 1; + buffer_part_idx %= 4; + } + + let mut num_read = 0; + let mut num_written = 0; + + while f != 0 { + let offset = f.trailing_zeros() as usize; + + num_read += offset; + + let idx = values_buffer[(values_offset + num_read) % 128]; + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + let value = *unsafe { dict.get_unchecked(idx as usize) }; + unsafe { target_ptr.add(num_written).write(value) }; + + num_written += 1; + num_read += 1; + + f >>= offset + 1; // Clear least significant bit. + } + + values_offset += len; + values_offset %= 128; + num_buffered -= len; + unsafe { + target_ptr = target_ptr.add(num_written); + } + num_rows_left -= num_written; + + ParquetResult::Ok(()) + }; + + let mut f_iter = current_filter.fast_iter_u56(); + + for f in f_iter.by_ref() { + iter(f, 56)?; + } + + let (f, fl) = f_iter.remainder(); + + iter(f, fl)?; + }, + } + } + + unsafe { + target.set_len(start_length + num_rows); + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs index 7ae39b3366ad..9cce1ee4c60c 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs @@ -11,7 +11,7 @@ use arrow::types::{ }; use super::utils::array_chunks::ArrayChunks; -use super::utils::dict_encoded::append_validity; +use super::dictionary_encoded::append_validity; use super::utils::{dict_indices_decoder, freeze_validity, Decoder}; use super::Filter; use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; @@ -19,7 +19,7 @@ use crate::parquet::encoding::{hybrid_rle, Encoding}; use crate::parquet::error::{ParquetError, ParquetResult}; use crate::parquet::page::{split_buffer, DataPage, DictPage}; use crate::read::deserialize::utils; -use crate::read::deserialize::utils::dict_encoded::constrain_page_validity; +use crate::read::deserialize::dictionary_encoded::constrain_page_validity; #[allow(clippy::large_enum_variant)] #[derive(Debug)] @@ -271,7 +271,7 @@ fn decode_fsb_dict( macro_rules! decode_static_size { ($dict:ident, $target:ident) => {{ - super::utils::dict_encoded::decode_dict_dispatch( + super::dictionary_encoded::decode_dict_dispatch( values, $dict, is_optional, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs index 3bc1beb30973..77089f1da486 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs @@ -3,6 +3,7 @@ mod binview; mod boolean; mod dictionary; +mod dictionary_encoded; mod fixed_size_binary; mod nested; mod nested_utils; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs index 225738b0c1fd..e5961e4baffe 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs @@ -12,6 +12,7 @@ use crate::parquet::encoding::{byte_stream_split, hybrid_rle, Encoding}; use crate::parquet::error::ParquetResult; use crate::parquet::page::{split_buffer, DataPage, DictPage}; use crate::parquet::types::{decode, NativeType as ParquetNativeType}; +use crate::read::deserialize::dictionary_encoded; use crate::read::deserialize::utils::{ dict_indices_decoder, freeze_validity, unspecialized_decode, }; @@ -170,7 +171,7 @@ where &mut decoded.0, self.0.decoder, ), - StateTranslation::Dictionary(ref mut indexes) => utils::dict_encoded::decode_dict( + StateTranslation::Dictionary(ref mut indexes) => dictionary_encoded::decode_dict( indexes.clone(), state.dict.unwrap(), state.is_optional, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs index 087fc1c447d5..5eb4b438741b 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs @@ -12,6 +12,7 @@ use crate::parquet::encoding::{byte_stream_split, delta_bitpacked, hybrid_rle, E use crate::parquet::error::ParquetResult; use crate::parquet::page::{split_buffer, DataPage, DictPage}; use crate::parquet::types::{decode, NativeType as ParquetNativeType}; +use crate::read::deserialize::dictionary_encoded; use crate::read::deserialize::utils::{ dict_indices_decoder, freeze_validity, unspecialized_decode, }; @@ -201,7 +202,7 @@ where &mut decoded.0, self.0.decoder, ), - StateTranslation::Dictionary(ref mut indexes) => utils::dict_encoded::decode_dict( + StateTranslation::Dictionary(ref mut indexes) => dictionary_encoded::decode_dict( indexes.clone(), state.dict.unwrap(), state.is_optional, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs index 909e68f5af37..d37f40a3337e 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs @@ -6,7 +6,7 @@ use super::DecoderFunction; use crate::parquet::error::ParquetResult; use crate::parquet::types::NativeType as ParquetNativeType; use crate::read::deserialize::utils::array_chunks::ArrayChunks; -use crate::read::deserialize::utils::dict_encoded::{append_validity, constrain_page_validity}; +use crate::read::deserialize::dictionary_encoded::{append_validity, constrain_page_validity}; use crate::read::{Filter, ParquetError}; #[allow(clippy::too_many_arguments)] diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs deleted file mode 100644 index 879bfd85615e..000000000000 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs +++ /dev/null @@ -1,818 +0,0 @@ -use arrow::bitmap::bitmask::BitMask; -use arrow::bitmap::{Bitmap, MutableBitmap}; -use arrow::types::{AlignedBytes, NativeType}; -use polars_compute::filter::filter_boolean_kernel; - -use super::filter_from_range; -use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; -use crate::parquet::error::ParquetResult; -use crate::read::{Filter, ParquetError}; - -pub fn decode_dict( - values: HybridRleDecoder<'_>, - dict: &[T], - is_optional: bool, - page_validity: Option<&Bitmap>, - filter: Option, - validity: &mut MutableBitmap, - target: &mut Vec, -) -> ParquetResult<()> { - decode_dict_dispatch( - values, - bytemuck::cast_slice(dict), - is_optional, - page_validity, - filter, - validity, - ::cast_vec_ref_mut(target), - ) -} - -pub(crate) fn append_validity( - page_validity: Option<&Bitmap>, - filter: Option<&Filter>, - validity: &mut MutableBitmap, - values_len: usize, -) { - match (page_validity, filter) { - (None, None) => validity.extend_constant(values_len, true), - (None, Some(f)) => validity.extend_constant(f.num_rows(), true), - (Some(page_validity), None) => validity.extend_from_bitmap(page_validity), - (Some(page_validity), Some(Filter::Range(rng))) => { - let page_validity = page_validity.clone(); - validity.extend_from_bitmap(&page_validity.clone().sliced(rng.start, rng.len())) - }, - (Some(page_validity), Some(Filter::Mask(mask))) => { - validity.extend_from_bitmap(&filter_boolean_kernel(page_validity, mask)) - }, - } -} - -pub(crate) fn constrain_page_validity( - values_len: usize, - page_validity: Option<&Bitmap>, - filter: Option<&Filter>, -) -> Option { - let num_unfiltered_rows = match (filter.as_ref(), page_validity) { - (None, None) => values_len, - (None, Some(pv)) => { - debug_assert!(pv.len() >= values_len); - pv.len() - }, - (Some(f), v) => { - if cfg!(debug_assertions) { - if let Some(v) = v { - assert!(v.len() >= f.max_offset()); - } - } - - f.max_offset() - }, - }; - - page_validity.map(|pv| { - if pv.len() > num_unfiltered_rows { - pv.clone().sliced(0, num_unfiltered_rows) - } else { - pv.clone() - } - }) -} - -#[inline(never)] -pub fn decode_dict_dispatch( - mut values: HybridRleDecoder<'_>, - dict: &[B], - is_optional: bool, - page_validity: Option<&Bitmap>, - filter: Option, - validity: &mut MutableBitmap, - target: &mut Vec, -) -> ParquetResult<()> { - if cfg!(debug_assertions) && is_optional { - assert_eq!(target.len(), validity.len()); - } - - if is_optional { - append_validity(page_validity, filter.as_ref(), validity, values.len()); - } - - let page_validity = constrain_page_validity(values.len(), page_validity, filter.as_ref()); - - match (filter, page_validity) { - (None, None) => decode_required_dict(values, dict, target), - (Some(Filter::Range(rng)), None) if rng.start == 0 => { - values.limit_to(rng.end); - decode_required_dict(values, dict, target) - }, - (None, Some(page_validity)) => decode_optional_dict(values, dict, page_validity, target), - (Some(Filter::Range(rng)), Some(page_validity)) if rng.start == 0 => { - decode_optional_dict(values, dict, page_validity, target) - }, - (Some(Filter::Mask(filter)), None) => { - decode_masked_required_dict(values, dict, filter, target) - }, - (Some(Filter::Mask(filter)), Some(page_validity)) => { - decode_masked_optional_dict(values, dict, filter, page_validity, target) - }, - (Some(Filter::Range(rng)), None) => { - decode_masked_required_dict(values, dict, filter_from_range(rng.clone()), target) - }, - (Some(Filter::Range(rng)), Some(page_validity)) => decode_masked_optional_dict( - values, - dict, - filter_from_range(rng.clone()), - page_validity, - target, - ), - }?; - - if cfg!(debug_assertions) && is_optional { - assert_eq!(target.len(), validity.len()); - } - - Ok(()) -} - -#[cold] -fn oob_dict_idx() -> ParquetError { - ParquetError::oos("Dictionary Index is out-of-bounds") -} - -#[inline(always)] -fn verify_dict_indices(indices: &[u32; 32], dict_size: usize) -> ParquetResult<()> { - let mut is_valid = true; - for &idx in indices { - is_valid &= (idx as usize) < dict_size; - } - - if is_valid { - return Ok(()); - } - - Err(oob_dict_idx()) -} - -#[inline(never)] -pub fn decode_required_dict( - mut values: HybridRleDecoder<'_>, - dict: &[B], - target: &mut Vec, -) -> ParquetResult<()> { - if dict.is_empty() && values.len() > 0 { - return Err(oob_dict_idx()); - } - - let start_length = target.len(); - let end_length = start_length + values.len(); - - target.reserve(values.len()); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - - while values.len() > 0 { - let chunk = values.next_chunk()?.unwrap(); - - match chunk { - HybridRleChunk::Rle(value, length) => { - if length == 0 { - continue; - } - - let target_slice; - // SAFETY: - // 1. `target_ptr..target_ptr + values.len()` is allocated - // 2. `length <= limit` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, length); - target_ptr = target_ptr.add(length); - } - - let Some(&value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; - - target_slice.fill(value); - }, - HybridRleChunk::Bitpacked(mut decoder) => { - let mut chunked = decoder.chunked(); - for chunk in chunked.by_ref() { - verify_dict_indices(&chunk, dict.len())?; - - for (i, &idx) in chunk.iter().enumerate() { - unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; - } - unsafe { - target_ptr = target_ptr.add(32); - } - } - - if let Some((chunk, chunk_size)) = chunked.remainder() { - let highest_idx = chunk[..chunk_size].iter().copied().max().unwrap(); - if highest_idx as usize >= dict.len() { - return Err(oob_dict_idx()); - } - - for (i, &idx) in chunk[..chunk_size].iter().enumerate() { - unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; - } - unsafe { - target_ptr = target_ptr.add(chunk_size); - } - } - }, - } - } - - unsafe { - target.set_len(end_length); - } - - Ok(()) -} - -#[inline(never)] -pub fn decode_optional_dict( - mut values: HybridRleDecoder<'_>, - dict: &[B], - validity: Bitmap, - target: &mut Vec, -) -> ParquetResult<()> { - let num_valid_values = validity.set_bits(); - - // Dispatch to the required kernel if all rows are valid anyway. - if num_valid_values == validity.len() { - values.limit_to(validity.len()); - return decode_required_dict(values, dict, target); - } - - if dict.is_empty() && num_valid_values > 0 { - return Err(oob_dict_idx()); - } - - assert!(num_valid_values <= values.len()); - let start_length = target.len(); - let end_length = start_length + validity.len(); - - target.reserve(validity.len()); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - - values.limit_to(num_valid_values); - let mut validity = BitMask::from_bitmap(&validity); - let mut values_buffer = [0u32; 128]; - let values_buffer = &mut values_buffer; - - for chunk in values.into_chunk_iter() { - match chunk? { - HybridRleChunk::Rle(value, size) => { - if size == 0 { - continue; - } - - // If we know that we have `size` times `value` that we can append, but there might - // be nulls in between those values. - // - // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is - // done with `num_bits_before_nth_one` on the validity mask. - // 2. Fill `num_rows` values into the target buffer. - // 3. Advance the validity mask by `num_rows` values. - - let num_chunk_rows = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); - - (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_rows) }; - - let Some(&value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; - - let target_slice; - // SAFETY: - // Given `validity_iter` before the `advance_by_bits` - // - // 1. `target_ptr..target_ptr + validity_iter.bits_left()` is allocated - // 2. `num_chunk_rows <= validity_iter.bits_left()` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); - target_ptr = target_ptr.add(num_chunk_rows); - } - - target_slice.fill(value); - }, - HybridRleChunk::Bitpacked(mut decoder) => { - let mut chunked = decoder.chunked(); - - let mut buffer_part_idx = 0; - let mut values_offset = 0; - let mut num_buffered: usize = 0; - - { - let mut num_done = 0; - let mut validity_iter = validity.fast_iter_u56(); - - 'outer: for v in validity_iter.by_ref() { - while num_buffered < v.count_ones() as usize { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let Some(num_added) = chunked.next_into(buffer_part) else { - break 'outer; - }; - - verify_dict_indices(buffer_part, dict.len())?; - - num_buffered += num_added; - - buffer_part_idx += 1; - buffer_part_idx %= 4; - } - - let mut num_read = 0; - - for i in 0..56 { - let idx = values_buffer[(values_offset + num_read) % 128]; - - // SAFETY: - // 1. `values_buffer` starts out as only zeros, which we know is in the - // dictionary following the original `dict.is_empty` check. - // 2. Each time we write to `values_buffer`, it is followed by a - // `verify_dict_indices`. - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { target_ptr.add(i).write(value) }; - num_read += ((v >> i) & 1) as usize; - } - - values_offset += num_read; - values_offset %= 128; - num_buffered -= num_read; - unsafe { - target_ptr = target_ptr.add(56); - } - num_done += 56; - } - - (_, validity) = unsafe { validity.split_at_unchecked(num_done) }; - } - - let num_decoder_remaining = num_buffered + chunked.decoder.len(); - let decoder_limit = validity - .nth_set_bit_idx(num_decoder_remaining, 0) - .unwrap_or(validity.len()); - - let current_validity; - (current_validity, validity) = - unsafe { validity.split_at_unchecked(decoder_limit) }; - let (v, _) = current_validity.fast_iter_u56().remainder(); - - while num_buffered < v.count_ones() as usize { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let num_added = chunked.next_into(buffer_part).unwrap(); - - verify_dict_indices(buffer_part, dict.len())?; - - num_buffered += num_added; - - buffer_part_idx += 1; - buffer_part_idx %= 4; - } - - let mut num_read = 0; - - for i in 0..decoder_limit { - let idx = values_buffer[(values_offset + num_read) % 128]; - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { *target_ptr.add(i) = value }; - num_read += ((v >> i) & 1) as usize; - } - - unsafe { - target_ptr = target_ptr.add(decoder_limit); - } - }, - } - } - - if cfg!(debug_assertions) { - assert_eq!(validity.set_bits(), 0); - } - - let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, validity.len()) }; - target_slice.fill(B::zeroed()); - unsafe { - target.set_len(end_length); - } - - Ok(()) -} - -#[inline(never)] -pub fn decode_masked_optional_dict( - mut values: HybridRleDecoder<'_>, - dict: &[B], - filter: Bitmap, - validity: Bitmap, - target: &mut Vec, -) -> ParquetResult<()> { - let num_rows = filter.set_bits(); - let num_valid_values = validity.set_bits(); - - // Dispatch to the non-filter kernel if all rows are needed anyway. - if num_rows == filter.len() { - return decode_optional_dict(values, dict, validity, target); - } - - // Dispatch to the required kernel if all rows are valid anyway. - if num_valid_values == validity.len() { - return decode_masked_required_dict(values, dict, filter, target); - } - - if dict.is_empty() && num_valid_values > 0 { - return Err(oob_dict_idx()); - } - - debug_assert_eq!(filter.len(), validity.len()); - assert!(num_valid_values <= values.len()); - let start_length = target.len(); - - target.reserve(num_rows); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - - let mut filter = BitMask::from_bitmap(&filter); - let mut validity = BitMask::from_bitmap(&validity); - - values.limit_to(num_valid_values); - let mut values_buffer = [0u32; 128]; - let values_buffer = &mut values_buffer; - - let mut num_rows_left = num_rows; - - for chunk in values.into_chunk_iter() { - // Early stop if we have no more rows to load. - if num_rows_left == 0 { - break; - } - - match chunk? { - HybridRleChunk::Rle(value, size) => { - if size == 0 { - continue; - } - - // If we know that we have `size` times `value` that we can append, but there might - // be nulls in between those values. - // - // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is - // done with `num_bits_before_nth_one` on the validity mask. - // 2. Fill `num_rows` values into the target buffer. - // 3. Advance the validity mask by `num_rows` values. - - let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); - - let current_filter; - (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_values) }; - (current_filter, filter) = unsafe { filter.split_at_unchecked(num_chunk_values) }; - - let num_chunk_rows = current_filter.set_bits(); - - if num_chunk_rows > 0 { - let target_slice; - // SAFETY: - // Given `filter_iter` before the `advance_by_bits`. - // - // 1. `target_ptr..target_ptr + filter_iter.count_ones()` is allocated - // 2. `num_chunk_rows < filter_iter.count_ones()` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); - target_ptr = target_ptr.add(num_chunk_rows); - } - - let Some(value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; - - target_slice.fill(*value); - num_rows_left -= num_chunk_rows; - } - }, - HybridRleChunk::Bitpacked(mut decoder) => { - // For bitpacked we do the following: - // 1. See how many rows are encoded by this `decoder`. - // 2. Go through the filter and validity 56 bits at a time and: - // 0. If filter bits are 0, skip the chunk entirely. - // 1. Buffer enough values so that we can branchlessly decode with the filter - // and validity. - // 2. Decode with filter and validity. - // 3. Decode remainder. - - let size = decoder.len(); - let mut chunked = decoder.chunked(); - - let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); - - let mut buffer_part_idx = 0; - let mut values_offset = 0; - let mut num_buffered: usize = 0; - let mut skip_values = 0; - - let current_filter; - let current_validity; - - (current_filter, filter) = unsafe { filter.split_at_unchecked(num_chunk_values) }; - (current_validity, validity) = - unsafe { validity.split_at_unchecked(num_chunk_values) }; - - let mut iter = |mut f: u64, mut v: u64| { - // Skip chunk if we don't any values from here. - if f == 0 { - skip_values += v.count_ones() as usize; - return ParquetResult::Ok(()); - } - - // Skip over already buffered items. - let num_buffered_skipped = skip_values.min(num_buffered); - values_offset += num_buffered_skipped; - num_buffered -= num_buffered_skipped; - skip_values -= num_buffered_skipped; - - // If we skipped plenty already, just skip decoding those chunks instead of - // decoding them and throwing them away. - chunked.decoder.skip_chunks(skip_values / 32); - // The leftovers we have to decode but we can also just skip. - skip_values %= 32; - - while num_buffered < v.count_ones() as usize { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let num_added = chunked.next_into(buffer_part).unwrap(); - - verify_dict_indices(buffer_part, dict.len())?; - - let skip_chunk_values = skip_values.min(num_added); - - values_offset += skip_chunk_values; - num_buffered += num_added - skip_chunk_values; - skip_values -= skip_chunk_values; - - buffer_part_idx += 1; - buffer_part_idx %= 4; - } - - let mut num_read = 0; - let mut num_written = 0; - - while f != 0 { - let offset = f.trailing_zeros(); - - num_read += (v & (1u64 << offset).wrapping_sub(1)).count_ones() as usize; - v >>= offset; - - let idx = values_buffer[(values_offset + num_read) % 128]; - // SAFETY: - // 1. `values_buffer` starts out as only zeros, which we know is in the - // dictionary following the original `dict.is_empty` check. - // 2. Each time we write to `values_buffer`, it is followed by a - // `verify_dict_indices`. - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { target_ptr.add(num_written).write(value) }; - - num_written += 1; - num_read += (v & 1) as usize; - - f >>= offset + 1; // Clear least significant bit. - v >>= 1; - } - - num_read += v.count_ones() as usize; - - values_offset += num_read; - values_offset %= 128; - num_buffered -= num_read; - unsafe { - target_ptr = target_ptr.add(num_written); - } - num_rows_left -= num_written; - - ParquetResult::Ok(()) - }; - - let mut f_iter = current_filter.fast_iter_u56(); - let mut v_iter = current_validity.fast_iter_u56(); - - for (f, v) in f_iter.by_ref().zip(v_iter.by_ref()) { - iter(f, v)?; - } - - let (f, fl) = f_iter.remainder(); - let (v, vl) = v_iter.remainder(); - - assert_eq!(fl, vl); - - iter(f, v)?; - }, - } - } - - if cfg!(debug_assertions) { - assert_eq!(validity.set_bits(), 0); - } - - let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, num_rows_left) }; - target_slice.fill(B::zeroed()); - unsafe { - target.set_len(start_length + num_rows); - } - - Ok(()) -} - -#[inline(never)] -pub fn decode_masked_required_dict( - mut values: HybridRleDecoder<'_>, - dict: &[B], - filter: Bitmap, - target: &mut Vec, -) -> ParquetResult<()> { - let num_rows = filter.set_bits(); - - // Dispatch to the non-filter kernel if all rows are needed anyway. - if num_rows == filter.len() { - values.limit_to(filter.len()); - return decode_required_dict(values, dict, target); - } - - if dict.is_empty() && !filter.is_empty() { - return Err(oob_dict_idx()); - } - - let start_length = target.len(); - - target.reserve(num_rows); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - - let mut filter = BitMask::from_bitmap(&filter); - - values.limit_to(filter.len()); - let mut values_buffer = [0u32; 128]; - let values_buffer = &mut values_buffer; - - let mut num_rows_left = num_rows; - - for chunk in values.into_chunk_iter() { - if num_rows_left == 0 { - break; - } - - match chunk? { - HybridRleChunk::Rle(value, size) => { - if size == 0 { - continue; - } - - let size = size.min(filter.len()); - - // If we know that we have `size` times `value` that we can append, but there might - // be nulls in between those values. - // - // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is - // done with `num_bits_before_nth_one` on the validity mask. - // 2. Fill `num_rows` values into the target buffer. - // 3. Advance the validity mask by `num_rows` values. - - let current_filter; - - (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; - let num_chunk_rows = current_filter.set_bits(); - - if num_chunk_rows > 0 { - let target_slice; - // SAFETY: - // Given `filter_iter` before the `advance_by_bits`. - // - // 1. `target_ptr..target_ptr + filter_iter.count_ones()` is allocated - // 2. `num_chunk_rows < filter_iter.count_ones()` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); - target_ptr = target_ptr.add(num_chunk_rows); - } - - let Some(value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; - - target_slice.fill(*value); - num_rows_left -= num_chunk_rows; - } - }, - HybridRleChunk::Bitpacked(mut decoder) => { - let size = decoder.len().min(filter.len()); - let mut chunked = decoder.chunked(); - - let mut buffer_part_idx = 0; - let mut values_offset = 0; - let mut num_buffered: usize = 0; - let mut skip_values = 0; - - let current_filter; - - (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; - - let mut iter = |mut f: u64, len: usize| { - debug_assert!(len <= 64); - - // Skip chunk if we don't any values from here. - if f == 0 { - skip_values += len; - return ParquetResult::Ok(()); - } - - // Skip over already buffered items. - let num_buffered_skipped = skip_values.min(num_buffered); - values_offset += num_buffered_skipped; - num_buffered -= num_buffered_skipped; - skip_values -= num_buffered_skipped; - - // If we skipped plenty already, just skip decoding those chunks instead of - // decoding them and throwing them away. - chunked.decoder.skip_chunks(skip_values / 32); - // The leftovers we have to decode but we can also just skip. - skip_values %= 32; - - while num_buffered < len { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let num_added = chunked.next_into(buffer_part).unwrap(); - - verify_dict_indices(buffer_part, dict.len())?; - - let skip_chunk_values = skip_values.min(num_added); - - values_offset += skip_chunk_values; - num_buffered += num_added - skip_chunk_values; - skip_values -= skip_chunk_values; - - buffer_part_idx += 1; - buffer_part_idx %= 4; - } - - let mut num_read = 0; - let mut num_written = 0; - - while f != 0 { - let offset = f.trailing_zeros() as usize; - - num_read += offset; - - let idx = values_buffer[(values_offset + num_read) % 128]; - // SAFETY: - // 1. `values_buffer` starts out as only zeros, which we know is in the - // dictionary following the original `dict.is_empty` check. - // 2. Each time we write to `values_buffer`, it is followed by a - // `verify_dict_indices`. - let value = *unsafe { dict.get_unchecked(idx as usize) }; - unsafe { target_ptr.add(num_written).write(value) }; - - num_written += 1; - num_read += 1; - - f >>= offset + 1; // Clear least significant bit. - } - - values_offset += len; - values_offset %= 128; - num_buffered -= len; - unsafe { - target_ptr = target_ptr.add(num_written); - } - num_rows_left -= num_written; - - ParquetResult::Ok(()) - }; - - let mut f_iter = current_filter.fast_iter_u56(); - - for f in f_iter.by_ref() { - iter(f, 56)?; - } - - let (f, fl) = f_iter.remainder(); - - iter(f, fl)?; - }, - } - } - - unsafe { - target.set_len(start_length + num_rows); - } - - Ok(()) -} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs index 960c11d75c12..191d132a52ae 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs @@ -1,5 +1,4 @@ pub(crate) mod array_chunks; -pub(crate) mod dict_encoded; pub(crate) mod filter; use std::ops::Range; diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs index 72e7b82cc4ad..827538d14dbc 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs @@ -43,6 +43,16 @@ impl<'a> Iterator for HybridRleChunkIter<'a> { } } +impl HybridRleChunk<'_> { + #[inline] + pub fn len(&self) -> usize { + match self { + HybridRleChunk::Rle(_, size) => *size, + HybridRleChunk::Bitpacked(decoder) => decoder.len(), + } + } +} + impl<'a> HybridRleDecoder<'a> { /// Returns a new [`HybridRleDecoder`] pub fn new(data: &'a [u8], num_bits: u32, num_values: usize) -> Self { From c52ae2c85c42c5ba53162804bacec6b247947774 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 5 Nov 2024 14:43:44 +0100 Subject: [PATCH 02/11] add skips to dictionary encoding --- .../deserialize/dictionary_encoded/mod.rs | 23 ++----- .../dictionary_encoded/optional.rs | 62 +++++++++++++------ .../optional_masked_dense.rs | 2 +- .../dictionary_encoded/required.rs | 44 ++++++++++--- .../required_masked_dense.rs | 2 +- crates/polars-python/src/dataframe/io.rs | 2 +- 6 files changed, 87 insertions(+), 48 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs index e24503a31826..1e8f6ec0c5c9 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs @@ -2,7 +2,6 @@ use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::types::{AlignedBytes, NativeType}; use polars_compute::filter::filter_boolean_kernel; -use super::utils::filter_from_range; use super::ParquetError; use crate::parquet::encoding::hybrid_rle::HybridRleDecoder; use crate::parquet::error::ParquetResult; @@ -54,14 +53,14 @@ pub fn decode_dict_dispatch( let page_validity = constrain_page_validity(values.len(), page_validity, filter.as_ref()); match (filter, page_validity) { - (None, None) => required::decode(values, dict, target), - (Some(Filter::Range(rng)), None) if rng.start == 0 => { + (None, None) => required::decode(values, dict, target, 0), + (Some(Filter::Range(rng)), None) => { values.limit_to(rng.end); - required::decode(values, dict, target) + required::decode(values, dict, target, rng.start) }, - (None, Some(page_validity)) => optional::decode(values, dict, page_validity, target), - (Some(Filter::Range(rng)), Some(page_validity)) if rng.start == 0 => { - optional::decode(values, dict, page_validity, target) + (None, Some(page_validity)) => optional::decode(values, dict, page_validity, target, 0), + (Some(Filter::Range(rng)), Some(page_validity)) => { + optional::decode(values, dict, page_validity, target, rng.start) }, (Some(Filter::Mask(filter)), None) => { required_masked_dense::decode(values, dict, filter, target) @@ -69,16 +68,6 @@ pub fn decode_dict_dispatch( (Some(Filter::Mask(filter)), Some(page_validity)) => { optional_masked_dense::decode(values, dict, filter, page_validity, target) }, - (Some(Filter::Range(rng)), None) => { - required_masked_dense::decode(values, dict, filter_from_range(rng.clone()), target) - }, - (Some(Filter::Range(rng)), Some(page_validity)) => optional_masked_dense::decode( - values, - dict, - filter_from_range(rng.clone()), - page_validity, - target, - ), }?; if cfg!(debug_assertions) && is_optional { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs index 23c4fd23f0db..a91a06e86b26 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -12,38 +12,48 @@ pub fn decode( dict: &[B], validity: Bitmap, target: &mut Vec, + mut num_skipped_rows: usize, ) -> ParquetResult<()> { - let num_valid_values = validity.set_bits(); - // Dispatch to the required kernel if all rows are valid anyway. - if num_valid_values == validity.len() { + if validity.set_bits() == validity.len() { values.limit_to(validity.len()); - return super::required::decode(values, dict, target); + return super::required::decode(values, dict, target, num_skipped_rows); } - - if dict.is_empty() && num_valid_values > 0 { + if dict.is_empty() && validity.set_bits() > 0 { return Err(oob_dict_idx()); } - assert!(num_valid_values <= values.len()); + let mut num_skipped_values = validity.clone().sliced(0, num_skipped_rows).set_bits(); + + assert!(validity.set_bits() <= values.len()); let start_length = target.len(); - let end_length = start_length + validity.len(); + let end_length = start_length + validity.len() - num_skipped_rows; - target.reserve(validity.len()); + target.reserve(validity.len() - num_skipped_rows); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - values.limit_to(num_valid_values); + values.limit_to(validity.set_bits() - num_skipped_values); let mut validity = BitMask::from_bitmap(&validity); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - for chunk in values.into_chunk_iter() { - match chunk? { - HybridRleChunk::Rle(value, size) => { - if size == 0 { - continue; - } + while let Some(chunk) = values.next_chunk()? { + let chunk_len = chunk.len(); + + if chunk_len <= num_skipped_values { + num_skipped_values -= chunk_len; + if chunk_len > 0 { + let offset = validity + .nth_set_bit_idx(chunk_len - 1, 0) + .unwrap_or(validity.len()); + num_skipped_rows -= offset; + validity = validity.sliced(offset, validity.len() - offset); + } + continue; + } + match chunk { + HybridRleChunk::Rle(value, size) => { // If we know that we have `size` times `value` that we can append, but there might // be nulls in between those values. // @@ -52,7 +62,9 @@ pub fn decode( // 2. Fill `num_rows` values into the target buffer. // 3. Advance the validity mask by `num_rows` values. - let num_chunk_rows = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + let num_chunk_rows = validity + .nth_set_bit_idx(size, num_skipped_rows) + .unwrap_or(validity.len()); (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_rows) }; @@ -74,6 +86,12 @@ pub fn decode( target_slice.fill(value); }, HybridRleChunk::Bitpacked(mut decoder) => { + if num_skipped_values > 0 { + validity = validity.sliced(num_skipped_rows, validity.len() - num_skipped_rows); + decoder.skip_chunks(num_skipped_values / 32); + num_skipped_values %= 32; + } + let mut chunked = decoder.chunked(); let mut buffer_part_idx = 0; @@ -85,7 +103,7 @@ pub fn decode( let mut validity_iter = validity.fast_iter_u56(); 'outer: for v in validity_iter.by_ref() { - while num_buffered < v.count_ones() as usize { + while num_buffered - num_skipped_values < v.count_ones() as usize { let buffer_part = <&mut [u32; 32]>::try_from( &mut values_buffer[buffer_part_idx * 32..][..32], ) @@ -96,7 +114,10 @@ pub fn decode( verify_dict_indices(buffer_part, dict.len())?; - num_buffered += num_added; + let num_added_skipped = num_added.min(num_skipped_values); + num_skipped_values -= num_added_skipped; + + num_buffered += num_added - num_added_skipped; buffer_part_idx += 1; buffer_part_idx %= 4; @@ -170,6 +191,9 @@ pub fn decode( } }, } + + num_skipped_rows = 0; + num_skipped_values = 0; } if cfg!(debug_assertions) { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs index 9ad9ee6ac22c..37364d58b6b6 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs @@ -19,7 +19,7 @@ pub fn decode( // Dispatch to the non-filter kernel if all rows are needed anyway. if num_rows == filter.len() { - return super::optional::decode(values, dict, validity, target); + return super::optional::decode(values, dict, validity, target, 0); } // Dispatch to the required kernel if all rows are valid anyway. diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs index a76b4dc370a6..4407a3fad852 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -9,32 +9,40 @@ pub fn decode( mut values: HybridRleDecoder<'_>, dict: &[B], target: &mut Vec, + mut num_skipped_rows: usize, ) -> ParquetResult<()> { + debug_assert!(num_skipped_rows <= values.len()); + + if num_skipped_rows == values.len() { + return Ok(()); + } if dict.is_empty() && values.len() > 0 { return Err(oob_dict_idx()); } let start_length = target.len(); - let end_length = start_length + values.len(); + let end_length = start_length + values.len() - num_skipped_rows; - target.reserve(values.len()); + target.reserve(values.len() - num_skipped_rows); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - while values.len() > 0 { - let chunk = values.next_chunk()?.unwrap(); + while let Some(chunk) = values.next_chunk()? { + let chunk_len = chunk.len(); + + if chunk_len <= num_skipped_rows { + num_skipped_rows -= chunk_len; + continue; + } match chunk { HybridRleChunk::Rle(value, length) => { - if length == 0 { - continue; - } - let target_slice; // SAFETY: // 1. `target_ptr..target_ptr + values.len()` is allocated // 2. `length <= limit` unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, length); + target_slice = + std::slice::from_raw_parts_mut(target_ptr, length - num_skipped_rows); target_ptr = target_ptr.add(length); } @@ -45,6 +53,22 @@ pub fn decode( target_slice.fill(value); }, HybridRleChunk::Bitpacked(mut decoder) => { + if num_skipped_rows > 0 { + decoder.skip_chunks(num_skipped_rows / 32); + num_skipped_rows %= 32; + if let Some((chunk, chunk_size)) = decoder.chunked().next_inexact() { + let chunk = &chunk[num_skipped_rows..chunk_size]; + verify_dict_indices_slice(chunk, dict.len())?; + + for (i, &idx) in chunk.iter().enumerate() { + unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; + } + unsafe { + target_ptr = target_ptr.add(chunk_size); + } + } + } + let mut chunked = decoder.chunked(); for chunk in chunked.by_ref() { verify_dict_indices(&chunk, dict.len())?; @@ -69,6 +93,8 @@ pub fn decode( } }, } + + num_skipped_rows = 0; } unsafe { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs index da4385023732..85a88f500cf5 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -18,7 +18,7 @@ pub fn decode( // Dispatch to the non-filter kernel if all rows are needed anyway. if num_rows == filter.len() { values.limit_to(filter.len()); - return super::required::decode(values, dict, target); + return super::required::decode(values, dict, target, 0); } if dict.is_empty() && !filter.is_empty() { diff --git a/crates/polars-python/src/dataframe/io.rs b/crates/polars-python/src/dataframe/io.rs index 9b34eb7e8ae9..d32a2c11ba8a 100644 --- a/crates/polars-python/src/dataframe/io.rs +++ b/crates/polars-python/src/dataframe/io.rs @@ -425,7 +425,7 @@ impl PyDataFrame { let buf = get_file_like(py_f, true)?; py.allow_threads(|| { - ParquetWriter::new(buf) + ParquetWriter::new(BufWriter::new(buf)) .with_compression(compression) .with_statistics(statistics.0) .with_row_group_size(row_group_size) From 6ac6b6e63e7e05e79f26e95e1f4017ddad5265ce Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 5 Nov 2024 14:54:08 +0100 Subject: [PATCH 03/11] trim leading and trailing nulls for dictionary encoding --- .../dictionary_encoded/optional.rs | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs index a91a06e86b26..d1fbfcbe37cc 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -10,14 +10,30 @@ use crate::parquet::error::ParquetResult; pub fn decode( mut values: HybridRleDecoder<'_>, dict: &[B], - validity: Bitmap, + mut validity: Bitmap, target: &mut Vec, mut num_skipped_rows: usize, ) -> ParquetResult<()> { + target.reserve(validity.len() - num_skipped_rows); + + // Remove any leading and trailing nulls. This has two benefits: + // 1. It increases the chance of dispatching to the faster kernel (e.g. for sorted data) + // 2. It reduces the amount of iterations in the main loop and replaces it with `memset`s + let leading_nulls = validity.take_leading_zeros(); + let trailing_nulls = validity.take_trailing_zeros(); + + target.resize( + target.len() + leading_nulls.saturating_sub(num_skipped_rows), + B::zeroed(), + ); + num_skipped_rows = num_skipped_rows.saturating_sub(leading_nulls); + // Dispatch to the required kernel if all rows are valid anyway. if validity.set_bits() == validity.len() { values.limit_to(validity.len()); - return super::required::decode(values, dict, target, num_skipped_rows); + super::required::decode(values, dict, target, num_skipped_rows)?; + target.resize(target.len() + trailing_nulls, B::zeroed()); + return Ok(()); } if dict.is_empty() && validity.set_bits() > 0 { return Err(oob_dict_idx()); @@ -29,7 +45,6 @@ pub fn decode( let start_length = target.len(); let end_length = start_length + validity.len() - num_skipped_rows; - target.reserve(validity.len() - num_skipped_rows); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; values.limit_to(validity.set_bits() - num_skipped_values); @@ -196,15 +211,10 @@ pub fn decode( num_skipped_values = 0; } - if cfg!(debug_assertions) { - assert_eq!(validity.set_bits(), 0); - } - - let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, validity.len()) }; - target_slice.fill(B::zeroed()); unsafe { target.set_len(end_length); } + target.resize(target.len() + trailing_nulls, B::zeroed()); Ok(()) } From c078d35993219fa0744509451294c2bf7c4428e3 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 5 Nov 2024 15:14:41 +0100 Subject: [PATCH 04/11] trim leading and trailing filtered elements required_masked --- .../required_masked_dense.rs | 45 ++++++++--------- .../src/parquet/encoding/hybrid_rle/mod.rs | 48 +++++++++++++++++++ 2 files changed, 71 insertions(+), 22 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs index 85a88f500cf5..5934014f2349 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -10,15 +10,19 @@ use crate::parquet::error::ParquetResult; pub fn decode( mut values: HybridRleDecoder<'_>, dict: &[B], - filter: Bitmap, + mut filter: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { + let leading_zeros = filter.take_leading_zeros(); + filter.take_trailing_zeros(); + + values.limit_to(leading_zeros + filter.len()); + let num_rows = filter.set_bits(); // Dispatch to the non-filter kernel if all rows are needed anyway. if num_rows == filter.len() { - values.limit_to(filter.len()); - return super::required::decode(values, dict, target, 0); + return super::required::decode(values, dict, target, leading_zeros); } if dict.is_empty() && !filter.is_empty() { @@ -32,25 +36,20 @@ pub fn decode( let mut filter = BitMask::from_bitmap(&filter); - values.limit_to(filter.len()); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - let mut num_rows_left = num_rows; + let mut num_skipped_rows = leading_zeros; - for chunk in values.into_chunk_iter() { - if num_rows_left == 0 { - break; + while let Some(chunk) = values.next_chunk()? { + let chunk_len = chunk.len(); + if chunk_len <= num_skipped_rows { + num_skipped_rows -= chunk_len; + continue; } - match chunk? { + match chunk { HybridRleChunk::Rle(value, size) => { - if size == 0 { - continue; - } - - let size = size.min(filter.len()); - // If we know that we have `size` times `value` that we can append, but there might // be nulls in between those values. // @@ -62,7 +61,9 @@ pub fn decode( let current_filter; (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; - let num_chunk_rows = current_filter.set_bits(); + let num_chunk_rows = current_filter + .sliced(num_skipped_rows, current_filter.len() - num_skipped_rows) + .set_bits(); if num_chunk_rows > 0 { let target_slice; @@ -81,17 +82,16 @@ pub fn decode( }; target_slice.fill(*value); - num_rows_left -= num_chunk_rows; } }, HybridRleChunk::Bitpacked(mut decoder) => { - let size = decoder.len().min(filter.len()); - let mut chunked = decoder.chunked(); - let mut buffer_part_idx = 0; let mut values_offset = 0; let mut num_buffered: usize = 0; - let mut skip_values = 0; + let mut skip_values = num_skipped_rows; + + let size = decoder.len(); + let mut chunked = decoder.chunked(); let current_filter; @@ -166,7 +166,6 @@ pub fn decode( unsafe { target_ptr = target_ptr.add(num_written); } - num_rows_left -= num_written; ParquetResult::Ok(()) }; @@ -182,6 +181,8 @@ pub fn decode( iter(f, fl)?; }, } + + num_skipped_rows = 0; } unsafe { diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs index 827538d14dbc..b77b32a05ca1 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs @@ -132,6 +132,54 @@ impl<'a> HybridRleDecoder<'a> { })) } + pub fn next_chunk_length(&mut self) -> ParquetResult> { + if self.len() == 0 { + return Ok(None); + } + + if self.num_bits == 0 { + let num_values = self.num_values; + self.num_values = 0; + return Ok(Some(num_values)); + } + + if self.data.is_empty() { + return Ok(None); + } + + let (indicator, consumed) = uleb128::decode(self.data); + self.data = unsafe { self.data.get_unchecked(consumed..) }; + + Ok(Some(if indicator & 1 == 1 { + // is bitpacking + let bytes = (indicator as usize >> 1) * self.num_bits; + let bytes = std::cmp::min(bytes, self.data.len()); + let Some((packed, remaining)) = self.data.split_at_checked(bytes) else { + return Err(ParquetError::oos("Not enough bytes for bitpacked data")); + }; + self.data = remaining; + + let length = std::cmp::min(packed.len() * 8 / self.num_bits, self.num_values); + self.num_values -= length; + + length + } else { + // is rle + let run_length = indicator as usize >> 1; + // repeated-value := value that is repeated, using a fixed-width of round-up-to-next-byte(bit-width) + let rle_bytes = self.num_bits.div_ceil(8); + let Some(remaining) = self.data.get(rle_bytes..) else { + return Err(ParquetError::oos("Not enough bytes for RLE encoded data")); + }; + self.data = remaining; + + let length = std::cmp::min(run_length, self.num_values); + self.num_values -= length; + + length + })) + } + pub fn limit_to(&mut self, length: usize) { self.num_values = self.num_values.min(length); } From c9251c4ac96653283607d07b1f1df887697d4a4b Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 5 Nov 2024 15:28:25 +0100 Subject: [PATCH 05/11] divide skip into two loops --- .../dictionary_encoded/optional.rs | 39 +++++++++++++------ .../dictionary_encoded/required.rs | 23 ++++++++--- .../required_masked_dense.rs | 21 ++++++++-- 3 files changed, 63 insertions(+), 20 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs index d1fbfcbe37cc..82a388a89d29 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -52,20 +52,37 @@ pub fn decode( let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - while let Some(chunk) = values.next_chunk()? { - let chunk_len = chunk.len(); + // Skip over any whole HybridRleChunks + if num_skipped_values > 0 { + let mut total_num_skipped_values = 0; + + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; + + if chunk_len < num_skipped_values { + break; + } - if chunk_len <= num_skipped_values { + values = values_clone; num_skipped_values -= chunk_len; - if chunk_len > 0 { - let offset = validity - .nth_set_bit_idx(chunk_len - 1, 0) - .unwrap_or(validity.len()); - num_skipped_rows -= offset; - validity = validity.sliced(offset, validity.len() - offset); - } - continue; + total_num_skipped_values += chunk_len; + } + + if total_num_skipped_values > 0 { + let offset = validity + .nth_set_bit_idx(total_num_skipped_values - 1, 0) + .unwrap_or(validity.len()); + num_skipped_rows -= offset; + validity = validity.sliced(offset, validity.len() - offset); } + } + + + while let Some(chunk) = values.next_chunk()? { + debug_assert!(chunk.len() < num_skipped_rows); match chunk { HybridRleChunk::Rle(value, size) => { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs index 4407a3fad852..121d934231ed 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -26,13 +26,26 @@ pub fn decode( target.reserve(values.len() - num_skipped_rows); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - while let Some(chunk) = values.next_chunk()? { - let chunk_len = chunk.len(); - - if chunk_len <= num_skipped_rows { + // Skip over any whole HybridRleChunks if possible + if num_skipped_rows > 0 { + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; + + if chunk_len < num_skipped_rows { + break; + } + + values = values_clone; num_skipped_rows -= chunk_len; - continue; } + } + + + while let Some(chunk) = values.next_chunk()? { + debug_assert!(chunk.len() < num_skipped_rows); match chunk { HybridRleChunk::Rle(value, length) => { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs index 5934014f2349..2543b86a6ba8 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -41,12 +41,25 @@ pub fn decode( let mut num_skipped_rows = leading_zeros; - while let Some(chunk) = values.next_chunk()? { - let chunk_len = chunk.len(); - if chunk_len <= num_skipped_rows { + // Skip over any whole HybridRleChunks + if num_skipped_rows > 0 { + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; + + if chunk_len < num_skipped_rows { + break; + } + + values = values_clone; num_skipped_rows -= chunk_len; - continue; } + } + + while let Some(chunk) = values.next_chunk()? { + debug_assert!(chunk.len() < num_skipped_rows); match chunk { HybridRleChunk::Rle(value, size) => { From 9d5e658e5f5929889b54a866190c86e017d3d672 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 5 Nov 2024 15:30:38 +0100 Subject: [PATCH 06/11] num_skipped_rows -> num_rows_to_skip --- .../dictionary_encoded/optional.rs | 46 +++++++++---------- .../dictionary_encoded/required.rs | 30 ++++++------ .../required_masked_dense.rs | 16 +++---- 3 files changed, 46 insertions(+), 46 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs index 82a388a89d29..04b53105f975 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -12,9 +12,9 @@ pub fn decode( dict: &[B], mut validity: Bitmap, target: &mut Vec, - mut num_skipped_rows: usize, + mut num_rows_to_skip: usize, ) -> ParquetResult<()> { - target.reserve(validity.len() - num_skipped_rows); + target.reserve(validity.len() - num_rows_to_skip); // Remove any leading and trailing nulls. This has two benefits: // 1. It increases the chance of dispatching to the faster kernel (e.g. for sorted data) @@ -23,15 +23,15 @@ pub fn decode( let trailing_nulls = validity.take_trailing_zeros(); target.resize( - target.len() + leading_nulls.saturating_sub(num_skipped_rows), + target.len() + leading_nulls.saturating_sub(num_rows_to_skip), B::zeroed(), ); - num_skipped_rows = num_skipped_rows.saturating_sub(leading_nulls); + num_rows_to_skip = num_rows_to_skip.saturating_sub(leading_nulls); // Dispatch to the required kernel if all rows are valid anyway. if validity.set_bits() == validity.len() { values.limit_to(validity.len()); - super::required::decode(values, dict, target, num_skipped_rows)?; + super::required::decode(values, dict, target, num_rows_to_skip)?; target.resize(target.len() + trailing_nulls, B::zeroed()); return Ok(()); } @@ -39,21 +39,21 @@ pub fn decode( return Err(oob_dict_idx()); } - let mut num_skipped_values = validity.clone().sliced(0, num_skipped_rows).set_bits(); + let mut num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits(); assert!(validity.set_bits() <= values.len()); let start_length = target.len(); - let end_length = start_length + validity.len() - num_skipped_rows; + let end_length = start_length + validity.len() - num_rows_to_skip; let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - values.limit_to(validity.set_bits() - num_skipped_values); + values.limit_to(validity.set_bits() - num_values_to_skip); let mut validity = BitMask::from_bitmap(&validity); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; // Skip over any whole HybridRleChunks - if num_skipped_values > 0 { + if num_values_to_skip > 0 { let mut total_num_skipped_values = 0; loop { @@ -62,12 +62,12 @@ pub fn decode( break; }; - if chunk_len < num_skipped_values { + if chunk_len < num_values_to_skip { break; } values = values_clone; - num_skipped_values -= chunk_len; + num_values_to_skip -= chunk_len; total_num_skipped_values += chunk_len; } @@ -75,14 +75,14 @@ pub fn decode( let offset = validity .nth_set_bit_idx(total_num_skipped_values - 1, 0) .unwrap_or(validity.len()); - num_skipped_rows -= offset; + num_rows_to_skip -= offset; validity = validity.sliced(offset, validity.len() - offset); } } while let Some(chunk) = values.next_chunk()? { - debug_assert!(chunk.len() < num_skipped_rows); + debug_assert!(chunk.len() < num_rows_to_skip); match chunk { HybridRleChunk::Rle(value, size) => { @@ -95,7 +95,7 @@ pub fn decode( // 3. Advance the validity mask by `num_rows` values. let num_chunk_rows = validity - .nth_set_bit_idx(size, num_skipped_rows) + .nth_set_bit_idx(size, num_rows_to_skip) .unwrap_or(validity.len()); (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_rows) }; @@ -118,10 +118,10 @@ pub fn decode( target_slice.fill(value); }, HybridRleChunk::Bitpacked(mut decoder) => { - if num_skipped_values > 0 { - validity = validity.sliced(num_skipped_rows, validity.len() - num_skipped_rows); - decoder.skip_chunks(num_skipped_values / 32); - num_skipped_values %= 32; + if num_values_to_skip > 0 { + validity = validity.sliced(num_rows_to_skip, validity.len() - num_rows_to_skip); + decoder.skip_chunks(num_values_to_skip / 32); + num_values_to_skip %= 32; } let mut chunked = decoder.chunked(); @@ -135,7 +135,7 @@ pub fn decode( let mut validity_iter = validity.fast_iter_u56(); 'outer: for v in validity_iter.by_ref() { - while num_buffered - num_skipped_values < v.count_ones() as usize { + while num_buffered - num_values_to_skip < v.count_ones() as usize { let buffer_part = <&mut [u32; 32]>::try_from( &mut values_buffer[buffer_part_idx * 32..][..32], ) @@ -146,8 +146,8 @@ pub fn decode( verify_dict_indices(buffer_part, dict.len())?; - let num_added_skipped = num_added.min(num_skipped_values); - num_skipped_values -= num_added_skipped; + let num_added_skipped = num_added.min(num_values_to_skip); + num_values_to_skip -= num_added_skipped; num_buffered += num_added - num_added_skipped; @@ -224,8 +224,8 @@ pub fn decode( }, } - num_skipped_rows = 0; - num_skipped_values = 0; + num_rows_to_skip = 0; + num_values_to_skip = 0; } unsafe { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs index 121d934231ed..06b20b6ca775 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -9,11 +9,11 @@ pub fn decode( mut values: HybridRleDecoder<'_>, dict: &[B], target: &mut Vec, - mut num_skipped_rows: usize, + mut num_rows_to_skip: usize, ) -> ParquetResult<()> { - debug_assert!(num_skipped_rows <= values.len()); + debug_assert!(num_rows_to_skip <= values.len()); - if num_skipped_rows == values.len() { + if num_rows_to_skip == values.len() { return Ok(()); } if dict.is_empty() && values.len() > 0 { @@ -21,31 +21,31 @@ pub fn decode( } let start_length = target.len(); - let end_length = start_length + values.len() - num_skipped_rows; + let end_length = start_length + values.len() - num_rows_to_skip; - target.reserve(values.len() - num_skipped_rows); + target.reserve(values.len() - num_rows_to_skip); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; // Skip over any whole HybridRleChunks if possible - if num_skipped_rows > 0 { + if num_rows_to_skip > 0 { loop { let mut values_clone = values.clone(); let Some(chunk_len) = values_clone.next_chunk_length()? else { break; }; - if chunk_len < num_skipped_rows { + if chunk_len < num_rows_to_skip { break; } values = values_clone; - num_skipped_rows -= chunk_len; + num_rows_to_skip -= chunk_len; } } while let Some(chunk) = values.next_chunk()? { - debug_assert!(chunk.len() < num_skipped_rows); + debug_assert!(chunk.len() < num_rows_to_skip); match chunk { HybridRleChunk::Rle(value, length) => { @@ -55,7 +55,7 @@ pub fn decode( // 2. `length <= limit` unsafe { target_slice = - std::slice::from_raw_parts_mut(target_ptr, length - num_skipped_rows); + std::slice::from_raw_parts_mut(target_ptr, length - num_rows_to_skip); target_ptr = target_ptr.add(length); } @@ -66,11 +66,11 @@ pub fn decode( target_slice.fill(value); }, HybridRleChunk::Bitpacked(mut decoder) => { - if num_skipped_rows > 0 { - decoder.skip_chunks(num_skipped_rows / 32); - num_skipped_rows %= 32; + if num_rows_to_skip > 0 { + decoder.skip_chunks(num_rows_to_skip / 32); + num_rows_to_skip %= 32; if let Some((chunk, chunk_size)) = decoder.chunked().next_inexact() { - let chunk = &chunk[num_skipped_rows..chunk_size]; + let chunk = &chunk[num_rows_to_skip..chunk_size]; verify_dict_indices_slice(chunk, dict.len())?; for (i, &idx) in chunk.iter().enumerate() { @@ -107,7 +107,7 @@ pub fn decode( }, } - num_skipped_rows = 0; + num_rows_to_skip = 0; } unsafe { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs index 2543b86a6ba8..56cdd20b1c3f 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -39,27 +39,27 @@ pub fn decode( let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - let mut num_skipped_rows = leading_zeros; + let mut num_rows_to_skip = leading_zeros; // Skip over any whole HybridRleChunks - if num_skipped_rows > 0 { + if num_rows_to_skip > 0 { loop { let mut values_clone = values.clone(); let Some(chunk_len) = values_clone.next_chunk_length()? else { break; }; - if chunk_len < num_skipped_rows { + if chunk_len < num_rows_to_skip { break; } values = values_clone; - num_skipped_rows -= chunk_len; + num_rows_to_skip -= chunk_len; } } while let Some(chunk) = values.next_chunk()? { - debug_assert!(chunk.len() < num_skipped_rows); + debug_assert!(chunk.len() < num_rows_to_skip); match chunk { HybridRleChunk::Rle(value, size) => { @@ -75,7 +75,7 @@ pub fn decode( (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; let num_chunk_rows = current_filter - .sliced(num_skipped_rows, current_filter.len() - num_skipped_rows) + .sliced(num_rows_to_skip, current_filter.len() - num_rows_to_skip) .set_bits(); if num_chunk_rows > 0 { @@ -101,7 +101,7 @@ pub fn decode( let mut buffer_part_idx = 0; let mut values_offset = 0; let mut num_buffered: usize = 0; - let mut skip_values = num_skipped_rows; + let mut skip_values = num_rows_to_skip; let size = decoder.len(); let mut chunked = decoder.chunked(); @@ -195,7 +195,7 @@ pub fn decode( }, } - num_skipped_rows = 0; + num_rows_to_skip = 0; } unsafe { From d582b0e82cec9fb973c482135f262e9008e7abab Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 5 Nov 2024 15:46:09 +0100 Subject: [PATCH 07/11] lowering in optional_masked --- .../deserialize/dictionary_encoded/mod.rs | 2 +- .../dictionary_encoded/optional.rs | 4 +- .../optional_masked_dense.rs | 86 ++++++++++++++----- .../dictionary_encoded/required.rs | 4 +- .../required_masked_dense.rs | 7 +- 5 files changed, 74 insertions(+), 29 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs index 1e8f6ec0c5c9..8ae6680e088b 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs @@ -63,7 +63,7 @@ pub fn decode_dict_dispatch( optional::decode(values, dict, page_validity, target, rng.start) }, (Some(Filter::Mask(filter)), None) => { - required_masked_dense::decode(values, dict, filter, target) + required_masked_dense::decode(values, dict, filter, target, 0) }, (Some(Filter::Mask(filter)), Some(page_validity)) => { optional_masked_dense::decode(values, dict, filter, page_validity, target) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs index 04b53105f975..5ef48f621e62 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -62,7 +62,7 @@ pub fn decode( break; }; - if chunk_len < num_values_to_skip { + if num_values_to_skip < chunk_len { break; } @@ -82,7 +82,7 @@ pub fn decode( while let Some(chunk) = values.next_chunk()? { - debug_assert!(chunk.len() < num_rows_to_skip); + debug_assert!(num_rows_to_skip < chunk.len()); match chunk { HybridRleChunk::Rle(value, size) => { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs index 37364d58b6b6..a55aa262b482 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs @@ -1,3 +1,4 @@ +use arrow::array::Splitable; use arrow::bitmap::bitmask::BitMask; use arrow::bitmap::Bitmap; use arrow::types::AlignedBytes; @@ -10,29 +11,43 @@ use crate::parquet::error::ParquetResult; pub fn decode( mut values: HybridRleDecoder<'_>, dict: &[B], - filter: Bitmap, - validity: Bitmap, + mut filter: Bitmap, + mut validity: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { + let leading_filtered = filter.take_leading_zeros(); + filter.take_trailing_zeros(); + let num_rows = filter.set_bits(); - let num_valid_values = validity.set_bits(); + + let leading_validity; + (leading_validity, validity) = validity.split_at(leading_filtered); + + let mut num_rows_to_skip = leading_filtered; + let mut num_values_to_skip = leading_validity.set_bits(); + + validity = validity.sliced(0, filter.len()); // Dispatch to the non-filter kernel if all rows are needed anyway. if num_rows == filter.len() { - return super::optional::decode(values, dict, validity, target, 0); + return super::optional::decode(values, dict, validity, target, num_rows_to_skip); } - // Dispatch to the required kernel if all rows are valid anyway. - if num_valid_values == validity.len() { - return super::required_masked_dense::decode(values, dict, filter, target); + if validity.set_bits() == validity.len() { + return super::required_masked_dense::decode( + values, + dict, + filter, + target, + num_values_to_skip, + ); } - - if dict.is_empty() && num_valid_values > 0 { + if dict.is_empty() && validity.set_bits() > 0 { return Err(oob_dict_idx()); } debug_assert_eq!(filter.len(), validity.len()); - assert!(num_valid_values <= values.len()); + assert!(validity.set_bits() <= values.len()); let start_length = target.len(); target.reserve(num_rows); @@ -41,19 +56,40 @@ pub fn decode( let mut filter = BitMask::from_bitmap(&filter); let mut validity = BitMask::from_bitmap(&validity); - values.limit_to(num_valid_values); + values.limit_to(num_values_to_skip + validity.set_bits()); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - let mut num_rows_left = num_rows; + // Skip over any whole HybridRleChunks + if num_values_to_skip > 0 { + let mut total_num_skipped_values = 0; + + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; - for chunk in values.into_chunk_iter() { - // Early stop if we have no more rows to load. - if num_rows_left == 0 { - break; + if num_values_to_skip < chunk_len { + break; + } + + values = values_clone; + num_values_to_skip -= chunk_len; + total_num_skipped_values += chunk_len; + } + + if total_num_skipped_values > 0 { + let offset = validity + .nth_set_bit_idx(total_num_skipped_values - 1, 0) + .unwrap_or(validity.len()); + num_rows_to_skip -= offset; + validity = validity.sliced(offset, validity.len() - offset); } + } - match chunk? { + while let Some(chunk) = values.next_chunk()? { + match chunk { HybridRleChunk::Rle(value, size) => { if size == 0 { continue; @@ -67,7 +103,9 @@ pub fn decode( // 2. Fill `num_rows` values into the target buffer. // 3. Advance the validity mask by `num_rows` values. - let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + let num_chunk_values = validity + .nth_set_bit_idx(size, num_rows_to_skip) + .unwrap_or(validity.len()); let current_filter; (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_values) }; @@ -92,10 +130,15 @@ pub fn decode( }; target_slice.fill(*value); - num_rows_left -= num_chunk_rows; } }, HybridRleChunk::Bitpacked(mut decoder) => { + if num_values_to_skip > 0 { + validity = validity.sliced(num_rows_to_skip, validity.len() - num_rows_to_skip); + decoder.skip_chunks(num_values_to_skip / 32); + num_values_to_skip %= 32; + } + // For bitpacked we do the following: // 1. See how many rows are encoded by this `decoder`. // 2. Go through the filter and validity 56 bits at a time and: @@ -194,7 +237,6 @@ pub fn decode( unsafe { target_ptr = target_ptr.add(num_written); } - num_rows_left -= num_written; ParquetResult::Ok(()) }; @@ -214,13 +256,15 @@ pub fn decode( iter(f, v)?; }, } + + num_rows_to_skip = 0; } if cfg!(debug_assertions) { assert_eq!(validity.set_bits(), 0); } - let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, num_rows_left) }; + let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, validity.len()) }; target_slice.fill(B::zeroed()); unsafe { target.set_len(start_length + num_rows); diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs index 06b20b6ca775..a7c5a516c438 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -34,7 +34,7 @@ pub fn decode( break; }; - if chunk_len < num_rows_to_skip { + if num_rows_to_skip < chunk_len { break; } @@ -45,7 +45,7 @@ pub fn decode( while let Some(chunk) = values.next_chunk()? { - debug_assert!(chunk.len() < num_rows_to_skip); + debug_assert!(num_rows_to_skip < chunk.len()); match chunk { HybridRleChunk::Rle(value, length) => { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs index 56cdd20b1c3f..9165928c3963 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -12,6 +12,7 @@ pub fn decode( dict: &[B], mut filter: Bitmap, target: &mut Vec, + mut num_rows_to_skip: usize, ) -> ParquetResult<()> { let leading_zeros = filter.take_leading_zeros(); filter.take_trailing_zeros(); @@ -39,7 +40,7 @@ pub fn decode( let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - let mut num_rows_to_skip = leading_zeros; + num_rows_to_skip += leading_zeros; // Skip over any whole HybridRleChunks if num_rows_to_skip > 0 { @@ -49,7 +50,7 @@ pub fn decode( break; }; - if chunk_len < num_rows_to_skip { + if num_rows_to_skip < chunk_len { break; } @@ -59,7 +60,7 @@ pub fn decode( } while let Some(chunk) = values.next_chunk()? { - debug_assert!(chunk.len() < num_rows_to_skip); + debug_assert!(num_rows_to_skip < chunk.len()); match chunk { HybridRleChunk::Rle(value, size) => { From 1b0e9c99e452637d8d56d2f139f0313443b20608 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 5 Nov 2024 17:13:44 +0100 Subject: [PATCH 08/11] buggy but full version --- .../dictionary_encoded/optional.rs | 167 ++++++++---------- .../optional_masked_dense.rs | 15 +- .../dictionary_encoded/required.rs | 2 +- .../required_masked_dense.rs | 6 +- 4 files changed, 87 insertions(+), 103 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs index 5ef48f621e62..a5459d9a62b3 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -14,6 +14,8 @@ pub fn decode( target: &mut Vec, mut num_rows_to_skip: usize, ) -> ParquetResult<()> { + debug_assert!(num_rows_to_skip <= validity.len()); + target.reserve(validity.len() - num_rows_to_skip); // Remove any leading and trailing nulls. This has two benefits: @@ -22,17 +24,23 @@ pub fn decode( let leading_nulls = validity.take_leading_zeros(); let trailing_nulls = validity.take_trailing_zeros(); + let skipped_leading_nulls = leading_nulls.min(num_rows_to_skip); target.resize( - target.len() + leading_nulls.saturating_sub(num_rows_to_skip), + target.len() + leading_nulls - skipped_leading_nulls, B::zeroed(), ); - num_rows_to_skip = num_rows_to_skip.saturating_sub(leading_nulls); + num_rows_to_skip -= skipped_leading_nulls; // Dispatch to the required kernel if all rows are valid anyway. if validity.set_bits() == validity.len() { values.limit_to(validity.len()); - super::required::decode(values, dict, target, num_rows_to_skip)?; - target.resize(target.len() + trailing_nulls, B::zeroed()); + let skipped_values = values.len().min(num_rows_to_skip); + super::required::decode(values, dict, target, skipped_values)?; + num_rows_to_skip -= skipped_values; + target.resize( + target.len() + trailing_nulls - num_rows_to_skip, + B::zeroed(), + ); return Ok(()); } if dict.is_empty() && validity.set_bits() > 0 { @@ -42,12 +50,12 @@ pub fn decode( let mut num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits(); assert!(validity.set_bits() <= values.len()); + let start_length = target.len(); let end_length = start_length + validity.len() - num_rows_to_skip; - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - values.limit_to(validity.set_bits() - num_values_to_skip); + values.limit_to(validity.set_bits()); let mut validity = BitMask::from_bitmap(&validity); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; @@ -80,9 +88,8 @@ pub fn decode( } } - while let Some(chunk) = values.next_chunk()? { - debug_assert!(num_rows_to_skip < chunk.len()); + debug_assert!(num_values_to_skip < chunk.len()); match chunk { HybridRleChunk::Rle(value, size) => { @@ -94,9 +101,8 @@ pub fn decode( // 2. Fill `num_rows` values into the target buffer. // 3. Advance the validity mask by `num_rows` values. - let num_chunk_rows = validity - .nth_set_bit_idx(size, num_rows_to_skip) - .unwrap_or(validity.len()); + let num_chunk_rows = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + let num_chunk_rows = num_chunk_rows - num_rows_to_skip; (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_rows) }; @@ -118,109 +124,83 @@ pub fn decode( target_slice.fill(value); }, HybridRleChunk::Bitpacked(mut decoder) => { - if num_values_to_skip > 0 { - validity = validity.sliced(num_rows_to_skip, validity.len() - num_rows_to_skip); + if num_values_to_skip >= 32 { decoder.skip_chunks(num_values_to_skip / 32); + let num_rows_skipped = validity + .nth_set_bit_idx(num_values_to_skip - num_values_to_skip % 32 - 1, 0) + .map_or(validity.len(), |v| v + 1); + validity = validity.sliced(num_rows_skipped, validity.len() - num_rows_skipped); num_values_to_skip %= 32; } + let num_rows_for_decoder = validity + .nth_set_bit_idx(decoder.len(), 0) + .unwrap_or(validity.len()); + let decoder_validity; + (decoder_validity, validity) = validity.split_at(num_rows_for_decoder); + let mut chunked = decoder.chunked(); let mut buffer_part_idx = 0; let mut values_offset = 0; let mut num_buffered: usize = 0; - { - let mut num_done = 0; - let mut validity_iter = validity.fast_iter_u56(); + let mut iter = |v: u64, n: usize| { + while num_buffered < v.count_ones() as usize + num_values_to_skip { + buffer_part_idx %= 4; + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let Some(num_added) = chunked.next_into(buffer_part) else { + return Ok(()); + }; - 'outer: for v in validity_iter.by_ref() { - while num_buffered - num_values_to_skip < v.count_ones() as usize { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let Some(num_added) = chunked.next_into(buffer_part) else { - break 'outer; - }; + verify_dict_indices(buffer_part, dict.len())?; - verify_dict_indices(buffer_part, dict.len())?; + let num_added_skipped = num_added.min(num_values_to_skip); + num_values_to_skip -= num_added_skipped; - let num_added_skipped = num_added.min(num_values_to_skip); - num_values_to_skip -= num_added_skipped; - - num_buffered += num_added - num_added_skipped; - - buffer_part_idx += 1; - buffer_part_idx %= 4; - } - - let mut num_read = 0; - - for i in 0..56 { - let idx = values_buffer[(values_offset + num_read) % 128]; - - // SAFETY: - // 1. `values_buffer` starts out as only zeros, which we know is in the - // dictionary following the original `dict.is_empty` check. - // 2. Each time we write to `values_buffer`, it is followed by a - // `verify_dict_indices`. - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { target_ptr.add(i).write(value) }; - num_read += ((v >> i) & 1) as usize; - } - - values_offset += num_read; + values_offset += num_added_skipped; values_offset %= 128; - num_buffered -= num_read; - unsafe { - target_ptr = target_ptr.add(56); - } - num_done += 56; - } + num_buffered += num_added - num_added_skipped; - (_, validity) = unsafe { validity.split_at_unchecked(num_done) }; - } - - let num_decoder_remaining = num_buffered + chunked.decoder.len(); - let decoder_limit = validity - .nth_set_bit_idx(num_decoder_remaining, 0) - .unwrap_or(validity.len()); - - let current_validity; - (current_validity, validity) = - unsafe { validity.split_at_unchecked(decoder_limit) }; - let (v, _) = current_validity.fast_iter_u56().remainder(); + buffer_part_idx += 1; + } - while num_buffered < v.count_ones() as usize { - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let num_added = chunked.next_into(buffer_part).unwrap(); + let mut num_read = 0; - verify_dict_indices(buffer_part, dict.len())?; + for i in 0..n { + let idx = values_buffer[(values_offset + num_read) % 128]; - num_buffered += num_added; + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + let value = unsafe { dict.get_unchecked(idx as usize) }; + let value = *value; + unsafe { target_ptr.add(i).write(value) }; + num_read += ((v >> i) & 1) as usize; + } - buffer_part_idx += 1; - buffer_part_idx %= 4; - } + values_offset += num_read; + values_offset %= 128; + num_buffered -= num_read; + unsafe { + target_ptr = target_ptr.add(n); + } - let mut num_read = 0; + ParquetResult::Ok(()) + }; - for i in 0..decoder_limit { - let idx = values_buffer[(values_offset + num_read) % 128]; - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { *target_ptr.add(i) = value }; - num_read += ((v >> i) & 1) as usize; + let mut validity_iter = decoder_validity.fast_iter_u56(); + for v in validity_iter.by_ref() { + iter(v, 56)?; } - unsafe { - target_ptr = target_ptr.add(decoder_limit); - } + let (v, vl) = validity_iter.remainder(); + iter(v, vl)?; }, } @@ -231,7 +211,10 @@ pub fn decode( unsafe { target.set_len(end_length); } - target.resize(target.len() + trailing_nulls, B::zeroed()); + target.resize( + target.len() + trailing_nulls - num_rows_to_skip, + B::zeroed(), + ); Ok(()) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs index a55aa262b482..597ad8547524 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs @@ -15,6 +15,8 @@ pub fn decode( mut validity: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { + debug_assert_eq!(filter.len(), validity.len()); + let leading_filtered = filter.take_leading_zeros(); filter.take_trailing_zeros(); @@ -89,6 +91,8 @@ pub fn decode( } while let Some(chunk) = values.next_chunk()? { + debug_assert!(num_values_to_skip < chunk.len()); + match chunk { HybridRleChunk::Rle(value, size) => { if size == 0 { @@ -133,12 +137,6 @@ pub fn decode( } }, HybridRleChunk::Bitpacked(mut decoder) => { - if num_values_to_skip > 0 { - validity = validity.sliced(num_rows_to_skip, validity.len() - num_rows_to_skip); - decoder.skip_chunks(num_values_to_skip / 32); - num_values_to_skip %= 32; - } - // For bitpacked we do the following: // 1. See how many rows are encoded by this `decoder`. // 2. Go through the filter and validity 56 bits at a time and: @@ -156,7 +154,7 @@ pub fn decode( let mut buffer_part_idx = 0; let mut values_offset = 0; let mut num_buffered: usize = 0; - let mut skip_values = 0; + let mut skip_values = num_values_to_skip; let current_filter; let current_validity; @@ -258,13 +256,14 @@ pub fn decode( } num_rows_to_skip = 0; + num_values_to_skip = 0; } if cfg!(debug_assertions) { assert_eq!(validity.set_bits(), 0); } - let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, validity.len()) }; + let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, validity.len() - num_rows_to_skip) }; target_slice.fill(B::zeroed()); unsafe { target.set_len(start_length + num_rows); diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs index a7c5a516c438..81f91350f53e 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -77,7 +77,7 @@ pub fn decode( unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; } unsafe { - target_ptr = target_ptr.add(chunk_size); + target_ptr = target_ptr.add(chunk_size - num_rows_to_skip); } } } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs index 9165928c3963..5b224176dc0c 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -12,8 +12,10 @@ pub fn decode( dict: &[B], mut filter: Bitmap, target: &mut Vec, - mut num_rows_to_skip: usize, + num_values_to_skip: usize, ) -> ParquetResult<()> { + debug_assert!(values.len() >= filter.len()); + let leading_zeros = filter.take_leading_zeros(); filter.take_trailing_zeros(); @@ -40,7 +42,7 @@ pub fn decode( let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - num_rows_to_skip += leading_zeros; + let mut num_rows_to_skip = leading_zeros + num_values_to_skip; // Skip over any whole HybridRleChunks if num_rows_to_skip > 0 { From 5f6876f76635b2b8437af59dbfb521775ee748d6 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Wed, 6 Nov 2024 09:18:41 +0100 Subject: [PATCH 09/11] working optional and required kernels --- crates/polars-arrow/src/bitmap/bitmask.rs | 18 +- .../deserialize/dictionary_encoded/mod.rs | 5 + .../dictionary_encoded/optional.rs | 307 +++++++++--------- .../dictionary_encoded/required.rs | 69 ++-- 4 files changed, 199 insertions(+), 200 deletions(-) diff --git a/crates/polars-arrow/src/bitmap/bitmask.rs b/crates/polars-arrow/src/bitmap/bitmask.rs index 2e4e45195266..637ae06a8e6b 100644 --- a/crates/polars-arrow/src/bitmap/bitmask.rs +++ b/crates/polars-arrow/src/bitmap/bitmask.rs @@ -4,7 +4,7 @@ use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount}; use polars_utils::slice::load_padded_le_u64; use super::iterator::FastU56BitmapIter; -use super::utils::{count_zeros, BitmapIter}; +use super::utils::{count_zeros, fmt, BitmapIter}; use crate::bitmap::Bitmap; /// Returns the nth set bit in w, if n+1 bits are set. The indexing is @@ -79,6 +79,15 @@ pub struct BitMask<'a> { len: usize, } +impl std::fmt::Debug for BitMask<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { bytes, offset, len } = self; + let offset_num_bytes = offset / 8; + let offset_in_byte = offset % 8; + fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f) + } +} + impl<'a> BitMask<'a> { pub fn from_bitmap(bitmap: &'a Bitmap) -> Self { let (bytes, offset, len) = bitmap.as_slice(); @@ -92,6 +101,13 @@ impl<'a> BitMask<'a> { self.len } + #[inline] + pub fn advance_by(&mut self, idx: usize) { + assert!(idx <= self.len); + self.offset += idx; + self.len -= idx; + } + #[inline] pub fn split_at(&self, idx: usize) -> (Self, Self) { assert!(idx <= self.len); diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs index 8ae6680e088b..11a334913120 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs @@ -134,6 +134,11 @@ fn oob_dict_idx() -> ParquetError { ParquetError::oos("Dictionary Index is out-of-bounds") } +#[cold] +fn no_more_bitpacked_values() -> ParquetError { + ParquetError::oos("Bitpacked Hybrid-RLE ran out before all values were served") +} + #[inline(always)] fn verify_dict_indices(indices: &[u32; 32], dict_size: usize) -> ParquetResult<()> { let mut is_valid = true; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs index a5459d9a62b3..a95a83fc0db0 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -5,6 +5,7 @@ use arrow::types::AlignedBytes; use super::{oob_dict_idx, verify_dict_indices}; use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; use crate::parquet::error::ParquetResult; +use crate::read::deserialize::dictionary_encoded::no_more_bitpacked_values; #[inline(never)] pub fn decode( @@ -16,7 +17,10 @@ pub fn decode( ) -> ParquetResult<()> { debug_assert!(num_rows_to_skip <= validity.len()); - target.reserve(validity.len() - num_rows_to_skip); + let num_rows = validity.len() - num_rows_to_skip; + let end_length = target.len() + num_rows; + + target.reserve(num_rows); // Remove any leading and trailing nulls. This has two benefits: // 1. It increases the chance of dispatching to the faster kernel (e.g. for sorted data) @@ -24,6 +28,16 @@ pub fn decode( let leading_nulls = validity.take_leading_zeros(); let trailing_nulls = validity.take_trailing_zeros(); + // Special case: we don't even have to decode any values, since all non-skipped values are + // trailing nulls. + if num_rows_to_skip >= leading_nulls + validity.len() { + target.resize(end_length, B::zeroed()); + return Ok(()); + } + + values.limit_to(validity.set_bits()); + + // Add the leading nulls let skipped_leading_nulls = leading_nulls.min(num_rows_to_skip); target.resize( target.len() + leading_nulls - skipped_leading_nulls, @@ -31,190 +45,173 @@ pub fn decode( ); num_rows_to_skip -= skipped_leading_nulls; - // Dispatch to the required kernel if all rows are valid anyway. if validity.set_bits() == validity.len() { - values.limit_to(validity.len()); - let skipped_values = values.len().min(num_rows_to_skip); - super::required::decode(values, dict, target, skipped_values)?; - num_rows_to_skip -= skipped_values; - target.resize( - target.len() + trailing_nulls - num_rows_to_skip, - B::zeroed(), - ); - return Ok(()); - } - if dict.is_empty() && validity.set_bits() > 0 { - return Err(oob_dict_idx()); - } + // Dispatch to the required kernel if all rows are valid anyway. + super::required::decode(values, dict, target, num_rows_to_skip)?; + } else { + if dict.is_empty() && validity.set_bits() > 0 { + return Err(oob_dict_idx()); + } - let mut num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits(); + let mut num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits(); - assert!(validity.set_bits() <= values.len()); + let mut validity = BitMask::from_bitmap(&validity); + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; - let start_length = target.len(); - let end_length = start_length + validity.len() - num_rows_to_skip; - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + // Skip over any whole HybridRleChunks + if num_values_to_skip > 0 { + let mut total_num_skipped_values = 0; - values.limit_to(validity.set_bits()); - let mut validity = BitMask::from_bitmap(&validity); - let mut values_buffer = [0u32; 128]; - let values_buffer = &mut values_buffer; - - // Skip over any whole HybridRleChunks - if num_values_to_skip > 0 { - let mut total_num_skipped_values = 0; - - loop { - let mut values_clone = values.clone(); - let Some(chunk_len) = values_clone.next_chunk_length()? else { - break; - }; - - if num_values_to_skip < chunk_len { - break; - } + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; - values = values_clone; - num_values_to_skip -= chunk_len; - total_num_skipped_values += chunk_len; - } + if num_values_to_skip < chunk_len { + break; + } - if total_num_skipped_values > 0 { - let offset = validity - .nth_set_bit_idx(total_num_skipped_values - 1, 0) - .unwrap_or(validity.len()); - num_rows_to_skip -= offset; - validity = validity.sliced(offset, validity.len() - offset); - } - } + values = values_clone; + num_values_to_skip -= chunk_len; + total_num_skipped_values += chunk_len; + } - while let Some(chunk) = values.next_chunk()? { - debug_assert!(num_values_to_skip < chunk.len()); + if total_num_skipped_values > 0 { + let offset = validity + .nth_set_bit_idx(total_num_skipped_values - 1, 0) + .unwrap_or(validity.len()); + num_rows_to_skip -= offset; + validity = validity.sliced(offset, validity.len() - offset); + } + } - match chunk { - HybridRleChunk::Rle(value, size) => { - // If we know that we have `size` times `value` that we can append, but there might - // be nulls in between those values. - // - // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is - // done with `num_bits_before_nth_one` on the validity mask. - // 2. Fill `num_rows` values into the target buffer. - // 3. Advance the validity mask by `num_rows` values. + while let Some(chunk) = values.next_chunk()? { + debug_assert!(num_values_to_skip < chunk.len()); - let num_chunk_rows = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); - let num_chunk_rows = num_chunk_rows - num_rows_to_skip; + match chunk { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } - (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_rows) }; + // If we know that we have `size` times `value` that we can append, but there might + // be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is + // done with `num_bits_before_nth_one` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + // + let Some(&value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + + let num_chunk_rows = + validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + validity.advance_by(num_chunk_rows); + + target.resize(target.len() + num_chunk_rows - num_rows_to_skip, value); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let num_rows_for_decoder = validity + .nth_set_bit_idx(decoder.len(), 0) + .unwrap_or(validity.len()); + + let mut chunked = decoder.chunked(); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + + let mut decoder_validity; + (decoder_validity, validity) = validity.split_at(num_rows_for_decoder); + + if num_values_to_skip > 0 { + decoder_validity.advance_by(num_rows_to_skip); + + chunked.decoder.skip_chunks(num_values_to_skip / 32); + num_values_to_skip %= 32; + + if num_values_to_skip > 0 { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let Some(num_added) = chunked.next_into(buffer_part) else { + return Err(no_more_bitpacked_values()); + }; + + debug_assert!(num_values_to_skip <= num_added); + + verify_dict_indices(buffer_part, dict.len())?; + + values_offset += num_values_to_skip; + num_buffered += num_added - num_values_to_skip; + buffer_part_idx += 1; + } + } - let Some(&value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; + let mut iter = |v: u64, n: usize| { + while num_buffered < v.count_ones() as usize { + buffer_part_idx %= 4; - let target_slice; - // SAFETY: - // Given `validity_iter` before the `advance_by_bits` - // - // 1. `target_ptr..target_ptr + validity_iter.bits_left()` is allocated - // 2. `num_chunk_rows <= validity_iter.bits_left()` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); - target_ptr = target_ptr.add(num_chunk_rows); - } + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let Some(num_added) = chunked.next_into(buffer_part) else { + return Err(no_more_bitpacked_values()); + }; - target_slice.fill(value); - }, - HybridRleChunk::Bitpacked(mut decoder) => { - if num_values_to_skip >= 32 { - decoder.skip_chunks(num_values_to_skip / 32); - let num_rows_skipped = validity - .nth_set_bit_idx(num_values_to_skip - num_values_to_skip % 32 - 1, 0) - .map_or(validity.len(), |v| v + 1); - validity = validity.sliced(num_rows_skipped, validity.len() - num_rows_skipped); - num_values_to_skip %= 32; - } + verify_dict_indices(buffer_part, dict.len())?; - let num_rows_for_decoder = validity - .nth_set_bit_idx(decoder.len(), 0) - .unwrap_or(validity.len()); - let decoder_validity; - (decoder_validity, validity) = validity.split_at(num_rows_for_decoder); + num_buffered += num_added; - let mut chunked = decoder.chunked(); + buffer_part_idx += 1; + } - let mut buffer_part_idx = 0; - let mut values_offset = 0; - let mut num_buffered: usize = 0; + let mut num_read = 0; - let mut iter = |v: u64, n: usize| { - while num_buffered < v.count_ones() as usize + num_values_to_skip { - buffer_part_idx %= 4; - let buffer_part = <&mut [u32; 32]>::try_from( - &mut values_buffer[buffer_part_idx * 32..][..32], - ) - .unwrap(); - let Some(num_added) = chunked.next_into(buffer_part) else { - return Ok(()); - }; + target.extend((0..n).map(|i| { + let idx = values_buffer[(values_offset + num_read) % 128]; + num_read += ((v >> i) & 1) as usize; - verify_dict_indices(buffer_part, dict.len())?; + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + *unsafe { dict.get_unchecked(idx as usize) } + })); - let num_added_skipped = num_added.min(num_values_to_skip); - num_values_to_skip -= num_added_skipped; - values_offset += num_added_skipped; + values_offset += num_read; values_offset %= 128; - num_buffered += num_added - num_added_skipped; - - buffer_part_idx += 1; - } - - let mut num_read = 0; - - for i in 0..n { - let idx = values_buffer[(values_offset + num_read) % 128]; + num_buffered -= num_read; - // SAFETY: - // 1. `values_buffer` starts out as only zeros, which we know is in the - // dictionary following the original `dict.is_empty` check. - // 2. Each time we write to `values_buffer`, it is followed by a - // `verify_dict_indices`. - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { target_ptr.add(i).write(value) }; - num_read += ((v >> i) & 1) as usize; - } + ParquetResult::Ok(()) + }; - values_offset += num_read; - values_offset %= 128; - num_buffered -= num_read; - unsafe { - target_ptr = target_ptr.add(n); + let mut v_iter = decoder_validity.fast_iter_u56(); + for v in v_iter.by_ref() { + iter(v, 56)?; } - ParquetResult::Ok(()) - }; - - let mut validity_iter = decoder_validity.fast_iter_u56(); - for v in validity_iter.by_ref() { - iter(v, 56)?; - } + let (v, vl) = v_iter.remainder(); + iter(v, vl)?; + }, + } - let (v, vl) = validity_iter.remainder(); - iter(v, vl)?; - }, + num_rows_to_skip = 0; + num_values_to_skip = 0; } - - num_rows_to_skip = 0; - num_values_to_skip = 0; } - unsafe { - target.set_len(end_length); - } - target.resize( - target.len() + trailing_nulls - num_rows_to_skip, - B::zeroed(), - ); + // Add back the trailing nulls + debug_assert_eq!(target.len(), end_length - trailing_nulls); + target.resize(end_length, B::zeroed()); Ok(()) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs index 81f91350f53e..2e56ff7cd386 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -13,19 +13,18 @@ pub fn decode( ) -> ParquetResult<()> { debug_assert!(num_rows_to_skip <= values.len()); - if num_rows_to_skip == values.len() { + let num_rows = values.len() - num_rows_to_skip; + let end_length = target.len() + num_rows; + + if num_rows == 0 { return Ok(()); } - if dict.is_empty() && values.len() > 0 { + target.reserve(num_rows); + + if dict.is_empty() { return Err(oob_dict_idx()); } - let start_length = target.len(); - let end_length = start_length + values.len() - num_rows_to_skip; - - target.reserve(values.len() - num_rows_to_skip); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - // Skip over any whole HybridRleChunks if possible if num_rows_to_skip > 0 { loop { @@ -43,66 +42,50 @@ pub fn decode( } } - while let Some(chunk) = values.next_chunk()? { debug_assert!(num_rows_to_skip < chunk.len()); match chunk { - HybridRleChunk::Rle(value, length) => { - let target_slice; - // SAFETY: - // 1. `target_ptr..target_ptr + values.len()` is allocated - // 2. `length <= limit` - unsafe { - target_slice = - std::slice::from_raw_parts_mut(target_ptr, length - num_rows_to_skip); - target_ptr = target_ptr.add(length); + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; } - let Some(&value) = dict.get(value as usize) else { return Err(oob_dict_idx()); }; - - target_slice.fill(value); + target.resize(target.len() + size - num_rows_to_skip, value); }, HybridRleChunk::Bitpacked(mut decoder) => { if num_rows_to_skip > 0 { decoder.skip_chunks(num_rows_to_skip / 32); num_rows_to_skip %= 32; + if let Some((chunk, chunk_size)) = decoder.chunked().next_inexact() { let chunk = &chunk[num_rows_to_skip..chunk_size]; verify_dict_indices_slice(chunk, dict.len())?; - for (i, &idx) in chunk.iter().enumerate() { - unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; - } - unsafe { - target_ptr = target_ptr.add(chunk_size - num_rows_to_skip); - } + target.extend(chunk.iter().map(|&idx| { + // SAFETY: The dict indices were verified before. + *unsafe { dict.get_unchecked(idx as usize) } + })); } } let mut chunked = decoder.chunked(); for chunk in chunked.by_ref() { verify_dict_indices(&chunk, dict.len())?; - - for (i, &idx) in chunk.iter().enumerate() { - unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; - } - unsafe { - target_ptr = target_ptr.add(32); - } + target.extend(chunk.iter().map(|&idx| { + // SAFETY: The dict indices were verified before. + *unsafe { dict.get_unchecked(idx as usize) } + })); } if let Some((chunk, chunk_size)) = chunked.remainder() { verify_dict_indices_slice(&chunk[..chunk_size], dict.len())?; - - for (i, &idx) in chunk[..chunk_size].iter().enumerate() { - unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; - } - unsafe { - target_ptr = target_ptr.add(chunk_size); - } + target.extend(chunk[..chunk_size].iter().map(|&idx| { + // SAFETY: The dict indices were verified before. + *unsafe { dict.get_unchecked(idx as usize) } + })); } }, } @@ -110,9 +93,7 @@ pub fn decode( num_rows_to_skip = 0; } - unsafe { - target.set_len(end_length); - } + debug_assert_eq!(target.len(), end_length); Ok(()) } From deb4bf5e229008625e6b14bf5c5efa390a01bd67 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Wed, 6 Nov 2024 10:33:12 +0100 Subject: [PATCH 10/11] fully working and tested optional and required --- .../deserialize/dictionary_encoded/mod.rs | 67 ++++++++++++++++- .../dictionary_encoded/optional.rs | 75 +++++++------------ .../dictionary_encoded/required.rs | 27 +++---- py-polars/tests/unit/io/test_parquet.py | 57 ++++++++++++++ 4 files changed, 161 insertions(+), 65 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs index 11a334913120..d7bc8a7d0abf 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs @@ -1,3 +1,4 @@ +use arrow::bitmap::bitmask::BitMask; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::types::{AlignedBytes, NativeType}; use polars_compute::filter::filter_boolean_kernel; @@ -128,7 +129,6 @@ pub(crate) fn constrain_page_validity( }) } - #[cold] fn oob_dict_idx() -> ParquetError { ParquetError::oos("Dictionary Index is out-of-bounds") @@ -166,3 +166,68 @@ fn verify_dict_indices_slice(indices: &[u32], dict_size: usize) -> ParquetResult Err(oob_dict_idx()) } + +/// Skip over entire chunks in a [`HybridRleDecoder`] as long as all skipped chunks do not include +/// more than `num_values_to_skip` values. +#[inline(always)] +fn required_skip_whole_chunks( + values: &mut HybridRleDecoder<'_>, + num_values_to_skip: &mut usize, +) -> ParquetResult<()> { + if *num_values_to_skip == 0 { + return Ok(()); + } + + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; + if *num_values_to_skip < chunk_len { + break; + } + *values = values_clone; + *num_values_to_skip -= chunk_len; + } + + Ok(()) +} + +/// Skip over entire chunks in a [`HybridRleDecoder`] as long as all skipped chunks do not include +/// more than `num_values_to_skip` values. +#[inline(always)] +fn optional_skip_whole_chunks( + values: &mut HybridRleDecoder<'_>, + validity: &mut BitMask<'_>, + num_rows_to_skip: &mut usize, + num_values_to_skip: &mut usize, +) -> ParquetResult<()> { + if *num_values_to_skip == 0 { + return Ok(()); + } + + let mut total_num_skipped_values = 0; + + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; + if *num_values_to_skip < chunk_len { + break; + } + *values = values_clone; + *num_values_to_skip -= chunk_len; + total_num_skipped_values += chunk_len; + } + + if total_num_skipped_values > 0 { + let offset = validity + .nth_set_bit_idx(total_num_skipped_values - 1, 0) + .map_or(validity.len(), |v| v + 1); + *num_rows_to_skip -= offset; + validity.advance_by(offset); + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs index a95a83fc0db0..5b8ad8a5d2e0 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -2,11 +2,13 @@ use arrow::bitmap::bitmask::BitMask; use arrow::bitmap::Bitmap; use arrow::types::AlignedBytes; -use super::{oob_dict_idx, verify_dict_indices}; +use super::{ + no_more_bitpacked_values, oob_dict_idx, optional_skip_whole_chunks, verify_dict_indices, +}; use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; use crate::parquet::error::ParquetResult; -use crate::read::deserialize::dictionary_encoded::no_more_bitpacked_values; +/// Decoding kernel for optional dictionary encoded. #[inline(never)] pub fn decode( mut values: HybridRleDecoder<'_>, @@ -28,8 +30,7 @@ pub fn decode( let leading_nulls = validity.take_leading_zeros(); let trailing_nulls = validity.take_trailing_zeros(); - // Special case: we don't even have to decode any values, since all non-skipped values are - // trailing nulls. + // Special case: all values are skipped, just add the trailing null. if num_rows_to_skip >= leading_nulls + validity.len() { target.resize(end_length, B::zeroed()); return Ok(()); @@ -38,54 +39,37 @@ pub fn decode( values.limit_to(validity.set_bits()); // Add the leading nulls - let skipped_leading_nulls = leading_nulls.min(num_rows_to_skip); - target.resize( - target.len() + leading_nulls - skipped_leading_nulls, - B::zeroed(), - ); - num_rows_to_skip -= skipped_leading_nulls; + if num_rows_to_skip < leading_nulls { + target.resize(target.len() + leading_nulls - num_rows_to_skip, B::zeroed()); + num_rows_to_skip = 0; + } else { + num_rows_to_skip -= leading_nulls; + } if validity.set_bits() == validity.len() { // Dispatch to the required kernel if all rows are valid anyway. super::required::decode(values, dict, target, num_rows_to_skip)?; } else { - if dict.is_empty() && validity.set_bits() > 0 { + if dict.is_empty() { return Err(oob_dict_idx()); } - let mut num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits(); + let mut num_values_to_skip = 0; + if num_rows_to_skip > 0 { + num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits(); + } let mut validity = BitMask::from_bitmap(&validity); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; // Skip over any whole HybridRleChunks - if num_values_to_skip > 0 { - let mut total_num_skipped_values = 0; - - loop { - let mut values_clone = values.clone(); - let Some(chunk_len) = values_clone.next_chunk_length()? else { - break; - }; - - if num_values_to_skip < chunk_len { - break; - } - - values = values_clone; - num_values_to_skip -= chunk_len; - total_num_skipped_values += chunk_len; - } - - if total_num_skipped_values > 0 { - let offset = validity - .nth_set_bit_idx(total_num_skipped_values - 1, 0) - .unwrap_or(validity.len()); - num_rows_to_skip -= offset; - validity = validity.sliced(offset, validity.len() - offset); - } - } + optional_skip_whole_chunks( + &mut values, + &mut validity, + &mut num_rows_to_skip, + &mut num_values_to_skip, + )?; while let Some(chunk) = values.next_chunk()? { debug_assert!(num_values_to_skip < chunk.len()); @@ -96,14 +80,14 @@ pub fn decode( continue; } - // If we know that we have `size` times `value` that we can append, but there might - // be nulls in between those values. + // If we know that we have `size` times `value` that we can append, but there + // might be nulls in between those values. // - // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is - // done with `num_bits_before_nth_one` on the validity mask. + // 1. See how many `num_rows = valid + invalid` values `size` would entail. + // This is done with `nth_set_bit_idx` on the validity mask. // 2. Fill `num_rows` values into the target buffer. // 3. Advance the validity mask by `num_rows` values. - // + let Some(&value) = dict.get(value as usize) else { return Err(oob_dict_idx()); }; @@ -128,7 +112,8 @@ pub fn decode( let mut decoder_validity; (decoder_validity, validity) = validity.split_at(num_rows_for_decoder); - if num_values_to_skip > 0 { + // Skip over any remaining values. + if num_rows_to_skip > 0 { decoder_validity.advance_by(num_rows_to_skip); chunked.decoder.skip_chunks(num_values_to_skip / 32); @@ -144,7 +129,6 @@ pub fn decode( }; debug_assert!(num_values_to_skip <= num_added); - verify_dict_indices(buffer_part, dict.len())?; values_offset += num_values_to_skip; @@ -186,7 +170,6 @@ pub fn decode( *unsafe { dict.get_unchecked(idx as usize) } })); - values_offset += num_read; values_offset %= 128; num_buffered -= num_read; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs index 2e56ff7cd386..443ab3633679 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -1,9 +1,12 @@ use arrow::types::AlignedBytes; -use super::{oob_dict_idx, verify_dict_indices, verify_dict_indices_slice}; +use super::{ + oob_dict_idx, required_skip_whole_chunks, verify_dict_indices, verify_dict_indices_slice, +}; use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; use crate::parquet::error::ParquetResult; +/// Decoding kernel for required dictionary encoded. #[inline(never)] pub fn decode( mut values: HybridRleDecoder<'_>, @@ -19,28 +22,15 @@ pub fn decode( if num_rows == 0 { return Ok(()); } + target.reserve(num_rows); if dict.is_empty() { return Err(oob_dict_idx()); } - // Skip over any whole HybridRleChunks if possible - if num_rows_to_skip > 0 { - loop { - let mut values_clone = values.clone(); - let Some(chunk_len) = values_clone.next_chunk_length()? else { - break; - }; - - if num_rows_to_skip < chunk_len { - break; - } - - values = values_clone; - num_rows_to_skip -= chunk_len; - } - } + // Skip over whole HybridRleChunks + required_skip_whole_chunks(&mut values, &mut num_rows_to_skip)?; while let Some(chunk) = values.next_chunk()? { debug_assert!(num_rows_to_skip < chunk.len()); @@ -50,9 +40,11 @@ pub fn decode( if size == 0 { continue; } + let Some(&value) = dict.get(value as usize) else { return Err(oob_dict_idx()); }; + target.resize(target.len() + size - num_rows_to_skip, value); }, HybridRleChunk::Bitpacked(mut decoder) => { @@ -63,7 +55,6 @@ pub fn decode( if let Some((chunk, chunk_size)) = decoder.chunked().next_inexact() { let chunk = &chunk[num_rows_to_skip..chunk_size]; verify_dict_indices_slice(chunk, dict.len())?; - target.extend(chunk.iter().map(|&idx| { // SAFETY: The dict indices were verified before. *unsafe { dict.get_unchecked(idx as usize) } diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 773862591f0e..f79723b0e58b 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -4,6 +4,7 @@ import io from datetime import datetime, time, timezone from decimal import Decimal +from itertools import chain from typing import IO, TYPE_CHECKING, Any, Callable, Literal, cast import fsspec @@ -2319,3 +2320,59 @@ def test_nested_dicts(content: list[float | None]) -> None: df.write_parquet(f, use_pyarrow=True) f.seek(0) assert_frame_equal(df, pl.read_parquet(f)) + + +@pytest.mark.parametrize( + "leading_nulls", + [ + [], + [None] * 7, + ], +) +@pytest.mark.parametrize( + "trailing_nulls", + [ + [], + [None] * 7, + ], +) +@pytest.mark.parametrize( + "first_chunk", + # Create both RLE and Bitpacked chunks + [ + [1] * 57, + [1 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +@pytest.mark.parametrize( + "second_chunk", + # Create both RLE and Bitpacked chunks + [ + [2] * 57, + [2 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +def test_dict_slices( + leading_nulls: list[None], + trailing_nulls: list[None], + first_chunk: list[None | int], + second_chunk: list[None | int], +) -> None: + df = pl.Series( + "a", leading_nulls + first_chunk + second_chunk + trailing_nulls, pl.Int64 + ).to_frame() + + f = io.BytesIO() + df.write_parquet(f) + + for offset in chain([0, 1, 2], range(3, df.height, 3)): + for length in chain([df.height, 1, 2], range(3, df.height - offset, 3)): + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).slice(offset, length).collect(), + df.slice(offset, length), + ) From d3094f5cbb11ec0b947b5d54f20e90053f7403dd Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Wed, 6 Nov 2024 11:53:29 +0100 Subject: [PATCH 11/11] final version --- .../deserialize/dictionary_encoded/mod.rs | 2 +- .../dictionary_encoded/optional.rs | 2 +- .../optional_masked_dense.rs | 122 +++++------------- .../dictionary_encoded/required.rs | 2 +- .../required_masked_dense.rs | 78 ++++------- .../read/deserialize/fixed_size_binary.rs | 4 +- .../arrow/read/deserialize/primitive/plain.rs | 2 +- py-polars/tests/unit/io/test_parquet.py | 65 ++++++++++ 8 files changed, 128 insertions(+), 149 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs index d7bc8a7d0abf..91ad273ee04b 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs @@ -64,7 +64,7 @@ pub fn decode_dict_dispatch( optional::decode(values, dict, page_validity, target, rng.start) }, (Some(Filter::Mask(filter)), None) => { - required_masked_dense::decode(values, dict, filter, target, 0) + required_masked_dense::decode(values, dict, filter, target) }, (Some(Filter::Mask(filter)), Some(page_validity)) => { optional_masked_dense::decode(values, dict, filter, page_validity, target) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs index 5b8ad8a5d2e0..f0546027549e 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -72,7 +72,7 @@ pub fn decode( )?; while let Some(chunk) = values.next_chunk()? { - debug_assert!(num_values_to_skip < chunk.len()); + debug_assert!(num_values_to_skip < chunk.len() || chunk.len() == 0); match chunk { HybridRleChunk::Rle(value, size) => { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs index 597ad8547524..c779c3b61b0e 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs @@ -1,4 +1,3 @@ -use arrow::array::Splitable; use arrow::bitmap::bitmask::BitMask; use arrow::bitmap::Bitmap; use arrow::types::AlignedBytes; @@ -15,85 +14,52 @@ pub fn decode( mut validity: Bitmap, target: &mut Vec, ) -> ParquetResult<()> { - debug_assert_eq!(filter.len(), validity.len()); - - let leading_filtered = filter.take_leading_zeros(); + // @NOTE: We don't skip leading filtered values, because it is a bit more involved than the + // other kernels. We could probably do it anyway after having tried to dispatch to faster + // kernels, but we lose quite a bit of the potency with that. filter.take_trailing_zeros(); + validity = validity.sliced(0, filter.len()); let num_rows = filter.set_bits(); + let num_valid_values = validity.set_bits(); - let leading_validity; - (leading_validity, validity) = validity.split_at(leading_filtered); - - let mut num_rows_to_skip = leading_filtered; - let mut num_values_to_skip = leading_validity.set_bits(); - - validity = validity.sliced(0, filter.len()); + assert_eq!(filter.len(), validity.len()); + assert!(num_valid_values <= values.len()); // Dispatch to the non-filter kernel if all rows are needed anyway. if num_rows == filter.len() { - return super::optional::decode(values, dict, validity, target, num_rows_to_skip); + return super::optional::decode(values, dict, validity, target, 0); } + // Dispatch to the required kernel if all rows are valid anyway. - if validity.set_bits() == validity.len() { - return super::required_masked_dense::decode( - values, - dict, - filter, - target, - num_values_to_skip, - ); + if num_valid_values == validity.len() { + return super::required_masked_dense::decode(values, dict, filter, target); } - if dict.is_empty() && validity.set_bits() > 0 { + + if dict.is_empty() && num_valid_values > 0 { return Err(oob_dict_idx()); } - debug_assert_eq!(filter.len(), validity.len()); - assert!(validity.set_bits() <= values.len()); - let start_length = target.len(); - target.reserve(num_rows); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + + let end_length = target.len() + num_rows; let mut filter = BitMask::from_bitmap(&filter); let mut validity = BitMask::from_bitmap(&validity); - values.limit_to(num_values_to_skip + validity.set_bits()); + values.limit_to(num_valid_values); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - // Skip over any whole HybridRleChunks - if num_values_to_skip > 0 { - let mut total_num_skipped_values = 0; + let mut num_rows_left = num_rows; - loop { - let mut values_clone = values.clone(); - let Some(chunk_len) = values_clone.next_chunk_length()? else { - break; - }; - - if num_values_to_skip < chunk_len { - break; - } - - values = values_clone; - num_values_to_skip -= chunk_len; - total_num_skipped_values += chunk_len; + for chunk in values.into_chunk_iter() { + // Early stop if we have no more rows to load. + if num_rows_left == 0 { + break; } - if total_num_skipped_values > 0 { - let offset = validity - .nth_set_bit_idx(total_num_skipped_values - 1, 0) - .unwrap_or(validity.len()); - num_rows_to_skip -= offset; - validity = validity.sliced(offset, validity.len() - offset); - } - } - - while let Some(chunk) = values.next_chunk()? { - debug_assert!(num_values_to_skip < chunk.len()); - - match chunk { + match chunk? { HybridRleChunk::Rle(value, size) => { if size == 0 { continue; @@ -107,34 +73,19 @@ pub fn decode( // 2. Fill `num_rows` values into the target buffer. // 3. Advance the validity mask by `num_rows` values. - let num_chunk_values = validity - .nth_set_bit_idx(size, num_rows_to_skip) - .unwrap_or(validity.len()); + let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); let current_filter; - (_, validity) = unsafe { validity.split_at_unchecked(num_chunk_values) }; - (current_filter, filter) = unsafe { filter.split_at_unchecked(num_chunk_values) }; + (current_filter, filter) = filter.split_at(num_chunk_values); + validity.advance_by(num_chunk_values); let num_chunk_rows = current_filter.set_bits(); - if num_chunk_rows > 0 { - let target_slice; - // SAFETY: - // Given `filter_iter` before the `advance_by_bits`. - // - // 1. `target_ptr..target_ptr + filter_iter.count_ones()` is allocated - // 2. `num_chunk_rows < filter_iter.count_ones()` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); - target_ptr = target_ptr.add(num_chunk_rows); - } - - let Some(value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; + let Some(&value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; - target_slice.fill(*value); - } + target.resize(target.len() + num_chunk_rows, value); }, HybridRleChunk::Bitpacked(mut decoder) => { // For bitpacked we do the following: @@ -154,7 +105,7 @@ pub fn decode( let mut buffer_part_idx = 0; let mut values_offset = 0; let mut num_buffered: usize = 0; - let mut skip_values = num_values_to_skip; + let mut skip_values = 0; let current_filter; let current_validity; @@ -203,6 +154,7 @@ pub fn decode( let mut num_read = 0; let mut num_written = 0; + let target_ptr = unsafe { target.as_mut_ptr().add(target.len()) }; while f != 0 { let offset = f.trailing_zeros(); @@ -233,8 +185,9 @@ pub fn decode( values_offset %= 128; num_buffered -= num_read; unsafe { - target_ptr = target_ptr.add(num_written); + target.set_len(target.len() + num_written); } + num_rows_left -= num_written; ParquetResult::Ok(()) }; @@ -254,20 +207,13 @@ pub fn decode( iter(f, v)?; }, } - - num_rows_to_skip = 0; - num_values_to_skip = 0; } if cfg!(debug_assertions) { assert_eq!(validity.set_bits(), 0); } - let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, validity.len() - num_rows_to_skip) }; - target_slice.fill(B::zeroed()); - unsafe { - target.set_len(start_length + num_rows); - } + target.resize(end_length, B::zeroed()); Ok(()) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs index 443ab3633679..a7053944d130 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -33,7 +33,7 @@ pub fn decode( required_skip_whole_chunks(&mut values, &mut num_rows_to_skip)?; while let Some(chunk) = values.next_chunk()? { - debug_assert!(num_rows_to_skip < chunk.len()); + debug_assert!(num_rows_to_skip < chunk.len() || chunk.len() == 0); match chunk { HybridRleChunk::Rle(value, size) => { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs index 5b224176dc0c..04aed3dbfa1e 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -2,7 +2,7 @@ use arrow::bitmap::bitmask::BitMask; use arrow::bitmap::Bitmap; use arrow::types::AlignedBytes; -use super::{oob_dict_idx, verify_dict_indices}; +use super::{oob_dict_idx, required_skip_whole_chunks, verify_dict_indices}; use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; use crate::parquet::error::ParquetResult; @@ -12,60 +12,44 @@ pub fn decode( dict: &[B], mut filter: Bitmap, target: &mut Vec, - num_values_to_skip: usize, ) -> ParquetResult<()> { - debug_assert!(values.len() >= filter.len()); + assert!(values.len() >= filter.len()); - let leading_zeros = filter.take_leading_zeros(); + let mut num_rows_to_skip = filter.take_leading_zeros(); filter.take_trailing_zeros(); - values.limit_to(leading_zeros + filter.len()); - let num_rows = filter.set_bits(); + values.limit_to(num_rows_to_skip + filter.len()); + // Dispatch to the non-filter kernel if all rows are needed anyway. if num_rows == filter.len() { - return super::required::decode(values, dict, target, leading_zeros); + return super::required::decode(values, dict, target, num_rows_to_skip); } if dict.is_empty() && !filter.is_empty() { return Err(oob_dict_idx()); } - let start_length = target.len(); - target.reserve(num_rows); - let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; let mut filter = BitMask::from_bitmap(&filter); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - let mut num_rows_to_skip = leading_zeros + num_values_to_skip; - - // Skip over any whole HybridRleChunks - if num_rows_to_skip > 0 { - loop { - let mut values_clone = values.clone(); - let Some(chunk_len) = values_clone.next_chunk_length()? else { - break; - }; - - if num_rows_to_skip < chunk_len { - break; - } - - values = values_clone; - num_rows_to_skip -= chunk_len; - } - } + // Skip over whole HybridRleChunks + required_skip_whole_chunks(&mut values, &mut num_rows_to_skip)?; while let Some(chunk) = values.next_chunk()? { - debug_assert!(num_rows_to_skip < chunk.len()); + debug_assert!(num_rows_to_skip < chunk.len() || chunk.len() == 0); match chunk { HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + // If we know that we have `size` times `value` that we can append, but there might // be nulls in between those values. // @@ -76,42 +60,29 @@ pub fn decode( let current_filter; - (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; - let num_chunk_rows = current_filter - .sliced(num_rows_to_skip, current_filter.len() - num_rows_to_skip) - .set_bits(); + (current_filter, filter) = filter.split_at(size - num_rows_to_skip); + let num_chunk_rows = current_filter.set_bits(); if num_chunk_rows > 0 { - let target_slice; - // SAFETY: - // Given `filter_iter` before the `advance_by_bits`. - // - // 1. `target_ptr..target_ptr + filter_iter.count_ones()` is allocated - // 2. `num_chunk_rows < filter_iter.count_ones()` - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_chunk_rows); - target_ptr = target_ptr.add(num_chunk_rows); - } - - let Some(value) = dict.get(value as usize) else { + let Some(&value) = dict.get(value as usize) else { return Err(oob_dict_idx()); }; - target_slice.fill(*value); + target.resize(target.len() + num_chunk_rows, value); } }, HybridRleChunk::Bitpacked(mut decoder) => { + let size = decoder.len().min(filter.len()); + let mut chunked = decoder.chunked(); + let mut buffer_part_idx = 0; let mut values_offset = 0; let mut num_buffered: usize = 0; let mut skip_values = num_rows_to_skip; - let size = decoder.len(); - let mut chunked = decoder.chunked(); - let current_filter; - (current_filter, filter) = unsafe { filter.split_at_unchecked(size) }; + (current_filter, filter) = filter.split_at(size); let mut iter = |mut f: u64, len: usize| { debug_assert!(len <= 64); @@ -155,6 +126,7 @@ pub fn decode( let mut num_read = 0; let mut num_written = 0; + let target_ptr = unsafe { target.as_mut_ptr().add(target.len()) }; while f != 0 { let offset = f.trailing_zeros() as usize; @@ -180,7 +152,7 @@ pub fn decode( values_offset %= 128; num_buffered -= len; unsafe { - target_ptr = target_ptr.add(num_written); + target.set_len(target.len() + num_written); } ParquetResult::Ok(()) @@ -201,9 +173,5 @@ pub fn decode( num_rows_to_skip = 0; } - unsafe { - target.set_len(start_length + num_rows); - } - Ok(()) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs index 9cce1ee4c60c..d968e369b7c0 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs @@ -10,16 +10,16 @@ use arrow::types::{ Bytes4Alignment4, Bytes8Alignment8, }; -use super::utils::array_chunks::ArrayChunks; use super::dictionary_encoded::append_validity; +use super::utils::array_chunks::ArrayChunks; use super::utils::{dict_indices_decoder, freeze_validity, Decoder}; use super::Filter; use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; use crate::parquet::encoding::{hybrid_rle, Encoding}; use crate::parquet::error::{ParquetError, ParquetResult}; use crate::parquet::page::{split_buffer, DataPage, DictPage}; -use crate::read::deserialize::utils; use crate::read::deserialize::dictionary_encoded::constrain_page_validity; +use crate::read::deserialize::utils; #[allow(clippy::large_enum_variant)] #[derive(Debug)] diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs index d37f40a3337e..fd7c8833e3c3 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain.rs @@ -5,8 +5,8 @@ use arrow::types::{AlignedBytes, NativeType}; use super::DecoderFunction; use crate::parquet::error::ParquetResult; use crate::parquet::types::NativeType as ParquetNativeType; -use crate::read::deserialize::utils::array_chunks::ArrayChunks; use crate::read::deserialize::dictionary_encoded::{append_validity, constrain_page_validity}; +use crate::read::deserialize::utils::array_chunks::ArrayChunks; use crate::read::{Filter, ParquetError}; #[allow(clippy::too_many_arguments)] diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index f79723b0e58b..8ea7d2152bc0 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -2376,3 +2376,68 @@ def test_dict_slices( pl.scan_parquet(f).slice(offset, length).collect(), df.slice(offset, length), ) + + +@pytest.mark.parametrize( + "mask", + [ + [i % 13 < 3 and i % 17 > 3 for i in range(57 * 2)], + [False] * 23 + [True] * 68 + [False] * 23, + [False] * 23 + [True] * 24 + [False] * 20 + [True] * 24 + [False] * 23, + [True] + [False] * 22 + [True] * 24 + [False] * 20 + [True] * 24 + [False] * 23, + [False] * 23 + [True] * 24 + [False] * 20 + [True] * 24 + [False] * 22 + [True], + [True] + + [False] * 22 + + [True] * 24 + + [False] * 20 + + [True] * 24 + + [False] * 22 + + [True], + [False] * 56 + [True] * 58, + [False] * 57 + [True] * 57, + [False] * 58 + [True] * 56, + [True] * 56 + [False] * 58, + [True] * 57 + [False] * 57, + [True] * 58 + [False] * 56, + ], +) +@pytest.mark.parametrize( + "first_chunk", + # Create both RLE and Bitpacked chunks + [ + [1] * 57, + [1 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +@pytest.mark.parametrize( + "second_chunk", + # Create both RLE and Bitpacked chunks + [ + [2] * 57, + [2 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +def test_dict_masked( + mask: list[bool], + first_chunk: list[None | int], + second_chunk: list[None | int], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", first_chunk + second_chunk, pl.Int64), + pl.Series("f", mask, pl.Boolean), + ] + ) + + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f, parallel="prefiltered").filter(pl.col.f).collect(), + df.filter(pl.col.f), + )