Skip to content

Commit

Permalink
Merge pull request #1661 from roeap/partition-values
Browse files Browse the repository at this point in the history
fix: more consistent handling of partition values and file paths
  • Loading branch information
rtyler authored Sep 25, 2023
2 parents a482e4f + a6d8c56 commit 56e1e87
Show file tree
Hide file tree
Showing 18 changed files with 366 additions and 67 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ Cargo.lock
!/delta-inspect/Cargo.lock
!/proofs/Cargo.lock

justfile
20 changes: 20 additions & 0 deletions python/deltalake/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from datetime import date, datetime
from typing import Any


def encode_partition_value(val: Any) -> str:
# Rules based on: https://github.com/delta-io/delta/blob/master/PROTOCOL.md#partition-value-serialization
if isinstance(val, bool):
return str(val).lower()
if isinstance(val, str):
return val
elif isinstance(val, (int, float)):
return str(val)
elif isinstance(val, date):
return val.isoformat()
elif isinstance(val, datetime):
return val.isoformat(sep=" ")
elif isinstance(val, bytes):
return val.decode("unicode_escape", "backslashreplace")
else:
raise ValueError(f"Could not encode partition value for type: {val}")
5 changes: 3 additions & 2 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import pandas

from ._internal import RawDeltaTable
from ._util import encode_partition_value
from .data_catalog import DataCatalog
from .exceptions import DeltaProtocolError
from .fs import DeltaStorageHandler
Expand Down Expand Up @@ -625,9 +626,9 @@ def __stringify_partition_values(
for field, op, value in partition_filters:
str_value: Union[str, List[str]]
if isinstance(value, (list, tuple)):
str_value = [str(val) for val in value]
str_value = [encode_partition_value(val) for val in value]
else:
str_value = str(value)
str_value = encode_partition_value(value)
out.append((field, op, str_value))
return out

Expand Down
25 changes: 5 additions & 20 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
Tuple,
Union,
)
from urllib.parse import unquote

from deltalake.fs import DeltaStorageHandler

from ._util import encode_partition_value

if TYPE_CHECKING:
import pandas as pd

Expand Down Expand Up @@ -262,7 +265,7 @@ def check_data_is_aligned_with_partition_filtering(
for i in range(partition_values.num_rows):
# Map will maintain order of partition_columns
partition_map = {
column_name: __encode_partition_value(
column_name: encode_partition_value(
batch.column(column_name)[i].as_py()
)
for column_name in table.metadata().partition_columns
Expand Down Expand Up @@ -422,7 +425,7 @@ def get_partitions_from_path(path: str) -> Tuple[str, Dict[str, Optional[str]]]:
if value == "__HIVE_DEFAULT_PARTITION__":
out[key] = None
else:
out[key] = value
out[key] = unquote(value)
return path, out


Expand Down Expand Up @@ -489,21 +492,3 @@ def iter_groups(metadata: Any) -> Iterator[Any]:
maximum for maximum in maximums if maximum is not None
)
return stats


def __encode_partition_value(val: Any) -> str:
# Rules based on: https://github.com/delta-io/delta/blob/master/PROTOCOL.md#partition-value-serialization
if isinstance(val, bool):
return str(val).lower()
if isinstance(val, str):
return val
elif isinstance(val, (int, float)):
return str(val)
elif isinstance(val, date):
return val.isoformat()
elif isinstance(val, datetime):
return val.isoformat(sep=" ")
elif isinstance(val, bytes):
return val.decode("unicode_escape", "backslashreplace")
else:
raise ValueError(f"Could not encode partition value for type: {val}")
34 changes: 33 additions & 1 deletion python/tests/pyspark_integration/test_writer_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import delta
import delta.pip_utils
import delta.tables
import pyspark.pandas as ps

spark = get_spark()
except ModuleNotFoundError:
Expand All @@ -34,7 +35,7 @@ def test_basic_read(sample_data: pa.Table, existing_table: DeltaTable):
@pytest.mark.pyspark
@pytest.mark.integration
def test_partitioned(tmp_path: pathlib.Path, sample_data: pa.Table):
partition_cols = ["date32", "utf8"]
partition_cols = ["date32", "utf8", "timestamp", "bool"]

# Add null values to sample data to verify we can read null partitions
sample_data_with_null = sample_data
Expand Down Expand Up @@ -63,3 +64,34 @@ def test_overwrite(

write_deltalake(path, sample_data, mode="overwrite")
assert_spark_read_equal(sample_data, path)


@pytest.mark.pyspark
@pytest.mark.integration
def test_issue_1591_roundtrip_special_characters(tmp_path: pathlib.Path):
test_string = r'$%&/()=^"[]#*?.:_-{=}|`<>~/\r\n+'
poisoned = "}|`<>~"
for char in poisoned:
test_string = test_string.replace(char, "")

data = pa.table(
{
"string": pa.array([test_string], type=pa.utf8()),
"data": pa.array(["python-module-test-write"]),
}
)

deltalake_path = tmp_path / "deltalake"
write_deltalake(
table_or_uri=deltalake_path, mode="append", data=data, partition_by=["string"]
)

loaded = ps.read_delta(str(deltalake_path), index_col=None).to_pandas()
assert loaded.shape == data.shape

spark_path = tmp_path / "spark"
spark_df = spark.createDataFrame(data.to_pandas())
spark_df.write.format("delta").partitionBy(["string"]).save(str(spark_path))

loaded = DeltaTable(spark_path).to_pandas()
assert loaded.shape == data.shape
46 changes: 46 additions & 0 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,3 +648,49 @@ def assert_scan_equals(table, predicate, expected):
assert_num_fragments(table, predicate, 2)
expected = pa.table({"part": ["a", "a", "b", "b"], "value": [1, 1, None, None]})
assert_scan_equals(table, predicate, expected)


def test_issue_1653_filter_bool_partition(tmp_path: Path):
ta = pa.Table.from_pydict(
{
"bool_col": [True, False, True, False],
"int_col": [0, 1, 2, 3],
"str_col": ["a", "b", "c", "d"],
}
)
write_deltalake(
tmp_path, ta, partition_by=["bool_col", "int_col"], mode="overwrite"
)
dt = DeltaTable(tmp_path)

assert (
dt.to_pyarrow_table(
filters=[
("int_col", "=", 0),
("bool_col", "=", True),
]
).num_rows
== 1
)
assert (
len(
dt.file_uris(
partition_filters=[
("int_col", "=", 0),
("bool_col", "=", "true"),
]
)
)
== 1
)
assert (
len(
dt.file_uris(
partition_filters=[
("int_col", "=", 0),
("bool_col", "=", True),
]
)
)
== 1
)
17 changes: 17 additions & 0 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table):
"bool",
"binary",
"date32",
"timestamp",
],
)
def test_roundtrip_partitioned(
Expand Down Expand Up @@ -888,3 +889,19 @@ def comp():
"a concurrent transaction deleted the same data your transaction deletes"
in str(exception)
)


