Skip to content

Commit

Permalink
feat: Add optimized row encoding for Decimals (#20050)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Nov 28, 2024
1 parent 74d059f commit 49b2e7b
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 42 deletions.
5 changes: 4 additions & 1 deletion crates/polars-row/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use arrow::datatypes::ArrowDataType;
use arrow::offset::OffsetsBuffer;

use self::encode::fixed_size;
use self::fixed::decimal;
use self::row::RowEncodingOptions;
use self::variable::utf8::decode_str;
use super::*;
Expand Down Expand Up @@ -151,7 +152,6 @@ fn dtype_and_data_to_encoded_item_len(

D::Union(_, _, _) => todo!(),
D::Map(_, _) => todo!(),
D::Decimal(_, _) => todo!(),
D::Decimal256(_, _) => todo!(),
D::Extension(_, _, _) => todo!(),
D::Unknown => todo!(),
Expand Down Expand Up @@ -326,6 +326,9 @@ unsafe fn decode(
.unwrap()
.to_boxed()
},

D::Decimal(precision, scale) => decimal::decode(rows, opt, *precision, *scale).to_boxed(),

dt => {
if matches!(dt, D::UInt32) {
if let Some(dict) = dict {
Expand Down
17 changes: 15 additions & 2 deletions crates/polars-row/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use arrow::bitmap::Bitmap;
use arrow::datatypes::ArrowDataType;
use arrow::types::Offset;

use crate::fixed::decimal;
use crate::fixed::numeric::FixedLengthEncoding;
use crate::row::{RowEncodingOptions, RowsEncoded};
use crate::widths::RowWidths;
Expand Down Expand Up @@ -554,6 +555,19 @@ unsafe fn encode_flat_array(
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
crate::fixed::boolean::encode_bool(buffer, array.iter(), opt, offsets);
},

// Needs to happen before numeric arm.
D::Decimal(precision, _) => decimal::encode(
buffer,
array
.as_any()
.downcast_ref::<PrimitiveArray<i128>>()
.unwrap(),
opt,
offsets,
*precision,
),

dt if dt.is_numeric() => {
if matches!(dt, D::UInt32) {
if let Some(dict) = dict {
Expand Down Expand Up @@ -607,7 +621,6 @@ unsafe fn encode_flat_array(
D::Dictionary(_, _, _) => todo!(),

D::FixedSizeBinary(_) => todo!(),
D::Decimal(_, _) => todo!(),
D::Decimal256(_, _) => todo!(),

D::Union(_, _, _) => todo!(),
Expand Down Expand Up @@ -842,7 +855,7 @@ pub fn fixed_size(dtype: &ArrowDataType, dict: Option<&RowEncodingCatOrder>) ->
Int16 => i16::ENCODED_LEN,
Int32 => i32::ENCODED_LEN,
Int64 => i64::ENCODED_LEN,
Decimal(_, _) => i128::ENCODED_LEN,
Decimal(precision, _) => decimal::len_from_precision(*precision),
Float32 => f32::ENCODED_LEN,
Float64 => f64::ENCODED_LEN,
Boolean => 1,
Expand Down
248 changes: 248 additions & 0 deletions crates/polars-row/src/fixed/decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
//! Row Encoding for Enum's and Categorical's
//!
//! This is a fixed-size encoding that takes a number of maximum bits that each value can take and
//! compresses such that a minimum amount of bytes are used for each value.
use std::mem::MaybeUninit;

use arrow::array::{Array, PrimitiveArray};
use arrow::bitmap::MutableBitmap;
use arrow::datatypes::ArrowDataType;
use polars_utils::slice::Slice2Uninit;

use crate::row::RowEncodingOptions;

macro_rules! with_constant_num_bytes {
($num_bytes:ident, $block:block) => {
with_arms!(
$num_bytes,
$block,
(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
)
};
}

pub fn len_from_precision(precision: usize) -> usize {
len_from_num_bits(num_bits_from_precision(precision))
}

fn num_bits_from_precision(precision: usize) -> usize {
assert!(precision <= 38);
// This may seem random. But this is ceil(s * log2(10)) which is a reduction of
// ceil(log2(10**s))
((precision as f32) * 10.0f32.log2()).ceil() as usize
}

fn len_from_num_bits(num_bits: usize) -> usize {
// 1 bit is used to indicate the nullability
// 1 bit is used to indicate the signedness
(num_bits + 2).div_ceil(8)
}

pub unsafe fn encode(
buffer: &mut [MaybeUninit<u8>],
input: &PrimitiveArray<i128>,
opt: RowEncodingOptions,
offsets: &mut [usize],
precision: usize,
) {
if input.null_count() == 0 {
unsafe { encode_slice(buffer, input.values(), opt, offsets, precision) }
} else {
unsafe {
encode_iter(
buffer,
input.iter().map(|v| v.copied()),
opt,
offsets,
precision,
)
}
}
}

pub unsafe fn encode_slice(
buffer: &mut [MaybeUninit<u8>],
input: &[i128],
opt: RowEncodingOptions,
offsets: &mut [usize],
precision: usize,
) {
let num_bits = num_bits_from_precision(precision);

// If the output will not fit in less bytes, just use the normal i128 encoding kernel.
if num_bits >= 127 {
super::numeric::encode_slice(buffer, input, opt, offsets);
return;
}

let num_bytes = len_from_num_bits(num_bits);
let mask = (1 << (num_bits + 1)) - 1;
let valid_mask = ((!opt.null_sentinel() & 0x80) as i128) << ((num_bytes - 1) * 8);
let sign_mask = 1 << num_bits;
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) {
mask
} else {
0
};

with_constant_num_bytes!(num_bytes, {
for (offset, &v) in offsets.iter_mut().zip(input) {
let mut v = v;

v &= mask; // Mask out higher sign extension bits
v ^= sign_mask; // Flip sign-bit to maintain order
v ^= invert_mask; // Invert for descending
v |= valid_mask; // Add valid indicator

unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) }
.copy_from_slice(v.to_be_bytes()[16 - num_bytes..].as_uninit());
*offset += num_bytes;
}
});
}

pub unsafe fn encode_iter(
buffer: &mut [MaybeUninit<u8>],
input: impl Iterator<Item = Option<i128>>,
opt: RowEncodingOptions,
offsets: &mut [usize],
precision: usize,
) {
let num_bits = num_bits_from_precision(precision);
// If the output will not fit in less bytes, just use the normal i128 encoding kernel.
if num_bits >= 127 {
super::numeric::encode_iter(buffer, input, opt, offsets);
return;
}

let num_bytes = len_from_num_bits(num_bits);
let null_value = (opt.null_sentinel() as i128) << ((num_bytes - 1) * 8);
let mask = (1 << (num_bits + 1)) - 1;
let valid_mask = ((!opt.null_sentinel() & 0x80) as i128) << ((num_bytes - 1) * 8);
let sign_mask = 1 << num_bits;
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) {
mask
} else {
0
};

with_constant_num_bytes!(num_bytes, {
for (offset, v) in offsets.iter_mut().zip(input) {
match v {
None => {
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) }
.copy_from_slice(null_value.to_be_bytes()[16 - num_bytes..].as_uninit());
},
Some(mut v) => {
v &= mask; // Mask out higher sign extension bits
v ^= sign_mask; // Flip sign-bit to maintain order
v ^= invert_mask; // Invert for descending
v |= valid_mask; // Add valid indicator

unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) }
.copy_from_slice(v.to_be_bytes()[16 - num_bytes..].as_uninit());
},
}

*offset += num_bytes;
}
});
}

pub unsafe fn decode(
rows: &mut [&[u8]],
opt: RowEncodingOptions,
precision: usize,
scale: usize,
) -> PrimitiveArray<i128> {
let num_bits = num_bits_from_precision(precision);
// If the output will not fit in less bytes, just use the normal i128 decoding kernel.
if num_bits >= 127 {
let (_, values, validity) = super::numeric::decode_primitive(rows, opt).into_inner();
return PrimitiveArray::new(ArrowDataType::Decimal(precision, scale), values, validity);
}

let mut values = Vec::with_capacity(rows.len());
let null_sentinel = opt.null_sentinel();

let num_bytes = len_from_num_bits(num_bits);
let mask = (1 << (num_bits + 1)) - 1;
let sign_mask = 1 << num_bits;
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) {
mask
} else {
0
};

with_constant_num_bytes!(num_bytes, {
values.extend(
rows.iter_mut()
.take_while(|row| *unsafe { row.get_unchecked(0) } != null_sentinel)
.map(|row| {
let mut value = 0i128;
let value_ref: &mut [u8; 16] = bytemuck::cast_mut(&mut value);
value_ref[16 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes));
*row = &row[num_bytes..];

if cfg!(target_endian = "little") {
// Big-Endian -> Little-Endian
value = value.swap_bytes();
}

value ^= invert_mask; // Invert for descending
value ^= sign_mask; // Flip sign bit to maintain order

// Sign extend. This also masks out the valid bit.
value <<= i128::BITS - num_bits as u32 - 1;
value >>= i128::BITS - num_bits as u32 - 1;

value
}),
);
});

if values.len() == rows.len() {
return PrimitiveArray::new(
ArrowDataType::Decimal(precision, scale),
values.into(),
None,
);
}

let mut validity = MutableBitmap::with_capacity(rows.len());
validity.extend_constant(values.len(), true);

let start_len = values.len();

with_constant_num_bytes!(num_bytes, {
values.extend(rows[start_len..].iter_mut().map(|row| {
validity.push(*unsafe { row.get_unchecked(0) } != null_sentinel);

let mut value = 0i128;
let value_ref: &mut [u8; 16] = bytemuck::cast_mut(&mut value);
value_ref[16 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes));
*row = &row[num_bytes..];

if cfg!(target_endian = "little") {
// Big-Endian -> Little-Endian
value = value.swap_bytes();
}

value ^= invert_mask; // Invert for descending
value ^= sign_mask; // Flip sign bit to maintain order

// Sign extend. This also masks out the valid bit.
value <<= i128::BITS - num_bits as u32 - 1;
value >>= i128::BITS - num_bits as u32 - 1;

value
}));
});

