From 7aa8bddeb09420b2f81a50112603de28aeaf3be7 Mon Sep 17 00:00:00 2001 From: Scott Donnelly Date: Thu, 29 Aug 2024 04:37:48 +0100 Subject: [PATCH] Table Scan: Add Row Group Skipping (#558) * feat(scan): add row group and page index row selection filtering * fix(row selection): off-by-one error * feat: remove row selection to defer to a second PR * feat: better min/max val conversion in RowGroupMetricsEvaluator * test(row_group_filtering): first three tests * test(row_group_filtering): next few tests * test: add more tests for RowGroupMetricsEvaluator * chore: refactor test assertions to silence clippy lints * refactor: consolidate parquet stat min/max parsing in one place --- Cargo.toml | 2 + crates/iceberg/Cargo.toml | 2 + crates/iceberg/src/arrow/reader.rs | 210 +- crates/iceberg/src/arrow/schema.rs | 103 + crates/iceberg/src/expr/visitors/mod.rs | 1 + .../visitors/row_group_metrics_evaluator.rs | 1872 +++++++++++++++++ crates/iceberg/src/scan.rs | 23 +- .../src/writer/file_writer/parquet_writer.rs | 212 +- 8 files changed, 2187 insertions(+), 238 deletions(-) create mode 100644 crates/iceberg/src/expr/visitors/row_group_metrics_evaluator.rs diff --git a/Cargo.toml b/Cargo.toml index b59d4326e..8d04f6799 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,9 +72,11 @@ once_cell = "1" opendal = "0.49" ordered-float = "4" parquet = "52" +paste = "1" pilota = "0.11.2" pretty_assertions = "1.4" port_scanner = "0.1.5" +rand = "0.8" regex = "1.10.5" reqwest = { version = "0.12", default-features = false, features = ["json"] } rust_decimal = "1.31" diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index 6218e98e5..6166d360d 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -66,6 +66,7 @@ once_cell = { workspace = true } opendal = { workspace = true } ordered-float = { workspace = true } parquet = { workspace = true, features = ["async"] } +paste = { workspace = true } reqwest = { workspace = true } rust_decimal = { workspace = true } serde = { workspace = true } @@ -84,5 +85,6 @@ ctor = { workspace = true } iceberg-catalog-memory = { workspace = true } iceberg_test_utils = { path = "../test_utils", features = ["tests"] } pretty_assertions = { workspace = true } +rand = { workspace = true } tempfile = { workspace = true } tera = { workspace = true } diff --git a/crates/iceberg/src/arrow/reader.rs b/crates/iceberg/src/arrow/reader.rs index ebef735b1..b058c8d25 100644 --- a/crates/iceberg/src/arrow/reader.rs +++ b/crates/iceberg/src/arrow/reader.rs @@ -23,7 +23,7 @@ use std::str::FromStr; use std::sync::Arc; use arrow_arith::boolean::{and, is_not_null, is_null, not, or}; -use arrow_array::{ArrayRef, BooleanArray, RecordBatch}; +use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch}; use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow_schema::{ArrowError, DataType, SchemaRef as ArrowSchemaRef}; use arrow_string::like::starts_with; @@ -32,7 +32,7 @@ use fnv::FnvHashSet; use futures::channel::mpsc::{channel, Sender}; use futures::future::BoxFuture; use futures::{try_join, SinkExt, StreamExt, TryFutureExt, TryStreamExt}; -use parquet::arrow::arrow_reader::{ArrowPredicateFn, RowFilter}; +use parquet::arrow::arrow_reader::{ArrowPredicateFn, ArrowReaderOptions, RowFilter}; use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader}; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask, PARQUET_FIELD_ID_META_KEY}; use parquet::file::metadata::ParquetMetaData; @@ -41,6 +41,7 @@ use parquet::schema::types::{SchemaDescriptor, Type as ParquetType}; use crate::arrow::{arrow_schema_to_schema, get_arrow_datum}; use crate::error::Result; use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; +use crate::expr::visitors::row_group_metrics_evaluator::RowGroupMetricsEvaluator; use crate::expr::{BoundPredicate, BoundReference}; use crate::io::{FileIO, FileMetadata, FileRead}; use crate::runtime::spawn; @@ -54,6 +55,7 @@ pub struct ArrowReaderBuilder { batch_size: Option, file_io: FileIO, concurrency_limit_data_files: usize, + row_group_filtering_enabled: bool, } impl ArrowReaderBuilder { @@ -65,13 +67,13 @@ impl ArrowReaderBuilder { batch_size: None, file_io, concurrency_limit_data_files: num_cpus, + row_group_filtering_enabled: true, } } /// Sets the max number of in flight data files that are being fetched pub fn with_data_file_concurrency_limit(mut self, val: usize) -> Self { self.concurrency_limit_data_files = val; - self } @@ -82,12 +84,19 @@ impl ArrowReaderBuilder { self } + /// Determines whether to enable row group filtering. + pub fn with_row_group_filtering_enabled(mut self, row_group_filtering_enabled: bool) -> Self { + self.row_group_filtering_enabled = row_group_filtering_enabled; + self + } + /// Build the ArrowReader. pub fn build(self) -> ArrowReader { ArrowReader { batch_size: self.batch_size, file_io: self.file_io, concurrency_limit_data_files: self.concurrency_limit_data_files, + row_group_filtering_enabled: self.row_group_filtering_enabled, } } } @@ -100,6 +109,8 @@ pub struct ArrowReader { /// the maximum number of data files that can be fetched at the same time concurrency_limit_data_files: usize, + + row_group_filtering_enabled: bool, } impl ArrowReader { @@ -109,6 +120,7 @@ impl ArrowReader { let file_io = self.file_io.clone(); let batch_size = self.batch_size; let concurrency_limit_data_files = self.concurrency_limit_data_files; + let row_group_filtering_enabled = self.row_group_filtering_enabled; let (tx, rx) = channel(concurrency_limit_data_files); let mut channel_for_error = tx.clone(); @@ -124,8 +136,14 @@ impl ArrowReader { let file_path = task.data_file_path().to_string(); spawn(async move { - Self::process_file_scan_task(task, batch_size, file_io, tx) - .await + Self::process_file_scan_task( + task, + batch_size, + file_io, + tx, + row_group_filtering_enabled, + ) + .await }) .await .map_err(|e| e.with_context("file_path", file_path)) @@ -149,55 +167,95 @@ impl ArrowReader { batch_size: Option, file_io: FileIO, mut tx: Sender>, + row_group_filtering_enabled: bool, ) -> Result<()> { - // Collect Parquet column indices from field ids - let mut collector = CollectFieldIdVisitor { - field_ids: HashSet::default(), - }; - - if let Some(predicates) = task.predicate() { - visit(&mut collector, predicates)?; - } - + // Get the metadata for the Parquet file we need to read and build + // a reader for the data within let parquet_file = file_io.new_input(task.data_file_path())?; - let (parquet_metadata, parquet_reader) = try_join!(parquet_file.metadata(), parquet_file.reader())?; - let arrow_file_reader = ArrowFileReader::new(parquet_metadata, parquet_reader); + let parquet_file_reader = ArrowFileReader::new(parquet_metadata, parquet_reader); - let mut batch_stream_builder = - ParquetRecordBatchStreamBuilder::new(arrow_file_reader).await?; + // Start creating the record batch stream, which wraps the parquet file reader + let mut record_batch_stream_builder = ParquetRecordBatchStreamBuilder::new_with_options( + parquet_file_reader, + // Page index will be required in upcoming row selection PR + ArrowReaderOptions::new().with_page_index(false), + ) + .await?; - let parquet_schema = batch_stream_builder.parquet_schema(); - let arrow_schema = batch_stream_builder.schema(); + // Create a projection mask for the batch stream to select which columns in the + // Parquet file that we want in the response let projection_mask = Self::get_arrow_projection_mask( task.project_field_ids(), task.schema(), - parquet_schema, - arrow_schema, + record_batch_stream_builder.parquet_schema(), + record_batch_stream_builder.schema(), )?; - batch_stream_builder = batch_stream_builder.with_projection(projection_mask); - - let parquet_schema = batch_stream_builder.parquet_schema(); - let row_filter = Self::get_row_filter(task.predicate(), parquet_schema, &collector)?; - - if let Some(row_filter) = row_filter { - batch_stream_builder = batch_stream_builder.with_row_filter(row_filter); - } + record_batch_stream_builder = record_batch_stream_builder.with_projection(projection_mask); if let Some(batch_size) = batch_size { - batch_stream_builder = batch_stream_builder.with_batch_size(batch_size); + record_batch_stream_builder = record_batch_stream_builder.with_batch_size(batch_size); } - let mut batch_stream = batch_stream_builder.build()?; + if let Some(predicate) = task.predicate() { + let (iceberg_field_ids, field_id_map) = Self::build_field_id_set_and_map( + record_batch_stream_builder.parquet_schema(), + predicate, + )?; + + let row_filter = Self::get_row_filter( + predicate, + record_batch_stream_builder.parquet_schema(), + &iceberg_field_ids, + &field_id_map, + )?; + record_batch_stream_builder = record_batch_stream_builder.with_row_filter(row_filter); + + let mut selected_row_groups = None; + if row_group_filtering_enabled { + let result = Self::get_selected_row_group_indices( + predicate, + record_batch_stream_builder.metadata(), + &field_id_map, + task.schema(), + )?; + + selected_row_groups = Some(result); + } + + if let Some(selected_row_groups) = selected_row_groups { + record_batch_stream_builder = + record_batch_stream_builder.with_row_groups(selected_row_groups); + } + } - while let Some(batch) = batch_stream.try_next().await? { + // Build the batch stream and send all the RecordBatches that it generates + // to the requester. + let mut record_batch_stream = record_batch_stream_builder.build()?; + while let Some(batch) = record_batch_stream.try_next().await? { tx.send(Ok(batch)).await? } Ok(()) } + fn build_field_id_set_and_map( + parquet_schema: &SchemaDescriptor, + predicate: &BoundPredicate, + ) -> Result<(HashSet, HashMap)> { + // Collects all Iceberg field IDs referenced in the filter predicate + let mut collector = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + visit(&mut collector, predicate)?; + + let iceberg_field_ids = collector.field_ids(); + let field_id_map = build_field_id_map(parquet_schema)?; + + Ok((iceberg_field_ids, field_id_map)) + } + fn get_arrow_projection_mask( field_ids: &[i32], iceberg_schema_of_task: &Schema, @@ -269,43 +327,59 @@ impl ArrowReader { } fn get_row_filter( - predicates: Option<&BoundPredicate>, + predicates: &BoundPredicate, parquet_schema: &SchemaDescriptor, - collector: &CollectFieldIdVisitor, - ) -> Result> { - if let Some(predicates) = predicates { - let field_id_map = build_field_id_map(parquet_schema)?; - - // Collect Parquet column indices from field ids. - // If the field id is not found in Parquet schema, it will be ignored due to schema evolution. - let mut column_indices = collector - .field_ids - .iter() - .filter_map(|field_id| field_id_map.get(field_id).cloned()) - .collect::>(); - - column_indices.sort(); - - // The converter that converts `BoundPredicates` to `ArrowPredicates` - let mut converter = PredicateConverter { - parquet_schema, - column_map: &field_id_map, - column_indices: &column_indices, - }; - - // After collecting required leaf column indices used in the predicate, - // creates the projection mask for the Arrow predicates. - let projection_mask = ProjectionMask::leaves(parquet_schema, column_indices.clone()); - let predicate_func = visit(&mut converter, predicates)?; - let arrow_predicate = ArrowPredicateFn::new(projection_mask, predicate_func); - Ok(Some(RowFilter::new(vec![Box::new(arrow_predicate)]))) - } else { - Ok(None) + iceberg_field_ids: &HashSet, + field_id_map: &HashMap, + ) -> Result { + // Collect Parquet column indices from field ids. + // If the field id is not found in Parquet schema, it will be ignored due to schema evolution. + let mut column_indices = iceberg_field_ids + .iter() + .filter_map(|field_id| field_id_map.get(field_id).cloned()) + .collect::>(); + column_indices.sort(); + + // The converter that converts `BoundPredicates` to `ArrowPredicates` + let mut converter = PredicateConverter { + parquet_schema, + column_map: field_id_map, + column_indices: &column_indices, + }; + + // After collecting required leaf column indices used in the predicate, + // creates the projection mask for the Arrow predicates. + let projection_mask = ProjectionMask::leaves(parquet_schema, column_indices.clone()); + let predicate_func = visit(&mut converter, predicates)?; + let arrow_predicate = ArrowPredicateFn::new(projection_mask, predicate_func); + Ok(RowFilter::new(vec![Box::new(arrow_predicate)])) + } + + fn get_selected_row_group_indices( + predicate: &BoundPredicate, + parquet_metadata: &Arc, + field_id_map: &HashMap, + snapshot_schema: &Schema, + ) -> Result> { + let row_groups_metadata = parquet_metadata.row_groups(); + let mut results = Vec::with_capacity(row_groups_metadata.len()); + + for (idx, row_group_metadata) in row_groups_metadata.iter().enumerate() { + if RowGroupMetricsEvaluator::eval( + predicate, + row_group_metadata, + field_id_map, + snapshot_schema, + )? { + results.push(idx); + } } + + Ok(results) } } -/// Build the map of field id to Parquet column index in the schema. +/// Build the map of parquet field id to Parquet column index in the schema. fn build_field_id_map(parquet_schema: &SchemaDescriptor) -> Result> { let mut column_map = HashMap::new(); for (idx, field) in parquet_schema.columns().iter().enumerate() { @@ -345,6 +419,12 @@ struct CollectFieldIdVisitor { field_ids: HashSet, } +impl CollectFieldIdVisitor { + fn field_ids(self) -> HashSet { + self.field_ids + } +} + impl BoundPredicateVisitor for CollectFieldIdVisitor { type T = (); diff --git a/crates/iceberg/src/arrow/schema.rs b/crates/iceberg/src/arrow/schema.rs index a41243756..2ff43e0f0 100644 --- a/crates/iceberg/src/arrow/schema.rs +++ b/crates/iceberg/src/arrow/schema.rs @@ -30,7 +30,9 @@ use arrow_array::{ use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; use bitvec::macros::internal::funty::Fundamental; use parquet::arrow::PARQUET_FIELD_ID_META_KEY; +use parquet::file::statistics::Statistics; use rust_decimal::prelude::ToPrimitive; +use uuid::Uuid; use crate::error::Result; use crate::spec::{ @@ -652,6 +654,107 @@ pub(crate) fn get_arrow_datum(datum: &Datum) -> Result { + paste::paste! { + /// Gets the $limit_type value from a parquet Statistics struct, as a Datum + pub(crate) fn []( + primitive_type: &PrimitiveType, stats: &Statistics + ) -> Result> { + Ok(Some(match (primitive_type, stats) { + (PrimitiveType::Boolean, Statistics::Boolean(stats)) => Datum::bool(*stats.$limit_type()), + (PrimitiveType::Int, Statistics::Int32(stats)) => Datum::int(*stats.$limit_type()), + (PrimitiveType::Date, Statistics::Int32(stats)) => Datum::date(*stats.$limit_type()), + (PrimitiveType::Long, Statistics::Int64(stats)) => Datum::long(*stats.$limit_type()), + (PrimitiveType::Time, Statistics::Int64(stats)) => Datum::time_micros(*stats.$limit_type())?, + (PrimitiveType::Timestamp, Statistics::Int64(stats)) => { + Datum::timestamp_micros(*stats.$limit_type()) + } + (PrimitiveType::Timestamptz, Statistics::Int64(stats)) => { + Datum::timestamptz_micros(*stats.$limit_type()) + } + (PrimitiveType::TimestampNs, Statistics::Int64(stats)) => { + Datum::timestamp_nanos(*stats.$limit_type()) + } + (PrimitiveType::TimestamptzNs, Statistics::Int64(stats)) => { + Datum::timestamptz_nanos(*stats.$limit_type()) + } + (PrimitiveType::Float, Statistics::Float(stats)) => Datum::float(*stats.$limit_type()), + (PrimitiveType::Double, Statistics::Double(stats)) => Datum::double(*stats.$limit_type()), + (PrimitiveType::String, Statistics::ByteArray(stats)) => { + Datum::string(stats.$limit_type().as_utf8()?) + } + (PrimitiveType::Decimal { + precision: _, + scale: _, + }, Statistics::ByteArray(stats)) => { + Datum::new( + primitive_type.clone(), + PrimitiveLiteral::Int128(i128::from_le_bytes(stats.[<$limit_type _bytes>]().try_into()?)), + ) + } + ( + PrimitiveType::Decimal { + precision: _, + scale: _, + }, + Statistics::Int32(stats)) => { + Datum::new( + primitive_type.clone(), + PrimitiveLiteral::Int128(i128::from(*stats.$limit_type())), + ) + } + + ( + PrimitiveType::Decimal { + precision: _, + scale: _, + }, + Statistics::Int64(stats), + ) => { + Datum::new( + primitive_type.clone(), + PrimitiveLiteral::Int128(i128::from(*stats.$limit_type())), + ) + } + (PrimitiveType::Uuid, Statistics::FixedLenByteArray(stats)) => { + let raw = stats.[<$limit_type _bytes>](); + if raw.len() != 16 { + return Err(Error::new( + ErrorKind::Unexpected, + "Invalid length of uuid bytes.", + )); + } + Datum::uuid(Uuid::from_bytes( + raw[..16].try_into().unwrap(), + )) + } + (PrimitiveType::Fixed(len), Statistics::FixedLenByteArray(stat)) => { + let raw = stat.[<$limit_type _bytes>](); + if raw.len() != *len as usize { + return Err(Error::new( + ErrorKind::Unexpected, + "Invalid length of fixed bytes.", + )); + } + Datum::fixed(raw.to_vec()) + } + (PrimitiveType::Binary, Statistics::ByteArray(stat)) => { + Datum::binary(stat.[<$limit_type _bytes>]().to_vec()) + } + _ => { + return Ok(None); + } + })) + } + } + } +} + +get_parquet_stat_as_datum!(min); + +get_parquet_stat_as_datum!(max); + impl TryFrom<&ArrowSchema> for crate::spec::Schema { type Error = Error; diff --git a/crates/iceberg/src/expr/visitors/mod.rs b/crates/iceberg/src/expr/visitors/mod.rs index d686b1173..06bfd8cda 100644 --- a/crates/iceberg/src/expr/visitors/mod.rs +++ b/crates/iceberg/src/expr/visitors/mod.rs @@ -20,3 +20,4 @@ pub(crate) mod expression_evaluator; pub(crate) mod inclusive_metrics_evaluator; pub(crate) mod inclusive_projection; pub(crate) mod manifest_evaluator; +pub(crate) mod row_group_metrics_evaluator; diff --git a/crates/iceberg/src/expr/visitors/row_group_metrics_evaluator.rs b/crates/iceberg/src/expr/visitors/row_group_metrics_evaluator.rs new file mode 100644 index 000000000..4bf53d6ee --- /dev/null +++ b/crates/iceberg/src/expr/visitors/row_group_metrics_evaluator.rs @@ -0,0 +1,1872 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Evaluates Parquet Row Group metrics + +use std::collections::HashMap; + +use fnv::FnvHashSet; +use parquet::file::metadata::RowGroupMetaData; +use parquet::file::statistics::Statistics; + +use crate::arrow::{get_parquet_stat_max_as_datum, get_parquet_stat_min_as_datum}; +use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; +use crate::expr::{BoundPredicate, BoundReference}; +use crate::spec::{Datum, PrimitiveLiteral, PrimitiveType, Schema}; +use crate::{Error, ErrorKind, Result}; + +pub(crate) struct RowGroupMetricsEvaluator<'a> { + row_group_metadata: &'a RowGroupMetaData, + iceberg_field_id_to_parquet_column_index: &'a HashMap, + snapshot_schema: &'a Schema, +} + +const IN_PREDICATE_LIMIT: usize = 200; +const ROW_GROUP_MIGHT_MATCH: Result = Ok(true); +const ROW_GROUP_CANT_MATCH: Result = Ok(false); + +impl<'a> RowGroupMetricsEvaluator<'a> { + fn new( + row_group_metadata: &'a RowGroupMetaData, + field_id_map: &'a HashMap, + snapshot_schema: &'a Schema, + ) -> Self { + Self { + row_group_metadata, + iceberg_field_id_to_parquet_column_index: field_id_map, + snapshot_schema, + } + } + + /// Evaluate this `RowGroupMetricsEvaluator`'s filter predicate against the + /// provided [`RowGroupMetaData`]'. Used by [`ArrowReader`] to + /// see if a Parquet file RowGroup could possibly contain data that matches + /// the scan's filter. + pub(crate) fn eval( + filter: &'a BoundPredicate, + row_group_metadata: &'a RowGroupMetaData, + field_id_map: &'a HashMap, + snapshot_schema: &'a Schema, + ) -> Result { + if row_group_metadata.num_rows() == 0 { + return ROW_GROUP_CANT_MATCH; + } + + let mut evaluator = Self::new(row_group_metadata, field_id_map, snapshot_schema); + + visit(&mut evaluator, filter) + } + + fn stats_for_field_id(&self, field_id: i32) -> Option<&Statistics> { + let parquet_column_index = *self + .iceberg_field_id_to_parquet_column_index + .get(&field_id)?; + self.row_group_metadata + .column(parquet_column_index) + .statistics() + } + + fn null_count(&self, field_id: i32) -> Option { + self.stats_for_field_id(field_id) + .map(|stats| stats.null_count()) + } + + fn value_count(&self) -> u64 { + self.row_group_metadata.num_rows() as u64 + } + + fn contains_nulls_only(&self, field_id: i32) -> bool { + let null_count = self.null_count(field_id); + let value_count = self.value_count(); + + null_count == Some(value_count) + } + + fn may_contain_null(&self, field_id: i32) -> bool { + if let Some(null_count) = self.null_count(field_id) { + null_count > 0 + } else { + true + } + } + + fn stats_and_type_for_field_id( + &self, + field_id: i32, + ) -> Result> { + let Some(stats) = self.stats_for_field_id(field_id) else { + // No statistics for column + return Ok(None); + }; + + let Some(field) = self.snapshot_schema.field_by_id(field_id) else { + return Err(Error::new( + ErrorKind::Unexpected, + format!( + "Could not find a field with id '{}' in the snapshot schema", + &field_id + ), + )); + }; + + let Some(primitive_type) = field.field_type.as_primitive_type() else { + return Err(Error::new( + ErrorKind::Unexpected, + format!( + "Could not determine the PrimitiveType for field id '{}'", + &field_id + ), + )); + }; + + Ok(Some((stats, primitive_type.clone()))) + } + + fn min_value(&self, field_id: i32) -> Result> { + let Some((stats, primitive_type)) = self.stats_and_type_for_field_id(field_id)? else { + return Ok(None); + }; + + if !stats.has_min_max_set() { + return Ok(None); + } + + get_parquet_stat_min_as_datum(&primitive_type, stats) + } + + fn max_value(&self, field_id: i32) -> Result> { + let Some((stats, primitive_type)) = self.stats_and_type_for_field_id(field_id)? else { + return Ok(None); + }; + + if !stats.has_min_max_set() { + return Ok(None); + } + + get_parquet_stat_max_as_datum(&primitive_type, stats) + } + + fn visit_inequality( + &mut self, + reference: &BoundReference, + datum: &Datum, + cmp_fn: fn(&Datum, &Datum) -> bool, + use_lower_bound: bool, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + if datum.is_nan() { + // NaN indicates unreliable bounds. + // See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } + + let bound = if use_lower_bound { + self.min_value(field_id) + } else { + self.max_value(field_id) + }?; + + if let Some(bound) = bound { + if cmp_fn(&bound, datum) { + return ROW_GROUP_MIGHT_MATCH; + } + + return ROW_GROUP_CANT_MATCH; + } + + ROW_GROUP_MIGHT_MATCH + } +} + +impl BoundPredicateVisitor for RowGroupMetricsEvaluator<'_> { + type T = bool; + + fn always_true(&mut self) -> Result { + ROW_GROUP_MIGHT_MATCH + } + + fn always_false(&mut self) -> Result { + ROW_GROUP_CANT_MATCH + } + + fn and(&mut self, lhs: bool, rhs: bool) -> Result { + Ok(lhs && rhs) + } + + fn or(&mut self, lhs: bool, rhs: bool) -> Result { + Ok(lhs || rhs) + } + + fn not(&mut self, inner: bool) -> Result { + Ok(!inner) + } + + fn is_null(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result { + let field_id = reference.field().id; + + match self.null_count(field_id) { + Some(0) => ROW_GROUP_CANT_MATCH, + Some(_) => ROW_GROUP_MIGHT_MATCH, + None => ROW_GROUP_MIGHT_MATCH, + } + } + + fn not_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + ROW_GROUP_MIGHT_MATCH + } + + fn is_nan(&mut self, _reference: &BoundReference, _predicate: &BoundPredicate) -> Result { + // NaN counts not in ColumnChunkMetadata Statistics + ROW_GROUP_MIGHT_MATCH + } + + fn not_nan( + &mut self, + _reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result { + // NaN counts not in ColumnChunkMetadata Statistics + ROW_GROUP_MIGHT_MATCH + } + + fn less_than( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + self.visit_inequality(reference, datum, PartialOrd::lt, true) + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + self.visit_inequality(reference, datum, PartialOrd::le, true) + } + + fn greater_than( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + self.visit_inequality(reference, datum, PartialOrd::gt, false) + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + self.visit_inequality(reference, datum, PartialOrd::ge, false) + } + + fn eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + if let Some(lower_bound) = self.min_value(field_id)? { + if lower_bound.is_nan() { + // NaN indicates unreliable bounds. + // See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } else if lower_bound.gt(datum) { + return ROW_GROUP_CANT_MATCH; + } + } + + if let Some(upper_bound) = self.max_value(field_id)? { + if upper_bound.is_nan() { + // NaN indicates unreliable bounds. + // See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } else if upper_bound.lt(datum) { + return ROW_GROUP_CANT_MATCH; + } + } + + ROW_GROUP_MIGHT_MATCH + } + + fn not_eq( + &mut self, + _reference: &BoundReference, + _datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + // Because the bounds are not necessarily a min or max value, + // this cannot be answered using them. notEq(col, X) with (X, Y) + // doesn't guarantee that X is a value in col. + ROW_GROUP_MIGHT_MATCH + } + + fn starts_with( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + let PrimitiveLiteral::String(datum) = datum.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string values", + )); + }; + + if let Some(lower_bound) = self.min_value(field_id)? { + let PrimitiveLiteral::String(lower_bound) = lower_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string lower_bound value", + )); + }; + + let prefix_length = lower_bound.chars().count().min(datum.chars().count()); + + // truncate lower bound so that its length + // is not greater than the length of prefix + let truncated_lower_bound = lower_bound.chars().take(prefix_length).collect::(); + if datum < &truncated_lower_bound { + return ROW_GROUP_CANT_MATCH; + } + } + + if let Some(upper_bound) = self.max_value(field_id)? { + let PrimitiveLiteral::String(upper_bound) = upper_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string upper_bound value", + )); + }; + + let prefix_length = upper_bound.chars().count().min(datum.chars().count()); + + // truncate upper bound so that its length + // is not greater than the length of prefix + let truncated_upper_bound = upper_bound.chars().take(prefix_length).collect::(); + if datum > &truncated_upper_bound { + return ROW_GROUP_CANT_MATCH; + } + } + + ROW_GROUP_MIGHT_MATCH + } + + fn not_starts_with( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.may_contain_null(field_id) { + return ROW_GROUP_MIGHT_MATCH; + } + + // notStartsWith will match unless all values must start with the prefix. + // This happens when the lower and upper bounds both start with the prefix. + + let PrimitiveLiteral::String(prefix) = datum.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string values", + )); + }; + + let Some(lower_bound) = self.min_value(field_id)? else { + return ROW_GROUP_MIGHT_MATCH; + }; + + let PrimitiveLiteral::String(lower_bound_str) = lower_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use NotStartsWith operator on non-string lower_bound value", + )); + }; + + if lower_bound_str < prefix { + // if lower is shorter than the prefix then lower doesn't start with the prefix + return ROW_GROUP_MIGHT_MATCH; + } + + let prefix_len = prefix.chars().count(); + + if lower_bound_str.chars().take(prefix_len).collect::() == *prefix { + // lower bound matches the prefix + + let Some(upper_bound) = self.max_value(field_id)? else { + return ROW_GROUP_MIGHT_MATCH; + }; + + let PrimitiveLiteral::String(upper_bound) = upper_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use NotStartsWith operator on non-string upper_bound value", + )); + }; + + // if upper is shorter than the prefix then upper can't start with the prefix + if upper_bound.chars().count() < prefix_len { + return ROW_GROUP_MIGHT_MATCH; + } + + if upper_bound.chars().take(prefix_len).collect::() == *prefix { + // both bounds match the prefix, so all rows must match the + // prefix and therefore do not satisfy the predicate + return ROW_GROUP_CANT_MATCH; + } + } + + ROW_GROUP_MIGHT_MATCH + } + + fn r#in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + if literals.len() > IN_PREDICATE_LIMIT { + // skip evaluating the predicate if the number of values is too big + return ROW_GROUP_MIGHT_MATCH; + } + + if let Some(lower_bound) = self.min_value(field_id)? { + if lower_bound.is_nan() { + // NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } + + if !literals.iter().any(|datum| datum.ge(&lower_bound)) { + // if all values are less than lower bound, rows cannot match. + return ROW_GROUP_CANT_MATCH; + } + } + + if let Some(upper_bound) = self.max_value(field_id)? { + if upper_bound.is_nan() { + // NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } + + if !literals.iter().any(|datum| datum.le(&upper_bound)) { + // if all values are greater than upper bound, rows cannot match. + return ROW_GROUP_CANT_MATCH; + } + } + + ROW_GROUP_MIGHT_MATCH + } + + fn not_in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result { + // Because the bounds are not necessarily a min or max value, + // this cannot be answered using them. notIn(col, {X, ...}) + // with (X, Y) doesn't guarantee that X is a value in col. + ROW_GROUP_MIGHT_MATCH + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use parquet::basic::{LogicalType as ParquetLogicalType, Type as ParquetPhysicalType}; + use parquet::data_type::ByteArray; + use parquet::file::metadata::{ColumnChunkMetaData, RowGroupMetaData}; + use parquet::file::statistics::Statistics; + use parquet::schema::types::{ + ColumnDescriptor, ColumnPath, SchemaDescriptor, Type as parquetSchemaType, + }; + use rand::{thread_rng, Rng}; + + use super::RowGroupMetricsEvaluator; + use crate::expr::{Bind, Reference}; + use crate::spec::{Datum, NestedField, PrimitiveType, Schema, Type}; + use crate::Result; + + #[test] + fn eval_matches_no_rows_for_empty_row_group() -> Result<()> { + let row_group_metadata = create_row_group_metadata(0, 0, None, 0, None)?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + + Ok(()) + } + + #[test] + fn eval_true_for_row_group_no_bounds_present() -> Result<()> { + let row_group_metadata = create_row_group_metadata(1, 1, None, 1, None)?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_null_filter_not_null() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 1, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_not_null() + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_all_null_filter_is_null() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 1, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_null() + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_none_null_filter_not_null() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_not_null() + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_none_null_filter_is_null() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_null() + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_nulls_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 1, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_datum_nan_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(f32::NAN)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_missing_bound_valid_other_bound_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_failing_bound_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(0.9), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_passing_bound_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_nulls_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 1, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_lower_nan_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(f32::NAN), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_lower_gt_than_datum_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(1.5), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_upper_nan_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(f32::NAN), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_upper_lt_than_datum_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(0.5), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_good_bounds_than_datum_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_bounds_eq_datum_filter_neq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(1.0), Some(1.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .not_equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_nulls_filter_starts_with() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array(None, None, None, 1, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_error_for_starts_with_non_string_filter_datum() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array(None, None, None, 0, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .starts_with(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_error_for_starts_with_non_utf8_lower_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // min val of 0xff is not valid utf-8 string. Max val of 0x20 is valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from(vec![255u8])), + Some(ByteArray::from(vec![32u8])), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_error_for_starts_with_non_utf8_upper_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("ice".as_bytes())), + Some(ByteArray::from(vec![255u8])), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_false_for_starts_with_meta_all_nulls() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array(None, None, None, 1, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_false_for_starts_with_datum_below_min_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("id".as_bytes())), + Some(ByteArray::from("ie".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_false_for_starts_with_datum_above_max_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("h".as_bytes())), + Some(ByteArray::from("ib".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_starts_with_datum_between_bounds() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("h".as_bytes())), + Some(ByteArray::from("j".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_all_nulls_filter_not_starts_with() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array(None, None, None, 1, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_error_for_not_starts_with_non_utf8_lower_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // min val of 0xff is not valid utf-8 string. Max val of 0x20 is valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from(vec![255u8])), + Some(ByteArray::from(vec![32u8])), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_error_for_not_starts_with_non_utf8_upper_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from(vec![255u8])), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_true_for_not_starts_with_no_min_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + None, + Some(ByteArray::from("iceberg".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_not_starts_with_datum_longer_min_max_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("ice".as_bytes())), + Some(ByteArray::from("iceberg".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_not_starts_with_datum_matches_lower_no_upper() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + None, + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_not_starts_with_datum_matches_lower_upper_shorter() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from("icy".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_not_starts_with_datum_matches_lower_and_upper() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from("iceberg".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_nulls_filter_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from("iceberg".as_bytes())), + None, + 1, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .is_in([Datum::string("ice"), Datum::string("berg")]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_too_many_literals_filter_is_in() -> Result<()> { + let mut rng = thread_rng(); + + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(11.0), Some(12.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in(std::iter::repeat_with(|| Datum::float(rng.gen_range(0.0..10.0))).take(1000)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_missing_bounds_filter_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array(None, None, None, 0, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .is_in([Datum::string("ice")]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_lower_bound_is_nan_filter_is_in() -> Result<()> { + // TODO: should this be false, since the max stat + // is lower than the min val in the set? + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(f32::NAN), Some(1.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in([Datum::float(2.0), Datum::float(3.0)]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_lower_bound_greater_than_all_vals_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(4.0), None, None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in([Datum::float(2.0), Datum::float(3.0)]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_nan_upper_bound_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(f32::NAN), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in([Datum::float(2.0), Datum::float(3.0)]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_upper_bound_below_all_vals_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(1.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in([Datum::float(2.0), Datum::float(3.0)]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_not_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from("iceberg".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .is_not_in([Datum::string("iceberg")]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + fn build_iceberg_schema_and_field_map() -> Result<(Arc, HashMap)> { + let iceberg_schema = Schema::builder() + .with_fields([ + Arc::new(NestedField::new( + 1, + "col_float", + Type::Primitive(PrimitiveType::Float), + false, + )), + Arc::new(NestedField::new( + 2, + "col_string", + Type::Primitive(PrimitiveType::String), + false, + )), + ]) + .build()?; + let iceberg_schema_ref = Arc::new(iceberg_schema); + + let field_id_map = HashMap::from_iter([(1, 0), (2, 1)]); + + Ok((iceberg_schema_ref, field_id_map)) + } + + fn build_parquet_schema_descriptor() -> Result> { + let field_1 = Arc::new( + parquetSchemaType::primitive_type_builder("col_float", ParquetPhysicalType::FLOAT) + .with_id(Some(1)) + .build()?, + ); + + let field_2 = Arc::new( + parquetSchemaType::primitive_type_builder( + "col_string", + ParquetPhysicalType::BYTE_ARRAY, + ) + .with_id(Some(2)) + .with_logical_type(Some(ParquetLogicalType::String)) + .build()?, + ); + + let group_type = Arc::new( + parquetSchemaType::group_type_builder("all") + .with_id(Some(1000)) + .with_fields(vec![field_1, field_2]) + .build()?, + ); + + let schema_descriptor = SchemaDescriptor::new(group_type); + let schema_descriptor_arc = Arc::new(schema_descriptor); + Ok(schema_descriptor_arc) + } + + fn create_row_group_metadata( + num_rows: i64, + col_1_num_vals: i64, + col_1_stats: Option, + col_2_num_vals: i64, + col_2_stats: Option, + ) -> Result { + let schema_descriptor_arc = build_parquet_schema_descriptor()?; + + let column_1_desc_ptr = Arc::new(ColumnDescriptor::new( + schema_descriptor_arc.column(0).self_type_ptr(), + 1, + 1, + ColumnPath::new(vec!["col_float".to_string()]), + )); + + let column_2_desc_ptr = Arc::new(ColumnDescriptor::new( + schema_descriptor_arc.column(1).self_type_ptr(), + 1, + 1, + ColumnPath::new(vec!["col_string".to_string()]), + )); + + let mut col_1_meta = + ColumnChunkMetaData::builder(column_1_desc_ptr).set_num_values(col_1_num_vals); + if let Some(stats1) = col_1_stats { + col_1_meta = col_1_meta.set_statistics(stats1) + } + + let mut col_2_meta = + ColumnChunkMetaData::builder(column_2_desc_ptr).set_num_values(col_2_num_vals); + if let Some(stats2) = col_2_stats { + col_2_meta = col_2_meta.set_statistics(stats2) + } + + let row_group_metadata = RowGroupMetaData::builder(schema_descriptor_arc) + .set_num_rows(num_rows) + .set_column_metadata(vec![ + col_1_meta.build()?, + // .set_statistics(Statistics::float(None, None, None, 1, false)) + col_2_meta.build()?, + ]) + .build(); + + Ok(row_group_metadata?) + } +} diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs index 04aa1f577..45d7d4fd1 100644 --- a/crates/iceberg/src/scan.rs +++ b/crates/iceberg/src/scan.rs @@ -60,6 +60,7 @@ pub struct TableScanBuilder<'a> { concurrency_limit_data_files: usize, concurrency_limit_manifest_entries: usize, concurrency_limit_manifest_files: usize, + row_group_filtering_enabled: bool, } impl<'a> TableScanBuilder<'a> { @@ -76,6 +77,7 @@ impl<'a> TableScanBuilder<'a> { concurrency_limit_data_files: num_cpus, concurrency_limit_manifest_entries: num_cpus, concurrency_limit_manifest_files: num_cpus, + row_group_filtering_enabled: true, } } @@ -142,9 +144,16 @@ impl<'a> TableScanBuilder<'a> { self } - /// Sets the manifest file concurrency limit for this scan - pub fn with_manifest_file_concurrency_limit(mut self, limit: usize) -> Self { - self.concurrency_limit_manifest_files = limit; + /// Determines whether to enable row group filtering. + /// When enabled, if a read is performed with a filter predicate, + /// then the metadata for each row group in the parquet file is + /// evaluated against the filter predicate and row groups + /// that cant contain matching rows will be skipped entirely. + /// + /// Defaults to enabled, as it generally improves performance or + /// keeps it the same, with performance degradation unlikely. + pub fn with_row_group_filtering_enabled(mut self, row_group_filtering_enabled: bool) -> Self { + self.row_group_filtering_enabled = row_group_filtering_enabled; self } @@ -258,6 +267,7 @@ impl<'a> TableScanBuilder<'a> { concurrency_limit_data_files: self.concurrency_limit_data_files, concurrency_limit_manifest_entries: self.concurrency_limit_manifest_entries, concurrency_limit_manifest_files: self.concurrency_limit_manifest_files, + row_group_filtering_enabled: self.row_group_filtering_enabled, }) } } @@ -280,6 +290,8 @@ pub struct TableScan { /// The maximum number of [`ManifestEntry`]s that will /// be processed in parallel concurrency_limit_data_files: usize, + + row_group_filtering_enabled: bool, } /// PlanContext wraps a [`SnapshotRef`] alongside all the other @@ -346,7 +358,7 @@ impl TableScan { .try_for_each_concurrent( concurrency_limit_manifest_entries, |(manifest_entry_context, tx)| async move { - crate::runtime::spawn(async move { + spawn(async move { Self::process_manifest_entry(manifest_entry_context, tx).await }) .await @@ -365,7 +377,8 @@ impl TableScan { /// Returns an [`ArrowRecordBatchStream`]. pub async fn to_arrow(&self) -> Result { let mut arrow_reader_builder = ArrowReaderBuilder::new(self.file_io.clone()) - .with_data_file_concurrency_limit(self.concurrency_limit_data_files); + .with_data_file_concurrency_limit(self.concurrency_limit_data_files) + .with_row_group_filtering_enabled(self.row_group_filtering_enabled); if let Some(batch_size) = self.batch_size { arrow_reader_builder = arrow_reader_builder.with_batch_size(batch_size); diff --git a/crates/iceberg/src/writer/file_writer/parquet_writer.rs b/crates/iceberg/src/writer/file_writer/parquet_writer.rs index 11ba04f6a..3e2db5855 100644 --- a/crates/iceberg/src/writer/file_writer/parquet_writer.rs +++ b/crates/iceberg/src/writer/file_writer/parquet_writer.rs @@ -27,23 +27,20 @@ use futures::future::BoxFuture; use itertools::Itertools; use parquet::arrow::async_writer::AsyncFileWriter as ArrowAsyncFileWriter; use parquet::arrow::AsyncArrowWriter; -use parquet::data_type::{ - BoolType, ByteArray, ByteArrayType, DataType as ParquetDataType, DoubleType, FixedLenByteArray, - FixedLenByteArrayType, FloatType, Int32Type, Int64Type, -}; use parquet::file::properties::WriterProperties; -use parquet::file::statistics::{from_thrift, Statistics, TypedStatistics}; +use parquet::file::statistics::{from_thrift, Statistics}; use parquet::format::FileMetaData; -use uuid::Uuid; use super::location_generator::{FileNameGenerator, LocationGenerator}; use super::track_writer::TrackWriter; use super::{FileWriter, FileWriterBuilder}; -use crate::arrow::DEFAULT_MAP_FIELD_NAME; +use crate::arrow::{ + get_parquet_stat_max_as_datum, get_parquet_stat_min_as_datum, DEFAULT_MAP_FIELD_NAME, +}; use crate::io::{FileIO, FileWrite, OutputFile}; use crate::spec::{ visit_schema, DataFileBuilder, DataFileFormat, Datum, ListType, MapType, NestedFieldRef, - PrimitiveLiteral, PrimitiveType, Schema, SchemaRef, SchemaVisitor, StructType, Type, + PrimitiveType, Schema, SchemaRef, SchemaVisitor, StructType, Type, }; use crate::writer::CurrentFileStatus; use crate::{Error, ErrorKind, Result}; @@ -237,34 +234,26 @@ impl MinMaxColAggregator { } } - fn update_state( - &mut self, - field_id: i32, - state: &TypedStatistics, - convert_func: impl Fn(::T) -> Result, - ) { - if state.min_is_exact() { - let val = convert_func(state.min().clone()).unwrap(); - self.lower_bounds - .entry(field_id) - .and_modify(|e| { - if *e > val { - *e = val.clone() - } - }) - .or_insert(val); - } - if state.max_is_exact() { - let val = convert_func(state.max().clone()).unwrap(); - self.upper_bounds - .entry(field_id) - .and_modify(|e| { - if *e < val { - *e = val.clone() - } - }) - .or_insert(val); - } + fn update_state_min(&mut self, field_id: i32, datum: Datum) { + self.lower_bounds + .entry(field_id) + .and_modify(|e| { + if *e > datum { + *e = datum.clone() + } + }) + .or_insert(datum); + } + + fn update_state_max(&mut self, field_id: i32, datum: Datum) { + self.upper_bounds + .entry(field_id) + .and_modify(|e| { + if *e > datum { + *e = datum.clone() + } + }) + .or_insert(datum); } fn update(&mut self, field_id: i32, value: Statistics) -> Result<()> { @@ -287,142 +276,28 @@ impl MinMaxColAggregator { )); }; - match (&ty, value) { - (PrimitiveType::Boolean, Statistics::Boolean(stat)) => { - let convert_func = |v: bool| Result::::Ok(Datum::bool(v)); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Int, Statistics::Int32(stat)) => { - let convert_func = |v: i32| Result::::Ok(Datum::int(v)); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Long, Statistics::Int64(stat)) => { - let convert_func = |v: i64| Result::::Ok(Datum::long(v)); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Float, Statistics::Float(stat)) => { - let convert_func = |v: f32| Result::::Ok(Datum::float(v)); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Double, Statistics::Double(stat)) => { - let convert_func = |v: f64| Result::::Ok(Datum::double(v)); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::String, Statistics::ByteArray(stat)) => { - let convert_func = |v: ByteArray| { - Result::::Ok(Datum::string( - String::from_utf8(v.data().to_vec()).unwrap(), - )) - }; - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Binary, Statistics::ByteArray(stat)) => { - let convert_func = - |v: ByteArray| Result::::Ok(Datum::binary(v.data().to_vec())); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Date, Statistics::Int32(stat)) => { - let convert_func = |v: i32| Result::::Ok(Datum::date(v)); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Time, Statistics::Int64(stat)) => { - let convert_func = |v: i64| Datum::time_micros(v); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Timestamp, Statistics::Int64(stat)) => { - let convert_func = |v: i64| Result::::Ok(Datum::timestamp_micros(v)); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Timestamptz, Statistics::Int64(stat)) => { - let convert_func = |v: i64| Result::::Ok(Datum::timestamptz_micros(v)); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::TimestampNs, Statistics::Int64(stat)) => { - let convert_func = |v: i64| Result::::Ok(Datum::timestamp_nanos(v)); - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::TimestamptzNs, Statistics::Int64(stat)) => { - let convert_func = |v: i64| Result::::Ok(Datum::timestamptz_nanos(v)); - self.update_state::(field_id, &stat, convert_func) - } - ( - PrimitiveType::Decimal { - precision: _, - scale: _, - }, - Statistics::ByteArray(stat), - ) => { - let convert_func = |v: ByteArray| -> Result { - Result::::Ok(Datum::new( - ty.clone(), - PrimitiveLiteral::Int128(i128::from_le_bytes(v.data().try_into().unwrap())), - )) - }; - self.update_state::(field_id, &stat, convert_func) - } - ( - PrimitiveType::Decimal { - precision: _, - scale: _, - }, - Statistics::Int32(stat), - ) => { - let convert_func = |v: i32| { - Result::::Ok(Datum::new( - ty.clone(), - PrimitiveLiteral::Int128(i128::from(v)), - )) - }; - self.update_state::(field_id, &stat, convert_func) - } - ( - PrimitiveType::Decimal { - precision: _, - scale: _, - }, - Statistics::Int64(stat), - ) => { - let convert_func = |v: i64| { - Result::::Ok(Datum::new( - ty.clone(), - PrimitiveLiteral::Int128(i128::from(v)), - )) - }; - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Uuid, Statistics::FixedLenByteArray(stat)) => { - let convert_func = |v: FixedLenByteArray| { - if v.len() != 16 { - return Err(Error::new( - ErrorKind::Unexpected, - "Invalid length of uuid bytes.", - )); - } - Ok(Datum::uuid(Uuid::from_bytes( - v.data()[..16].try_into().unwrap(), - ))) - }; - self.update_state::(field_id, &stat, convert_func) - } - (PrimitiveType::Fixed(len), Statistics::FixedLenByteArray(stat)) => { - let convert_func = |v: FixedLenByteArray| { - if v.len() != *len as usize { - return Err(Error::new( - ErrorKind::Unexpected, - "Invalid length of fixed bytes.", - )); - } - Ok(Datum::fixed(v.data().to_vec())) - }; - self.update_state::(field_id, &stat, convert_func) - } - (ty, value) => { + if value.min_is_exact() { + let Some(min_datum) = get_parquet_stat_min_as_datum(&ty, &value)? else { return Err(Error::new( ErrorKind::Unexpected, format!("Statistics {} is not match with field type {}.", value, ty), - )) - } + )); + }; + + self.update_state_min(field_id, min_datum); } + + if value.max_is_exact() { + let Some(max_datum) = get_parquet_stat_max_as_datum(&ty, &value)? else { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Statistics {} is not match with field type {}.", value, ty), + )); + }; + + self.update_state_max(field_id, max_datum); + } + Ok(()) } @@ -609,6 +484,7 @@ mod tests { use arrow_select::concat::concat_batches; use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use tempfile::TempDir; + use uuid::Uuid; use super::*; use crate::io::FileIOBuilder;