def test_issue_1651_roundtrip_timestamp(tmp_path: pathlib.Path):
data = pa.table(
{
"id": pa.array([425], type=pa.int32()),
"data": pa.array(["python-module-test-write"]),
"t": pa.array([datetime(2023, 9, 15)]),
}
)

write_deltalake(table_or_uri=tmp_path, mode="append", data=data, partition_by=["t"])
dt = DeltaTable(table_uri=tmp_path)
dataset = dt.to_pyarrow_dataset()

assert dataset.count_rows() == 1
4 changes: 4 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ harness = false
name = "basic_operations"
required-features = ["datafusion"]

[[example]]
name = "load_table"
required-features = ["datafusion"]

[[example]]
name = "recordbatch-writer"
required-features = ["arrow"]
41 changes: 33 additions & 8 deletions rust/examples/basic_operations.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arrow::{
array::{Int32Array, StringArray},
datatypes::{DataType, Field, Schema as ArrowSchema},
array::{Int32Array, StringArray, TimestampMicrosecondArray},
datatypes::{DataType, Field, Schema as ArrowSchema, TimeUnit},
record_batch::RecordBatch,
};
use deltalake::operations::collect_sendable_stream;
Expand All @@ -26,34 +26,59 @@ fn get_table_columns() -> Vec<SchemaField> {
true,
Default::default(),
),
SchemaField::new(
String::from("timestamp"),
SchemaDataType::primitive(String::from("timestamp")),
true,
Default::default(),
),
]
}

fn get_table_batches() -> RecordBatch {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("int", DataType::Int32, false),
Field::new("string", DataType::Utf8, true),
Field::new(
"timestamp",
DataType::Timestamp(TimeUnit::Microsecond, None),
true,
),
]));

let int_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
let str_values = StringArray::from(vec!["A", "B", "A", "B", "A", "A", "A", "B", "B", "A", "A"]);

RecordBatch::try_new(schema, vec![Arc::new(int_values), Arc::new(str_values)]).unwrap()
let ts_values = TimestampMicrosecondArray::from(vec![
1000000012, 1000000012, 1000000012, 1000000012, 500012305, 500012305, 500012305, 500012305,
500012305, 500012305, 500012305,
]);
RecordBatch::try_new(
schema,
vec![
Arc::new(int_values),
Arc::new(str_values),
Arc::new(ts_values),
],
)
.unwrap()
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), deltalake::errors::DeltaTableError> {
// Create a delta operations client pointing at an un-initialized in-memory location.
// In a production environment this would be created with "try_new" and point at
// a real storage location.
let ops = DeltaOps::new_in_memory();
// Create a delta operations client pointing at an un-initialized location.
let ops = if let Ok(table_uri) = std::env::var("TABLE_URI") {
DeltaOps::try_from_uri(table_uri).await?
} else {
DeltaOps::new_in_memory()
};

// The operations module uses a builder pattern that allows specifying several options
// on how the command behaves. The builders implement `Into<Future>`, so once
// options are set you can run the command using `.await`.
let table = ops
.create()
.with_columns(get_table_columns())
.with_partition_columns(["timestamp"])
.with_table_name("my_table")
.with_comment("A table to show how delta-rs works")
.await?;
Expand Down
20 changes: 20 additions & 0 deletions rust/examples/load_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use arrow::record_batch::RecordBatch;
use deltalake::operations::collect_sendable_stream;
use deltalake::DeltaOps;

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), deltalake::errors::DeltaTableError> {
// Create a delta operations client pointing at an un-initialized location.
let ops = if let Ok(table_uri) = std::env::var("TABLE_URI") {
DeltaOps::try_from_uri(table_uri).await?
} else {
DeltaOps::try_from_uri("./rust/tests/data/delta-0.8.0").await?
};

let (_table, stream) = ops.load().await?;
let data: Vec<RecordBatch> = collect_sendable_stream(stream).await?;

println!("{:?}", data);

Ok(())
}
2 changes: 1 addition & 1 deletion rust/examples/recordbatch-writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn main() -> Result<(), DeltaTableError> {
})?;
info!("Using the location of: {:?}", table_uri);

let table_path = Path::from(table_uri.as_ref());
let table_path = Path::parse(&table_uri)?;

let maybe_table = deltalake::open_table(&table_path).await;
let mut table = match maybe_table {
Expand Down
Loading

0 comments on commit 56e1e87

Please sign in to comment.