From 23d1e900f192e20ae4008207e23ebacc8f3b3427 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris Date: Wed, 27 Sep 2023 17:09:34 +0200 Subject: [PATCH] fix: change partitioning schema from large to normal string for pyarrow<12 (#1671) # Description If pyarrow is below v12.0.0 it changes the partitioning schema fields from large_string to string. # Related Issue(s) closes #1669 # Documentation https://github.com/apache/arrow/issues/34546#issuecomment-1466587443 --------- Co-authored-by: Will Jones --- python/deltalake/writer.py | 20 +++++++++++++++++++- python/tests/test_writer.py | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 1fb5403b76..db399e857e 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -208,8 +208,26 @@ def write_deltalake( else: # creating a new table current_version = -1 + dtype_map = { + pa.large_string(): pa.string(), # type: ignore + } + + def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType: + try: + return dtype_map[dtype] + except KeyError: + return dtype + if partition_by: - partition_schema = pa.schema([schema.field(name) for name in partition_by]) + if PYARROW_MAJOR_VERSION < 12: + partition_schema = pa.schema( + [ + pa.field(name, _large_to_normal_dtype(schema.field(name).type)) + for name in partition_by + ] + ) + else: + partition_schema = pa.schema([schema.field(name) for name in partition_by]) partitioning = ds.partitioning(partition_schema, flavor="hive") else: partitioning = None diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index e72d0ac8cd..3385e32175 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -841,6 +841,25 @@ def test_large_arrow_types(tmp_path: pathlib.Path): assert table.schema == dt.schema().to_pyarrow(as_large_types=True) +def test_partition_large_arrow_types(tmp_path: pathlib.Path): + table = pa.table( + { + "foo": pa.array(["1", "1", "2", "2"], pa.large_string()), + "bar": pa.array([1, 2, 1, 2], pa.int64()), + "baz": pa.array([1, 1, 1, 1], pa.int64()), + } + ) + + write_deltalake(tmp_path, table, partition_by=["foo"]) + + dt = DeltaTable(tmp_path) + files = dt.files() + expected = ["foo=1", "foo=2"] + + result = sorted([file.split("/")[0] for file in files]) + assert expected == result + + def test_uint_arrow_types(tmp_path: pathlib.Path): pylist = [ {"num1": 3, "num2": 3, "num3": 3, "num4": 5},