PrimitiveArray::new(
ArrowDataType::Decimal(precision, scale),
values.into(),
Some(validity.freeze()),
)
}
16 changes: 16 additions & 0 deletions crates/polars-row/src/fixed/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
macro_rules! with_arms {
($num_bytes:ident, $block:block, ($($values:literal),+)) => {
match $num_bytes {
$(
$values => {
#[allow(non_upper_case_globals)]
const $num_bytes: usize = $values;
$block
},
)+
_ => unreachable!(),
}
};
}

pub mod boolean;
pub mod decimal;
pub mod numeric;
pub mod packed_u32;
24 changes: 1 addition & 23 deletions crates/polars-row/src/fixed/packed_u32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,7 @@ pub fn len_from_num_bits(num_bits: usize) -> usize {

macro_rules! with_constant_num_bytes {
($num_bytes:ident, $block:block) => {
match $num_bytes {
1 => {
#[allow(non_upper_case_globals)]
const $num_bytes: usize = 1;
$block
},
2 => {
#[allow(non_upper_case_globals)]
const $num_bytes: usize = 2;
$block
},
3 => {
#[allow(non_upper_case_globals)]
const $num_bytes: usize = 3;
$block
},
4 => {
#[allow(non_upper_case_globals)]
const $num_bytes: usize = 4;
$block
},
_ => unreachable!(),
}
with_arms!($num_bytes, $block, (1, 2, 3, 4))
};
}

Expand Down
15 changes: 0 additions & 15 deletions crates/polars-row/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@ macro_rules! with_match_arrow_primitive_type {(
}
})}

#[macro_export]
macro_rules! with_match_arrow_integer_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use arrow::datatypes::IntegerType::*;
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
_ => unreachable!(),
}
})}

pub(crate) unsafe fn decode_opt_nulls(rows: &[&[u8]], null_sentinel: u8) -> Option<Bitmap> {
let first_null = rows
.iter()
Expand Down
Loading

0 comments on commit 49b2e7b

Please sign in to comment.