diff --git a/python/pyiceberg/avro/__init__.py b/python/pyiceberg/avro/__init__.py index d7d8b55ef913..75440db77c1d 100644 --- a/python/pyiceberg/avro/__init__.py +++ b/python/pyiceberg/avro/__init__.py @@ -16,5 +16,8 @@ # under the License. import struct +STRUCT_BOOL = struct.Struct("?") STRUCT_FLOAT = struct.Struct(" Optional[pa.Array] def map_value_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]: return partner_map.items if isinstance(partner_map, pa.MapArray) else None + + +_PRIMITIVE_TO_PHYSICAL = { + BooleanType(): "BOOLEAN", + IntegerType(): "INT32", + LongType(): "INT64", + FloatType(): "FLOAT", + DoubleType(): "DOUBLE", + DateType(): "INT32", + TimeType(): "INT64", + TimestampType(): "INT64", + TimestamptzType(): "INT64", + StringType(): "BYTE_ARRAY", + UUIDType(): "FIXED_LEN_BYTE_ARRAY", + BinaryType(): "BYTE_ARRAY", +} +_PHYSICAL_TYPES = set(_PRIMITIVE_TO_PHYSICAL.values()).union({"INT96"}) + + +class StatsAggregator: + current_min: Any + current_max: Any + trunc_length: Optional[int] + + def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc_length: Optional[int] = None) -> None: + self.current_min = None + self.current_max = None + self.trunc_length = trunc_length + + if physical_type_string not in _PHYSICAL_TYPES: + raise ValueError(f"Unknown physical type {physical_type_string}") + + if physical_type_string == "INT96": + raise NotImplementedError("Statistics not implemented for INT96 physical type") + + expected_physical_type = _PRIMITIVE_TO_PHYSICAL[iceberg_type] + if expected_physical_type != physical_type_string: + raise ValueError( + f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}" + ) + + self.primitive_type = iceberg_type + + def serialize(self, value: Any) -> bytes: + return to_bytes(self.primitive_type, value) + + def update_min(self, val: Any) -> None: + self.current_min = val if self.current_min is None else min(val, self.current_min) + + def update_max(self, val: Any) -> None: + self.current_max = val if self.current_max is None else max(val, self.current_max) + + def min_as_bytes(self) -> bytes: + return self.serialize( + self.current_min + if self.trunc_length is None + else TruncateTransform(width=self.trunc_length).transform(self.primitive_type)(self.current_min) + ) + + def max_as_bytes(self) -> Optional[bytes]: + if self.current_max is None: + return None + + if self.primitive_type == StringType(): + if type(self.current_max) != str: + raise ValueError("Expected the current_max to be a string") + s_result = truncate_upper_bound_text_string(self.current_max, self.trunc_length) + return self.serialize(s_result) if s_result is not None else None + elif self.primitive_type == BinaryType(): + if type(self.current_max) != bytes: + raise ValueError("Expected the current_max to be bytes") + b_result = truncate_upper_bound_binary_string(self.current_max, self.trunc_length) + return self.serialize(b_result) if b_result is not None else None + else: + if self.trunc_length is not None: + raise ValueError(f"{self.primitive_type} cannot be truncated") + return self.serialize(self.current_max) + + +DEFAULT_TRUNCATION_LENGTH = 16 +TRUNCATION_EXPR = r"^truncate\((\d+)\)$" + + +class MetricModeTypes(Enum): + TRUNCATE = "truncate" + NONE = "none" + COUNTS = "counts" + FULL = "full" + + +DEFAULT_METRICS_MODE_KEY = "write.metadata.metrics.default" +COLUMN_METRICS_MODE_KEY_PREFIX = "write.metadata.metrics.column" + + +@dataclass(frozen=True) +class MetricsMode(Singleton): + type: MetricModeTypes + length: Optional[int] = None + + +_DEFAULT_METRICS_MODE = MetricsMode(MetricModeTypes.TRUNCATE, DEFAULT_TRUNCATION_LENGTH) + + +def match_metrics_mode(mode: str) -> MetricsMode: + sanitized_mode = mode.strip().lower() + if sanitized_mode.startswith("truncate"): + m = re.match(TRUNCATION_EXPR, sanitized_mode) + if m: + length = int(m[1]) + if length < 1: + raise ValueError("Truncation length must be larger than 0") + return MetricsMode(MetricModeTypes.TRUNCATE, int(m[1])) + else: + raise ValueError(f"Malformed truncate: {mode}") + elif sanitized_mode == "none": + return MetricsMode(MetricModeTypes.NONE) + elif sanitized_mode == "counts": + return MetricsMode(MetricModeTypes.COUNTS) + elif sanitized_mode == "full": + return MetricsMode(MetricModeTypes.FULL) + else: + raise ValueError(f"Unsupported metrics mode: {mode}") + + +@dataclass(frozen=True) +class StatisticsCollector: + field_id: int + iceberg_type: PrimitiveType + mode: MetricsMode + column_name: str + + +class PyArrowStatisticsCollector(PreOrderSchemaVisitor[List[StatisticsCollector]]): + _field_id: int = 0 + _schema: Schema + _properties: Dict[str, str] + _default_mode: Optional[str] + + def __init__(self, schema: Schema, properties: Dict[str, str]): + self._schema = schema + self._properties = properties + self._default_mode = self._properties.get(DEFAULT_METRICS_MODE_KEY) + + def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]: + return struct_result() + + def struct( + self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]] + ) -> List[StatisticsCollector]: + return list(chain(*[result() for result in field_results])) + + def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]: + self._field_id = field.field_id + return field_result() + + def list(self, list_type: ListType, element_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]: + self._field_id = list_type.element_id + return element_result() + + def map( + self, + map_type: MapType, + key_result: Callable[[], List[StatisticsCollector]], + value_result: Callable[[], List[StatisticsCollector]], + ) -> List[StatisticsCollector]: + self._field_id = map_type.key_id + k = key_result() + self._field_id = map_type.value_id + v = value_result() + return k + v + + def primitive(self, primitive: PrimitiveType) -> List[StatisticsCollector]: + column_name = self._schema.find_column_name(self._field_id) + if column_name is None: + return [] + + metrics_mode = _DEFAULT_METRICS_MODE + + if self._default_mode: + metrics_mode = match_metrics_mode(self._default_mode) + + col_mode = self._properties.get(f"{COLUMN_METRICS_MODE_KEY_PREFIX}.{column_name}") + if col_mode: + metrics_mode = match_metrics_mode(col_mode) + + if ( + not (isinstance(primitive, StringType) or isinstance(primitive, BinaryType)) + and metrics_mode.type == MetricModeTypes.TRUNCATE + ): + metrics_mode = MetricsMode(MetricModeTypes.FULL) + + is_nested = column_name.find(".") >= 0 + + if is_nested and metrics_mode.type in [MetricModeTypes.TRUNCATE, MetricModeTypes.FULL]: + metrics_mode = MetricsMode(MetricModeTypes.COUNTS) + + return [StatisticsCollector(field_id=self._field_id, iceberg_type=primitive, mode=metrics_mode, column_name=column_name)] + + +def compute_statistics_plan( + schema: Schema, + table_properties: Dict[str, str], +) -> Dict[int, StatisticsCollector]: + """ + Compute the statistics plan for all columns. + + The resulting list is assumed to have the same length and same order as the columns in the pyarrow table. + This allows the list to map from the column index to the Iceberg column ID. + For each element, the desired metrics collection that was provided by the user in the configuration + is computed and then adjusted according to the data type of the column. For nested columns the minimum + and maximum values are not computed. And truncation is only applied to text of binary strings. + + Args: + table_properties (from pyiceberg.table.metadata.TableMetadata): The Iceberg table metadata properties. + They are required to compute the mapping of column position to iceberg schema type id. It's also + used to set the mode for column metrics collection + """ + stats_cols = pre_order_visit(schema, PyArrowStatisticsCollector(schema, table_properties)) + result: Dict[int, StatisticsCollector] = {} + for stats_col in stats_cols: + result[stats_col.field_id] = stats_col + return result + + +@dataclass(frozen=True) +class ID2ParquetPath: + field_id: int + parquet_path: str + + +class ID2ParquetPathVisitor(PreOrderSchemaVisitor[List[ID2ParquetPath]]): + _field_id: int = 0 + _path: List[str] + + def __init__(self) -> None: + self._path = [] + + def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]: + return struct_result() + + def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]: + return list(chain(*[result() for result in field_results])) + + def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]: + self._field_id = field.field_id + self._path.append(field.name) + result = field_result() + self._path.pop() + return result + + def list(self, list_type: ListType, element_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]: + self._field_id = list_type.element_id + self._path.append("list.element") + result = element_result() + self._path.pop() + return result + + def map( + self, + map_type: MapType, + key_result: Callable[[], List[ID2ParquetPath]], + value_result: Callable[[], List[ID2ParquetPath]], + ) -> List[ID2ParquetPath]: + self._field_id = map_type.key_id + self._path.append("key_value.key") + k = key_result() + self._path.pop() + self._field_id = map_type.value_id + self._path.append("key_value.value") + v = value_result() + self._path.pop() + return k + v + + def primitive(self, primitive: PrimitiveType) -> List[ID2ParquetPath]: + return [ID2ParquetPath(field_id=self._field_id, parquet_path=".".join(self._path))] + + +def parquet_path_to_id_mapping( + schema: Schema, +) -> Dict[str, int]: + """ + Compute the mapping of parquet column path to Iceberg ID. + + For each column, the parquet file metadata has a path_in_schema attribute that follows + a specific naming scheme for nested columnds. This function computes a mapping of + the full paths to the corresponding Iceberg IDs. + + Args: + schema (pyiceberg.schema.Schema): The current table schema. + """ + result: Dict[str, int] = {} + for pair in pre_order_visit(schema, ID2ParquetPathVisitor()): + result[pair.parquet_path] = pair.field_id + return result + + +def fill_parquet_file_metadata( + df: DataFile, + parquet_metadata: pq.FileMetaData, + file_size: int, + stats_columns: Dict[int, StatisticsCollector], + parquet_column_mapping: Dict[str, int], +) -> None: + """ + Compute and fill the following fields of the DataFile object. + + - file_format + - record_count + - file_size_in_bytes + - column_sizes + - value_counts + - null_value_counts + - nan_value_counts + - lower_bounds + - upper_bounds + - split_offsets + + Args: + df (DataFile): A DataFile object representing the Parquet file for which metadata is to be filled. + parquet_metadata (pyarrow.parquet.FileMetaData): A pyarrow metadata object. + file_size (int): The total compressed file size cannot be retrieved from the metadata and hence has to + be passed here. Depending on the kind of file system and pyarrow library call used, different + ways to obtain this value might be appropriate. + stats_columns (Dict[int, StatisticsCollector]): The statistics gathering plan. It is required to + set the mode for column metrics collection + """ + if parquet_metadata.num_columns != len(stats_columns): + raise ValueError( + f"Number of columns in statistics configuration ({len(stats_columns)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})" + ) + + if parquet_metadata.num_columns != len(parquet_column_mapping): + raise ValueError( + f"Number of columns in column mapping ({len(parquet_column_mapping)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})" + ) + + column_sizes: Dict[int, int] = {} + value_counts: Dict[int, int] = {} + split_offsets: List[int] = [] + + null_value_counts: Dict[int, int] = {} + nan_value_counts: Dict[int, int] = {} + + col_aggs = {} + + for r in range(parquet_metadata.num_row_groups): + # References: + # https://github.com/apache/iceberg/blob/fc381a81a1fdb8f51a0637ca27cd30673bd7aad3/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java#L232 + # https://github.com/apache/parquet-mr/blob/ac29db4611f86a07cc6877b416aa4b183e09b353/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/metadata/ColumnChunkMetaData.java#L184 + + row_group = parquet_metadata.row_group(r) + + data_offset = row_group.column(0).data_page_offset + dictionary_offset = row_group.column(0).dictionary_page_offset + + if row_group.column(0).has_dictionary_page and dictionary_offset < data_offset: + split_offsets.append(dictionary_offset) + else: + split_offsets.append(data_offset) + + invalidate_col: Set[int] = set() + + for pos in range(0, parquet_metadata.num_columns): + column = row_group.column(pos) + field_id = parquet_column_mapping[column.path_in_schema] + + stats_col = stats_columns[field_id] + + column_sizes.setdefault(field_id, 0) + column_sizes[field_id] += column.total_compressed_size + + if stats_col.mode == MetricsMode(MetricModeTypes.NONE): + continue + + value_counts[field_id] = value_counts.get(field_id, 0) + column.num_values + + if column.is_stats_set: + try: + statistics = column.statistics + + if statistics.has_null_count: + null_value_counts[field_id] = null_value_counts.get(field_id, 0) + statistics.null_count + + if stats_col.mode == MetricsMode(MetricModeTypes.COUNTS): + continue + + if field_id not in col_aggs: + col_aggs[field_id] = StatsAggregator( + stats_col.iceberg_type, statistics.physical_type, stats_col.mode.length + ) + + col_aggs[field_id].update_min(statistics.min) + col_aggs[field_id].update_max(statistics.max) + + except pyarrow.lib.ArrowNotImplementedError as e: + invalidate_col.add(field_id) + logger.warning(e) + else: + invalidate_col.add(field_id) + logger.warning("PyArrow statistics missing for column %d when writing file", pos) + + split_offsets.sort() + + lower_bounds = {} + upper_bounds = {} + + for k, agg in col_aggs.items(): + _min = agg.min_as_bytes() + if _min is not None: + lower_bounds[k] = _min + _max = agg.max_as_bytes() + if _max is not None: + upper_bounds[k] = _max + + for field_id in invalidate_col: + del lower_bounds[field_id] + del upper_bounds[field_id] + del null_value_counts[field_id] + + df.file_format = FileFormat.PARQUET + df.record_count = parquet_metadata.num_rows + df.file_size_in_bytes = file_size + df.column_sizes = column_sizes + df.value_counts = value_counts + df.null_value_counts = null_value_counts + df.nan_value_counts = nan_value_counts + df.lower_bounds = lower_bounds + df.upper_bounds = upper_bounds + df.split_offsets = split_offsets diff --git a/python/pyiceberg/utils/truncate.py b/python/pyiceberg/utils/truncate.py new file mode 100644 index 000000000000..4ddb2401c42d --- /dev/null +++ b/python/pyiceberg/utils/truncate.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + + +def truncate_upper_bound_text_string(value: str, trunc_length: Optional[int]) -> Optional[str]: + result = value[:trunc_length] + if result != value: + chars = [*result] + + for i in range(-1, -len(result) - 1, -1): + try: + to_inc = ord(chars[i]) + # will raise exception if the highest unicode code is reached + _next = chr(to_inc + 1) + chars[i] = _next + return "".join(chars) + except ValueError: + pass + return None # didn't find a valid upper bound + return result + + +def truncate_upper_bound_binary_string(value: bytes, trunc_length: Optional[int]) -> Optional[bytes]: + result = value[:trunc_length] + if result != value: + _bytes = [*result] + for i in range(-1, -len(result) - 1, -1): + if _bytes[i] < 255: + _bytes[i] += 1 + return b"".join([i.to_bytes(1, byteorder="little") for i in _bytes]) + return None + + return result diff --git a/python/tests/io/test_pyarrow_stats.py b/python/tests/io/test_pyarrow_stats.py new file mode 100644 index 000000000000..74297fe52627 --- /dev/null +++ b/python/tests/io/test_pyarrow_stats.py @@ -0,0 +1,798 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=protected-access,unused-argument,redefined-outer-name + +import math +import tempfile +import uuid +from dataclasses import asdict, dataclass +from datetime import ( + date, + datetime, + time, + timedelta, + timezone, +) +from typing import ( + Any, + Dict, + List, + Optional, + Tuple, + Union, +) + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from pyiceberg.avro import ( + STRUCT_BOOL, + STRUCT_DOUBLE, + STRUCT_FLOAT, + STRUCT_INT32, + STRUCT_INT64, +) +from pyiceberg.io.pyarrow import ( + MetricModeTypes, + MetricsMode, + PyArrowStatisticsCollector, + compute_statistics_plan, + fill_parquet_file_metadata, + match_metrics_mode, + parquet_path_to_id_mapping, + schema_to_pyarrow, +) +from pyiceberg.manifest import DataFile +from pyiceberg.schema import Schema, pre_order_visit +from pyiceberg.table.metadata import ( + TableMetadata, + TableMetadataUtil, + TableMetadataV1, + TableMetadataV2, +) +from pyiceberg.types import ( + BooleanType, + FloatType, + IntegerType, + StringType, +) +from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros + + +@dataclass(frozen=True) +class TestStruct: + x: Optional[int] + y: Optional[float] + + +def construct_test_table() -> Tuple[Any, Any, Union[TableMetadataV1, TableMetadataV2]]: + table_metadata = { + "format-version": 2, + "location": "s3://bucket/test/location", + "last-column-id": 7, + "current-schema-id": 0, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "fields": [ + {"id": 1, "name": "strings", "required": False, "type": "string"}, + {"id": 2, "name": "floats", "required": False, "type": "float"}, + { + "id": 3, + "name": "list", + "required": False, + "type": {"type": "list", "element-id": 6, "element": "long", "element-required": False}, + }, + { + "id": 4, + "name": "maps", + "required": False, + "type": { + "type": "map", + "key-id": 7, + "key": "long", + "value-id": 8, + "value": "long", + "value-required": False, + }, + }, + { + "id": 5, + "name": "structs", + "required": False, + "type": { + "type": "struct", + "fields": [ + {"id": 9, "name": "x", "required": False, "type": "long"}, + {"id": 10, "name": "y", "required": False, "type": "float", "doc": "comment"}, + ], + }, + }, + ], + }, + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "properties": {}, + } + + table_metadata = TableMetadataUtil.parse_obj(table_metadata) + arrow_schema = schema_to_pyarrow(table_metadata.schemas[0]) + + _strings = ["zzzzzzzzzzzzzzzzzzzz", "rrrrrrrrrrrrrrrrrrrr", None, "aaaaaaaaaaaaaaaaaaaa"] + + _floats = [3.14, math.nan, 1.69, 100] + + _list = [[1, 2, 3], [4, 5, 6], None, [7, 8, 9]] + + _maps: List[Optional[Dict[int, int]]] = [ + {1: 2, 3: 4}, + None, + {5: 6}, + {}, + ] + + _structs = [ + asdict(TestStruct(1, 0.2)), + asdict(TestStruct(None, -1.34)), + None, + asdict(TestStruct(54, None)), + ] + + table = pa.Table.from_pydict( + { + "strings": _strings, + "floats": _floats, + "list": _list, + "maps": _maps, + "structs": _structs, + }, + schema=arrow_schema, + ) + metadata_collector: List[Any] = [] + + with pa.BufferOutputStream() as f: + with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector) as writer: + writer.write_table(table) + + return f.getvalue(), metadata_collector[0], table_metadata + + +def get_current_schema( + table_metadata: TableMetadata, +) -> Schema: + return next(filter(lambda s: s.schema_id == table_metadata.current_schema_id, table_metadata.schemas)) + + +def test_record_count() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + assert datafile.record_count == 4 + + +def test_file_size() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert datafile.file_size_in_bytes == len(file_bytes) + + +def test_value_counts() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.value_counts) == 7 + assert datafile.value_counts[1] == 4 + assert datafile.value_counts[2] == 4 + assert datafile.value_counts[6] == 10 # 3 lists with 3 items and a None value + assert datafile.value_counts[7] == 5 + assert datafile.value_counts[8] == 5 + assert datafile.value_counts[9] == 4 + assert datafile.value_counts[10] == 4 + + +def test_column_sizes() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.column_sizes) == 7 + # these values are an artifact of how the write_table encodes the columns + assert datafile.column_sizes[1] > 0 + assert datafile.column_sizes[2] > 0 + assert datafile.column_sizes[6] > 0 + assert datafile.column_sizes[7] > 0 + assert datafile.column_sizes[8] > 0 + + +def test_null_and_nan_counts() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.null_value_counts) == 7 + assert datafile.null_value_counts[1] == 1 + assert datafile.null_value_counts[2] == 0 + assert datafile.null_value_counts[6] == 1 + assert datafile.null_value_counts[7] == 2 + assert datafile.null_value_counts[8] == 2 + assert datafile.null_value_counts[9] == 2 + assert datafile.null_value_counts[10] == 2 + + # #arrow does not include this in the statistics + # assert len(datafile.nan_value_counts) == 3 + # assert datafile.nan_value_counts[1] == 0 + # assert datafile.nan_value_counts[2] == 1 + # assert datafile.nan_value_counts[3] == 0 + + +def test_bounds() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.lower_bounds) == 2 + assert datafile.lower_bounds[1].decode() == "aaaaaaaaaaaaaaaa" + assert datafile.lower_bounds[2] == STRUCT_FLOAT.pack(1.69) + + assert len(datafile.upper_bounds) == 2 + assert datafile.upper_bounds[1].decode() == "zzzzzzzzzzzzzzz{" + assert datafile.upper_bounds[2] == STRUCT_FLOAT.pack(100) + + +def test_metrics_mode_parsing() -> None: + assert match_metrics_mode("none") == MetricsMode(MetricModeTypes.NONE) + assert match_metrics_mode("nOnE") == MetricsMode(MetricModeTypes.NONE) + assert match_metrics_mode("counts") == MetricsMode(MetricModeTypes.COUNTS) + assert match_metrics_mode("Counts") == MetricsMode(MetricModeTypes.COUNTS) + assert match_metrics_mode("full") == MetricsMode(MetricModeTypes.FULL) + assert match_metrics_mode("FuLl") == MetricsMode(MetricModeTypes.FULL) + assert match_metrics_mode(" FuLl") == MetricsMode(MetricModeTypes.FULL) + + assert match_metrics_mode("truncate(16)") == MetricsMode(MetricModeTypes.TRUNCATE, 16) + assert match_metrics_mode("trUncatE(16)") == MetricsMode(MetricModeTypes.TRUNCATE, 16) + assert match_metrics_mode("trUncatE(7)") == MetricsMode(MetricModeTypes.TRUNCATE, 7) + assert match_metrics_mode("trUncatE(07)") == MetricsMode(MetricModeTypes.TRUNCATE, 7) + + with pytest.raises(ValueError) as exc_info: + match_metrics_mode("trUncatE(-7)") + assert "Malformed truncate: trUncatE(-7)" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + match_metrics_mode("trUncatE(0)") + assert "Truncation length must be larger than 0" in str(exc_info.value) + + +def test_metrics_mode_none() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + table_metadata.properties["write.metadata.metrics.default"] = "none" + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.value_counts) == 0 + assert len(datafile.null_value_counts) == 0 + assert len(datafile.nan_value_counts) == 0 + assert len(datafile.lower_bounds) == 0 + assert len(datafile.upper_bounds) == 0 + + +def test_metrics_mode_counts() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + table_metadata.properties["write.metadata.metrics.default"] = "counts" + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.value_counts) == 7 + assert len(datafile.null_value_counts) == 7 + assert len(datafile.nan_value_counts) == 0 + assert len(datafile.lower_bounds) == 0 + assert len(datafile.upper_bounds) == 0 + + +def test_metrics_mode_full() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + table_metadata.properties["write.metadata.metrics.default"] = "full" + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.value_counts) == 7 + assert len(datafile.null_value_counts) == 7 + assert len(datafile.nan_value_counts) == 0 + + assert len(datafile.lower_bounds) == 2 + assert datafile.lower_bounds[1].decode() == "aaaaaaaaaaaaaaaaaaaa" + assert datafile.lower_bounds[2] == STRUCT_FLOAT.pack(1.69) + + assert len(datafile.upper_bounds) == 2 + assert datafile.upper_bounds[1].decode() == "zzzzzzzzzzzzzzzzzzzz" + assert datafile.upper_bounds[2] == STRUCT_FLOAT.pack(100) + + +def test_metrics_mode_non_default_trunc() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + table_metadata.properties["write.metadata.metrics.default"] = "truncate(2)" + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.value_counts) == 7 + assert len(datafile.null_value_counts) == 7 + assert len(datafile.nan_value_counts) == 0 + + assert len(datafile.lower_bounds) == 2 + assert datafile.lower_bounds[1].decode() == "aa" + assert datafile.lower_bounds[2] == STRUCT_FLOAT.pack(1.69) + + assert len(datafile.upper_bounds) == 2 + assert datafile.upper_bounds[1].decode() == "z{" + assert datafile.upper_bounds[2] == STRUCT_FLOAT.pack(100) + + +def test_column_metrics_mode() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + table_metadata.properties["write.metadata.metrics.default"] = "truncate(2)" + table_metadata.properties["write.metadata.metrics.column.strings"] = "none" + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.value_counts) == 6 + assert len(datafile.null_value_counts) == 6 + assert len(datafile.nan_value_counts) == 0 + + assert len(datafile.lower_bounds) == 1 + assert datafile.lower_bounds[2] == STRUCT_FLOAT.pack(1.69) + assert 1 not in datafile.lower_bounds + + assert len(datafile.upper_bounds) == 1 + assert datafile.upper_bounds[2] == STRUCT_FLOAT.pack(100) + assert 1 not in datafile.upper_bounds + + +def construct_test_table_primitive_types() -> Tuple[Any, Any, Union[TableMetadataV1, TableMetadataV2]]: + table_metadata = { + "format-version": 2, + "location": "s3://bucket/test/location", + "last-column-id": 7, + "current-schema-id": 0, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "fields": [ + {"id": 1, "name": "booleans", "required": False, "type": "boolean"}, + {"id": 2, "name": "ints", "required": False, "type": "int"}, + {"id": 3, "name": "longs", "required": False, "type": "long"}, + {"id": 4, "name": "floats", "required": False, "type": "float"}, + {"id": 5, "name": "doubles", "required": False, "type": "double"}, + {"id": 6, "name": "dates", "required": False, "type": "date"}, + {"id": 7, "name": "times", "required": False, "type": "time"}, + {"id": 8, "name": "timestamps", "required": False, "type": "timestamp"}, + {"id": 9, "name": "timestamptzs", "required": False, "type": "timestamptz"}, + {"id": 10, "name": "strings", "required": False, "type": "string"}, + {"id": 11, "name": "uuids", "required": False, "type": "uuid"}, + {"id": 12, "name": "binaries", "required": False, "type": "binary"}, + ], + }, + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "properties": {}, + } + + table_metadata = TableMetadataUtil.parse_obj(table_metadata) + arrow_schema = schema_to_pyarrow(table_metadata.schemas[0]) + tz = timezone(timedelta(seconds=19800)) + + booleans = [True, False] + ints = [23, 89] + longs = [54, 2] + floats = [454.1223, 24342.29] + doubles = [8542.12, -43.9] + dates = [date(2022, 1, 2), date(2023, 2, 4)] + times = [time(17, 30, 34), time(13, 21, 4)] + timestamps = [datetime(2022, 1, 2, 17, 30, 34, 399), datetime(2023, 2, 4, 13, 21, 4, 354)] + timestamptzs = [datetime(2022, 1, 2, 17, 30, 34, 399, tz), datetime(2023, 2, 4, 13, 21, 4, 354, tz)] + strings = ["hello", "world"] + uuids = [uuid.uuid3(uuid.NAMESPACE_DNS, "foo").bytes, uuid.uuid3(uuid.NAMESPACE_DNS, "bar").bytes] + binaries = [b"hello", b"world"] + + table = pa.Table.from_pydict( + { + "booleans": booleans, + "ints": ints, + "longs": longs, + "floats": floats, + "doubles": doubles, + "dates": dates, + "times": times, + "timestamps": timestamps, + "timestamptzs": timestamptzs, + "strings": strings, + "uuids": uuids, + "binaries": binaries, + }, + schema=arrow_schema, + ) + + metadata_collector: List[Any] = [] + + with pa.BufferOutputStream() as f: + with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector) as writer: + writer.write_table(table) + + return f.getvalue(), metadata_collector[0], table_metadata + + +def test_metrics_primitive_types() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table_primitive_types() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + table_metadata.properties["write.metadata.metrics.default"] = "truncate(2)" + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.value_counts) == 12 + assert len(datafile.null_value_counts) == 12 + assert len(datafile.nan_value_counts) == 0 + + tz = timezone(timedelta(seconds=19800)) + + assert len(datafile.lower_bounds) == 12 + assert datafile.lower_bounds[1] == STRUCT_BOOL.pack(False) + assert datafile.lower_bounds[2] == STRUCT_INT32.pack(23) + assert datafile.lower_bounds[3] == STRUCT_INT64.pack(2) + assert datafile.lower_bounds[4] == STRUCT_FLOAT.pack(454.1223) + assert datafile.lower_bounds[5] == STRUCT_DOUBLE.pack(-43.9) + assert datafile.lower_bounds[6] == STRUCT_INT32.pack(date_to_days(date(2022, 1, 2))) + assert datafile.lower_bounds[7] == STRUCT_INT64.pack(time_to_micros(time(13, 21, 4))) + assert datafile.lower_bounds[8] == STRUCT_INT64.pack(datetime_to_micros(datetime(2022, 1, 2, 17, 30, 34, 399))) + assert datafile.lower_bounds[9] == STRUCT_INT64.pack(datetime_to_micros(datetime(2022, 1, 2, 17, 30, 34, 399, tz))) + assert datafile.lower_bounds[10] == b"he" + assert datafile.lower_bounds[11] == uuid.uuid3(uuid.NAMESPACE_DNS, "foo").bytes + assert datafile.lower_bounds[12] == b"he" + + assert len(datafile.upper_bounds) == 12 + assert datafile.upper_bounds[1] == STRUCT_BOOL.pack(True) + assert datafile.upper_bounds[2] == STRUCT_INT32.pack(89) + assert datafile.upper_bounds[3] == STRUCT_INT64.pack(54) + assert datafile.upper_bounds[4] == STRUCT_FLOAT.pack(24342.29) + assert datafile.upper_bounds[5] == STRUCT_DOUBLE.pack(8542.12) + assert datafile.upper_bounds[6] == STRUCT_INT32.pack(date_to_days(date(2023, 2, 4))) + assert datafile.upper_bounds[7] == STRUCT_INT64.pack(time_to_micros(time(17, 30, 34))) + assert datafile.upper_bounds[8] == STRUCT_INT64.pack(datetime_to_micros(datetime(2023, 2, 4, 13, 21, 4, 354))) + assert datafile.upper_bounds[9] == STRUCT_INT64.pack(datetime_to_micros(datetime(2023, 2, 4, 13, 21, 4, 354, tz))) + assert datafile.upper_bounds[10] == b"wp" + assert datafile.upper_bounds[11] == uuid.uuid3(uuid.NAMESPACE_DNS, "bar").bytes + assert datafile.upper_bounds[12] == b"wp" + + +def construct_test_table_invalid_upper_bound() -> Tuple[Any, Any, Union[TableMetadataV1, TableMetadataV2]]: + table_metadata = { + "format-version": 2, + "location": "s3://bucket/test/location", + "last-column-id": 7, + "current-schema-id": 0, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "fields": [ + {"id": 1, "name": "valid_upper_binary", "required": False, "type": "binary"}, + {"id": 2, "name": "invalid_upper_binary", "required": False, "type": "binary"}, + {"id": 3, "name": "valid_upper_string", "required": False, "type": "string"}, + {"id": 4, "name": "invalid_upper_string", "required": False, "type": "string"}, + ], + }, + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "properties": {}, + } + + table_metadata = TableMetadataUtil.parse_obj(table_metadata) + arrow_schema = schema_to_pyarrow(table_metadata.schemas[0]) + + valid_binaries = [b"\x00\x00\x00", b"\xff\xfe\x00"] + invalid_binaries = [b"\x00\x00\x00", b"\xff\xff\x00"] + + valid_strings = ["\x00\x00\x00", "".join([chr(0x10FFFF), chr(0x10FFFE), chr(0x0)])] + invalid_strings = ["\x00\x00\x00", "".join([chr(0x10FFFF), chr(0x10FFFF), chr(0x0)])] + + table = pa.Table.from_pydict( + { + "valid_upper_binary": valid_binaries, + "invalid_upper_binary": invalid_binaries, + "valid_upper_string": valid_strings, + "invalid_upper_string": invalid_strings, + }, + schema=arrow_schema, + ) + + metadata_collector: List[Any] = [] + + with pa.BufferOutputStream() as f: + with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector) as writer: + writer.write_table(table) + + return f.getvalue(), metadata_collector[0], table_metadata + + +def test_metrics_invalid_upper_bound() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table_invalid_upper_bound() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + table_metadata.properties["write.metadata.metrics.default"] = "truncate(2)" + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert len(datafile.value_counts) == 4 + assert len(datafile.null_value_counts) == 4 + assert len(datafile.nan_value_counts) == 0 + + assert len(datafile.lower_bounds) == 4 + assert datafile.lower_bounds[1] == b"\x00\x00" + assert datafile.lower_bounds[2] == b"\x00\x00" + assert datafile.lower_bounds[3] == b"\x00\x00" + assert datafile.lower_bounds[4] == b"\x00\x00" + + assert len(datafile.upper_bounds) == 2 + assert datafile.upper_bounds[1] == b"\xff\xff" + assert datafile.upper_bounds[3] == "".join([chr(0x10FFFF), chr(0x10FFFF)]).encode() + + +def test_offsets() -> None: + (file_bytes, metadata, table_metadata) = construct_test_table() + + schema = get_current_schema(table_metadata) + datafile = DataFile() + fill_parquet_file_metadata( + datafile, + metadata, + len(file_bytes), + compute_statistics_plan(schema, table_metadata.properties), + parquet_path_to_id_mapping(schema), + ) + + assert datafile.split_offsets is not None + assert len(datafile.split_offsets) == 1 + assert datafile.split_offsets[0] == 4 + + +def test_write_and_read_stats_schema(table_schema_nested: Schema) -> None: + tbl = pa.Table.from_pydict( + { + "foo": ["a", "b"], + "bar": [1, 2], + "baz": [False, True], + "qux": [["a", "b"], ["c", "d"]], + "quux": [[("a", (("aa", 1), ("ab", 2)))], [("b", (("ba", 3), ("bb", 4)))]], + "location": [[(52.377956, 4.897070), (4.897070, -122.431297)], [(43.618881, -116.215019), (41.881832, -87.623177)]], + "person": [("Fokko", 33), ("Max", 42)], # Possible data quality issue + }, + schema=schema_to_pyarrow(table_schema_nested), + ) + stats_columns = pre_order_visit(table_schema_nested, PyArrowStatisticsCollector(table_schema_nested, {})) + + visited_paths = [] + + def file_visitor(written_file: Any) -> None: + visited_paths.append(written_file) + + with tempfile.TemporaryDirectory() as tmpdir: + pq.write_to_dataset(tbl, tmpdir, file_visitor=file_visitor) + + assert visited_paths[0].metadata.num_columns == len(stats_columns) + + +def test_stats_types(table_schema_nested: Schema) -> None: + stats_columns = pre_order_visit(table_schema_nested, PyArrowStatisticsCollector(table_schema_nested, {})) + + # the field-ids should be sorted + assert all(stats_columns[i].field_id <= stats_columns[i + 1].field_id for i in range(len(stats_columns) - 1)) + assert [col.iceberg_type for col in stats_columns] == [ + StringType(), + IntegerType(), + BooleanType(), + StringType(), + StringType(), + StringType(), + IntegerType(), + FloatType(), + FloatType(), + StringType(), + IntegerType(), + ] + + +# This is commented out for now because write_to_dataset drops the partition +# columns making it harder to calculate the mapping from the column index to +# datatype id +# +# def test_dataset() -> pa.Buffer: + +# table_metadata = { +# "format-version": 2, +# "location": "s3://bucket/test/location", +# "last-column-id": 7, +# "current-schema-id": 0, +# "schemas": [ +# { +# "type": "struct", +# "schema-id": 0, +# "fields": [ +# {"id": 1, "name": "ints", "required": False, "type": "long"}, +# {"id": 2, "name": "even", "required": False, "type": "boolean"}, +# ], +# }, +# ], +# "default-spec-id": 0, +# "partition-specs": [{"spec-id": 0, "fields": []}], +# "properties": {}, +# } + +# table_metadata = TableMetadataUtil.parse_obj(table_metadata) +# schema = schema_to_pyarrow(table_metadata.schemas[0]) + +# _ints = [0, 2, 4, 8, 1, 3, 5, 7] +# parity = [True, True, True, True, False, False, False, False] + +# table = pa.Table.from_pydict({"ints": _ints, "even": parity}, schema=schema) + +# visited_paths = [] + +# def file_visitor(written_file: Any) -> None: +# visited_paths.append(written_file) + +# with TemporaryDirectory() as tmpdir: +# pq.write_to_dataset(table, tmpdir, partition_cols=["even"], file_visitor=file_visitor) + +# even = None +# odd = None + +# assert len(visited_paths) == 2 + +# for written_file in visited_paths: +# df = DataFile() + +# fill_parquet_file_metadata(df, written_file.metadata, written_file.size, table_metadata) + +# if "even=true" in written_file.path: +# even = df + +# if "even=false" in written_file.path: +# odd = df + +# assert even is not None +# assert odd is not None + +# assert len(even.value_counts) == 1 +# assert even.value_counts[1] == 4 +# assert len(even.lower_bounds) == 1 +# assert even.lower_bounds[1] == STRUCT_INT64.pack(0) +# assert len(even.upper_bounds) == 1 +# assert even.upper_bounds[1] == STRUCT_INT64.pack(8) + +# assert len(odd.value_counts) == 1 +# assert odd.value_counts[1] == 4 +# assert len(odd.lower_bounds) == 1 +# assert odd.lower_bounds[1] == STRUCT_INT64.pack(1) +# assert len(odd.upper_bounds) == 1 +# assert odd.upper_bounds[1] == STRUCT_INT64.pack(7) diff --git a/python/tests/utils/test_truncate.py b/python/tests/utils/test_truncate.py new file mode 100644 index 000000000000..b9c3c10335fa --- /dev/null +++ b/python/tests/utils/test_truncate.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string + + +def test_upper_bound_string_truncation() -> None: + assert truncate_upper_bound_text_string("aaaa", 2) == "ab" + assert truncate_upper_bound_text_string("".join([chr(0x10FFFF), chr(0x10FFFF), chr(0x0)]), 2) is None + + +def test_upper_bound_binary_truncation() -> None: + assert truncate_upper_bound_binary_string(b"\x01\x02\x03", 2) == b"\x01\x03" + assert truncate_upper_bound_binary_string(b"\xff\xff\x00", 2) is None