-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add optimized row encoding for Decimals (#20050)
- Loading branch information
1 parent
74d059f
commit 49b2e7b
Showing
7 changed files
with
299 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
//! Row Encoding for Enum's and Categorical's | ||
//! | ||
//! This is a fixed-size encoding that takes a number of maximum bits that each value can take and | ||
//! compresses such that a minimum amount of bytes are used for each value. | ||
use std::mem::MaybeUninit; | ||
|
||
use arrow::array::{Array, PrimitiveArray}; | ||
use arrow::bitmap::MutableBitmap; | ||
use arrow::datatypes::ArrowDataType; | ||
use polars_utils::slice::Slice2Uninit; | ||
|
||
use crate::row::RowEncodingOptions; | ||
|
||
macro_rules! with_constant_num_bytes { | ||
($num_bytes:ident, $block:block) => { | ||
with_arms!( | ||
$num_bytes, | ||
$block, | ||
(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) | ||
) | ||
}; | ||
} | ||
|
||
pub fn len_from_precision(precision: usize) -> usize { | ||
len_from_num_bits(num_bits_from_precision(precision)) | ||
} | ||
|
||
fn num_bits_from_precision(precision: usize) -> usize { | ||
assert!(precision <= 38); | ||
// This may seem random. But this is ceil(s * log2(10)) which is a reduction of | ||
// ceil(log2(10**s)) | ||
((precision as f32) * 10.0f32.log2()).ceil() as usize | ||
} | ||
|
||
fn len_from_num_bits(num_bits: usize) -> usize { | ||
// 1 bit is used to indicate the nullability | ||
// 1 bit is used to indicate the signedness | ||
(num_bits + 2).div_ceil(8) | ||
} | ||
|
||
pub unsafe fn encode( | ||
buffer: &mut [MaybeUninit<u8>], | ||
input: &PrimitiveArray<i128>, | ||
opt: RowEncodingOptions, | ||
offsets: &mut [usize], | ||
precision: usize, | ||
) { | ||
if input.null_count() == 0 { | ||
unsafe { encode_slice(buffer, input.values(), opt, offsets, precision) } | ||
} else { | ||
unsafe { | ||
encode_iter( | ||
buffer, | ||
input.iter().map(|v| v.copied()), | ||
opt, | ||
offsets, | ||
precision, | ||
) | ||
} | ||
} | ||
} | ||
|
||
pub unsafe fn encode_slice( | ||
buffer: &mut [MaybeUninit<u8>], | ||
input: &[i128], | ||
opt: RowEncodingOptions, | ||
offsets: &mut [usize], | ||
precision: usize, | ||
) { | ||
let num_bits = num_bits_from_precision(precision); | ||
|
||
// If the output will not fit in less bytes, just use the normal i128 encoding kernel. | ||
if num_bits >= 127 { | ||
super::numeric::encode_slice(buffer, input, opt, offsets); | ||
return; | ||
} | ||
|
||
let num_bytes = len_from_num_bits(num_bits); | ||
let mask = (1 << (num_bits + 1)) - 1; | ||
let valid_mask = ((!opt.null_sentinel() & 0x80) as i128) << ((num_bytes - 1) * 8); | ||
let sign_mask = 1 << num_bits; | ||
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) { | ||
mask | ||
} else { | ||
0 | ||
}; | ||
|
||
with_constant_num_bytes!(num_bytes, { | ||
for (offset, &v) in offsets.iter_mut().zip(input) { | ||
let mut v = v; | ||
|
||
v &= mask; // Mask out higher sign extension bits | ||
v ^= sign_mask; // Flip sign-bit to maintain order | ||
v ^= invert_mask; // Invert for descending | ||
v |= valid_mask; // Add valid indicator | ||
|
||
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) } | ||
.copy_from_slice(v.to_be_bytes()[16 - num_bytes..].as_uninit()); | ||
*offset += num_bytes; | ||
} | ||
}); | ||
} | ||
|
||
pub unsafe fn encode_iter( | ||
buffer: &mut [MaybeUninit<u8>], | ||
input: impl Iterator<Item = Option<i128>>, | ||
opt: RowEncodingOptions, | ||
offsets: &mut [usize], | ||
precision: usize, | ||
) { | ||
let num_bits = num_bits_from_precision(precision); | ||
// If the output will not fit in less bytes, just use the normal i128 encoding kernel. | ||
if num_bits >= 127 { | ||
super::numeric::encode_iter(buffer, input, opt, offsets); | ||
return; | ||
} | ||
|
||
let num_bytes = len_from_num_bits(num_bits); | ||
let null_value = (opt.null_sentinel() as i128) << ((num_bytes - 1) * 8); | ||
let mask = (1 << (num_bits + 1)) - 1; | ||
let valid_mask = ((!opt.null_sentinel() & 0x80) as i128) << ((num_bytes - 1) * 8); | ||
let sign_mask = 1 << num_bits; | ||
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) { | ||
mask | ||
} else { | ||
0 | ||
}; | ||
|
||
with_constant_num_bytes!(num_bytes, { | ||
for (offset, v) in offsets.iter_mut().zip(input) { | ||
match v { | ||
None => { | ||
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) } | ||
.copy_from_slice(null_value.to_be_bytes()[16 - num_bytes..].as_uninit()); | ||
}, | ||
Some(mut v) => { | ||
v &= mask; // Mask out higher sign extension bits | ||
v ^= sign_mask; // Flip sign-bit to maintain order | ||
v ^= invert_mask; // Invert for descending | ||
v |= valid_mask; // Add valid indicator | ||
|
||
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) } | ||
.copy_from_slice(v.to_be_bytes()[16 - num_bytes..].as_uninit()); | ||
}, | ||
} | ||
|
||
*offset += num_bytes; | ||
} | ||
}); | ||
} | ||
|
||
pub unsafe fn decode( | ||
rows: &mut [&[u8]], | ||
opt: RowEncodingOptions, | ||
precision: usize, | ||
scale: usize, | ||
) -> PrimitiveArray<i128> { | ||
let num_bits = num_bits_from_precision(precision); | ||
// If the output will not fit in less bytes, just use the normal i128 decoding kernel. | ||
if num_bits >= 127 { | ||
let (_, values, validity) = super::numeric::decode_primitive(rows, opt).into_inner(); | ||
return PrimitiveArray::new(ArrowDataType::Decimal(precision, scale), values, validity); | ||
} | ||
|
||
let mut values = Vec::with_capacity(rows.len()); | ||
let null_sentinel = opt.null_sentinel(); | ||
|
||
let num_bytes = len_from_num_bits(num_bits); | ||
let mask = (1 << (num_bits + 1)) - 1; | ||
let sign_mask = 1 << num_bits; | ||
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) { | ||
mask | ||
} else { | ||
0 | ||
}; | ||
|
||
with_constant_num_bytes!(num_bytes, { | ||
values.extend( | ||
rows.iter_mut() | ||
.take_while(|row| *unsafe { row.get_unchecked(0) } != null_sentinel) | ||
.map(|row| { | ||
let mut value = 0i128; | ||
let value_ref: &mut [u8; 16] = bytemuck::cast_mut(&mut value); | ||
value_ref[16 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes)); | ||
*row = &row[num_bytes..]; | ||
|
||
if cfg!(target_endian = "little") { | ||
// Big-Endian -> Little-Endian | ||
value = value.swap_bytes(); | ||
} | ||
|
||
value ^= invert_mask; // Invert for descending | ||
value ^= sign_mask; // Flip sign bit to maintain order | ||
|
||
// Sign extend. This also masks out the valid bit. | ||
value <<= i128::BITS - num_bits as u32 - 1; | ||
value >>= i128::BITS - num_bits as u32 - 1; | ||
|
||
value | ||
}), | ||
); | ||
}); | ||
|
||
if values.len() == rows.len() { | ||
return PrimitiveArray::new( | ||
ArrowDataType::Decimal(precision, scale), | ||
values.into(), | ||
None, | ||
); | ||
} | ||
|
||
let mut validity = MutableBitmap::with_capacity(rows.len()); | ||
validity.extend_constant(values.len(), true); | ||
|
||
let start_len = values.len(); | ||
|
||
with_constant_num_bytes!(num_bytes, { | ||
values.extend(rows[start_len..].iter_mut().map(|row| { | ||
validity.push(*unsafe { row.get_unchecked(0) } != null_sentinel); | ||
|
||
let mut value = 0i128; | ||
let value_ref: &mut [u8; 16] = bytemuck::cast_mut(&mut value); | ||
value_ref[16 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes)); | ||
*row = &row[num_bytes..]; | ||
|
||
if cfg!(target_endian = "little") { | ||
// Big-Endian -> Little-Endian | ||
value = value.swap_bytes(); | ||
} | ||
|
||
value ^= invert_mask; // Invert for descending | ||
value ^= sign_mask; // Flip sign bit to maintain order | ||
|
||
// Sign extend. This also masks out the valid bit. | ||
value <<= i128::BITS - num_bits as u32 - 1; | ||
value >>= i128::BITS - num_bits as u32 - 1; | ||
|
||
value | ||
})); | ||
}); | ||
|
||
PrimitiveArray::new( | ||
ArrowDataType::Decimal(precision, scale), | ||
values.into(), | ||
Some(validity.freeze()), | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,19 @@ | ||
macro_rules! with_arms { | ||
($num_bytes:ident, $block:block, ($($values:literal),+)) => { | ||
match $num_bytes { | ||
$( | ||
$values => { | ||
#[allow(non_upper_case_globals)] | ||
const $num_bytes: usize = $values; | ||
$block | ||
}, | ||
)+ | ||
_ => unreachable!(), | ||
} | ||
}; | ||
} | ||
|
||
pub mod boolean; | ||
pub mod decimal; | ||
pub mod numeric; | ||
pub mod packed_u32; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.