Skip to content

Commit

Permalink
fix: change partitioning schema from large to normal string for pyarr…
Browse files Browse the repository at this point in the history
…ow<12 (#1671)

# Description
If pyarrow is below v12.0.0 it changes the partitioning schema fields
from large_string to string.

# Related Issue(s)
closes #1669 

# Documentation
apache/arrow#34546 (comment)

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
  • Loading branch information
ion-elgreco and wjones127 authored Sep 27, 2023
1 parent 113fd0f commit 23d1e90
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
20 changes: 19 additions & 1 deletion python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,26 @@ def write_deltalake(
else: # creating a new table
current_version = -1

dtype_map = {
pa.large_string(): pa.string(), # type: ignore
}

def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType:
try:
return dtype_map[dtype]
except KeyError:
return dtype

if partition_by:
partition_schema = pa.schema([schema.field(name) for name in partition_by])
if PYARROW_MAJOR_VERSION < 12:
partition_schema = pa.schema(
[
pa.field(name, _large_to_normal_dtype(schema.field(name).type))
for name in partition_by
]
)
else:
partition_schema = pa.schema([schema.field(name) for name in partition_by])
partitioning = ds.partitioning(partition_schema, flavor="hive")
else:
partitioning = None
Expand Down
19 changes: 19 additions & 0 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,25 @@ def test_large_arrow_types(tmp_path: pathlib.Path):
assert table.schema == dt.schema().to_pyarrow(as_large_types=True)


def test_partition_large_arrow_types(tmp_path: pathlib.Path):
table = pa.table(
{
"foo": pa.array(["1", "1", "2", "2"], pa.large_string()),
"bar": pa.array([1, 2, 1, 2], pa.int64()),
"baz": pa.array([1, 1, 1, 1], pa.int64()),
}
)

write_deltalake(tmp_path, table, partition_by=["foo"])

dt = DeltaTable(tmp_path)
files = dt.files()
expected = ["foo=1", "foo=2"]

result = sorted([file.split("/")[0] for file in files])
assert expected == result


def test_uint_arrow_types(tmp_path: pathlib.Path):
pylist = [
{"num1": 3, "num2": 3, "num3": 3, "num4": 5},
Expand Down

0 comments on commit 23d1e90

Please sign in to comment.