Skip to content

Commit

Permalink
fix: casting when data to be written does not match table schema (#1427)
Browse files Browse the repository at this point in the history
# Description
Suppose a user has a table with column of type int. A user can create a
record batch with type Uft8 and write the value to table. My expectation
is that either the writer returns an error or ansi sql behavior is
implemented where non-numeric strings are turned into nulls.
  • Loading branch information
Blajda authored Jun 3, 2023
1 parent 98af4dc commit 98d9dcc
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 22 deletions.
26 changes: 9 additions & 17 deletions rust/src/operations/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use arrow::datatypes::Field;
use arrow::datatypes::Schema as ArrowSchema;
use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use arrow_cast::CastOptions;
use datafusion::datasource::file_format::{parquet::ParquetFormat, FileFormat};
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::MemTable;
Expand Down Expand Up @@ -460,6 +461,7 @@ async fn excute_non_empty_expr(
Some(snapshot.table_config().target_file_size() as usize),
None,
writer_properties,
&CastOptions { safe: false },
)
.await?;
metrics.rewrite_time_ms = Instant::now().duration_since(write_start).as_millis();
Expand Down Expand Up @@ -680,6 +682,7 @@ mod tests {

use crate::action::*;
use crate::operations::DeltaOps;
use crate::writer::test_utils::datafusion::get_data;
use crate::writer::test_utils::{get_arrow_schema, get_delta_schema};
use crate::DeltaTable;
use arrow::array::Int32Array;
Expand All @@ -703,17 +706,6 @@ mod tests {
table
}

async fn get_data(table: DeltaTable) -> Vec<RecordBatch> {
let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table)).unwrap();
ctx.sql("select * from test")
.await
.unwrap()
.collect()
.await
.unwrap()
}

#[tokio::test]
async fn test_delete_default() {
let schema = get_arrow_schema(&None);
Expand Down Expand Up @@ -849,7 +841,7 @@ mod tests {
"+----+-------+------------+",
];

let actual = get_data(table).await;
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}

Expand Down Expand Up @@ -897,7 +889,7 @@ mod tests {
"| 2 |",
"+-------+",
];
let actual = get_data(table).await;
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);

// Validate behaviour of less than
Expand All @@ -918,7 +910,7 @@ mod tests {
"| 4 |",
"+-------+",
];
let actual = get_data(table).await;
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);

// Validate behaviour of less plus not null
Expand All @@ -937,7 +929,7 @@ mod tests {
"| 4 |",
"+-------+",
];
let actual = get_data(table).await;
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}

Expand Down Expand Up @@ -996,7 +988,7 @@ mod tests {
"+----+-------+------------+",
];

let actual = get_data(table).await;
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}

Expand Down Expand Up @@ -1057,7 +1049,7 @@ mod tests {
"| B | 20 | 2021-02-03 |",
"+----+-------+------------+",
];
let actual = get_data(table).await;
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}

Expand Down
134 changes: 129 additions & 5 deletions rust/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};

use arrow_array::RecordBatch;
use arrow_cast::{can_cast_types, cast};
use arrow_cast::{can_cast_types, cast_with_options, CastOptions};
use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef};
use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::physical_plan::{memory::MemoryExec, ExecutionPlan};
Expand Down Expand Up @@ -97,6 +97,9 @@ pub struct WriteBuilder {
write_batch_size: Option<usize>,
/// RecordBatches to be written into the table
batches: Option<Vec<RecordBatch>>,
/// CastOptions determines how data types that do not match the underlying table are handled
/// By default an error is returned
cast_options: CastOptions,
}

impl WriteBuilder {
Expand All @@ -113,6 +116,7 @@ impl WriteBuilder {
target_file_size: None,
write_batch_size: None,
batches: None,
cast_options: CastOptions { safe: false },
}
}

Expand Down Expand Up @@ -168,6 +172,12 @@ impl WriteBuilder {
self
}

/// Specify the cast options to use when casting columns that do not match the table's schema.
pub fn with_cast_options(mut self, cast_options: CastOptions) -> Self {
self.cast_options = cast_options;
self
}

async fn check_preconditions(&self) -> DeltaResult<Vec<Action>> {
match self.store.is_delta_table_location().await? {
true => {
Expand Down Expand Up @@ -217,35 +227,41 @@ pub(crate) async fn write_execution_plan(
target_file_size: Option<usize>,
write_batch_size: Option<usize>,
writer_properties: Option<WriterProperties>,
cast_options: &CastOptions,
) -> DeltaResult<Vec<Add>> {
let invariants = snapshot
.current_metadata()
.and_then(|meta| meta.schema.get_invariants().ok())
.unwrap_or_default();

// Use input schema to prevent wrapping partitions columns into a dictionary.
let schema = snapshot.input_schema().unwrap_or(plan.schema());

let checker = DeltaDataChecker::new(invariants);

// Write data to disk
let mut tasks = vec![];
for i in 0..plan.output_partitioning().partition_count() {
let inner_plan = plan.clone();
let inner_schema = schema.clone();
let task_ctx = Arc::new(TaskContext::from(&state));
let inner_cast = cast_options.clone();
let config = WriterConfig::new(
inner_plan.schema(),
inner_schema.clone(),
partition_columns.clone(),
writer_properties.clone(),
target_file_size,
write_batch_size,
);
let mut writer = DeltaWriter::new(object_store.clone(), config);
let checker_stream = checker.clone();
let schema = inner_plan.schema().clone();
let mut stream = inner_plan.execute(i, task_ctx)?;
let handle: tokio::task::JoinHandle<DeltaResult<Vec<Add>>> =
tokio::task::spawn(async move {
while let Some(maybe_batch) = stream.next().await {
let batch = maybe_batch?;
checker_stream.check_batch(&batch).await?;
let arr = cast_record_batch(&batch, schema.clone())?;
let arr = cast_record_batch(&batch, inner_schema.clone(), &inner_cast)?;
writer.write(&arr).await?;
}
writer.close().await
Expand Down Expand Up @@ -376,6 +392,7 @@ impl std::future::IntoFuture for WriteBuilder {
this.target_file_size,
this.write_batch_size,
None,
&this.cast_options,
)
.await?;
actions.extend(add_actions.into_iter().map(Action::add));
Expand Down Expand Up @@ -464,14 +481,17 @@ fn can_cast_batch(from_schema: &ArrowSchema, to_schema: &ArrowSchema) -> bool {
fn cast_record_batch(
batch: &RecordBatch,
target_schema: ArrowSchemaRef,
cast_options: &CastOptions,
) -> DeltaResult<RecordBatch> {
//let cast_options = CastOptions { safe: false };

let columns = target_schema
.all_fields()
.iter()
.map(|f| {
let col = batch.column_by_name(f.name()).unwrap();
if !col.data_type().equals_datatype(f.data_type()) {
cast(col, f.data_type())
cast_with_options(col, f.data_type(), cast_options)
} else {
Ok(col.clone())
}
Expand All @@ -484,7 +504,13 @@ fn cast_record_batch(
mod tests {
use super::*;
use crate::operations::DeltaOps;
use crate::writer::test_utils::datafusion::get_data;
use crate::writer::test_utils::{get_delta_schema, get_record_batch};
use arrow::datatypes::Field;
use arrow::datatypes::Schema as ArrowSchema;
use arrow_array::{Int32Array, StringArray, TimestampMicrosecondArray};
use arrow_schema::{DataType, TimeUnit};
use datafusion::assert_batches_sorted_eq;
use serde_json::json;

#[tokio::test]
Expand Down Expand Up @@ -527,6 +553,104 @@ mod tests {
assert_eq!(table.get_file_uris().count(), 1)
}

#[tokio::test]
async fn test_write_different_types() {
// Ensure write data is casted when data of a different type from the table is provided.

// Validate String -> Int is err
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"value",
DataType::Int32,
true,
)]));

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int32Array::from(vec![Some(0), None]))],
)
.unwrap();
let table = DeltaOps::new_in_memory().write(vec![batch]).await.unwrap();

let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"value",
DataType::Utf8,
true,
)]));

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(StringArray::from(vec![
Some("Test123".to_owned()),
Some("123".to_owned()),
None,
]))],
)
.unwrap();

// Test cast options
let table = DeltaOps::from(table)
.write(vec![batch.clone()])
.with_cast_options(CastOptions { safe: true })
.await
.unwrap();

let expected = [
"+-------+",
"| value |",
"+-------+",
"| |",
"| |",
"| |",
"| 123 |",
"| 0 |",
"+-------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);

let res = DeltaOps::from(table).write(vec![batch]).await;
assert!(res.is_err());

// Validate the datetime -> string behavior
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"value",
arrow::datatypes::DataType::Utf8,
true,
)]));

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(StringArray::from(vec![Some(
"2023-06-03 15:35:00".to_owned(),
)]))],
)
.unwrap();
let table = DeltaOps::new_in_memory().write(vec![batch]).await.unwrap();

let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"value",
DataType::Timestamp(TimeUnit::Microsecond, None),
true,
)]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(TimestampMicrosecondArray::from(vec![Some(10000)]))],
)
.unwrap();

let _res = DeltaOps::from(table).write(vec![batch]).await.unwrap();
let expected = [
"+-------------------------+",
"| value |",
"+-------------------------+",
"| 1970-01-01T00:00:00.010 |",
"| 2023-06-03 15:35:00 |",
"+-------------------------+",
];
let actual = get_data(&_res).await;
assert_batches_sorted_eq!(&expected, &actual);
}

#[tokio::test]
async fn test_write_nonexistent() {
let batch = get_record_batch(None, false);
Expand Down
20 changes: 20 additions & 0 deletions rust/src/writer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,23 @@ pub async fn create_initialized_table(partition_cols: &[String]) -> DeltaTable {

table
}

#[cfg(feature = "datafusion")]
pub mod datafusion {
use crate::DeltaTable;
use arrow_array::RecordBatch;
use datafusion::prelude::SessionContext;
use std::sync::Arc;

pub async fn get_data(table: &DeltaTable) -> Vec<RecordBatch> {
let table = DeltaTable::new_with_state(table.object_store(), table.state.clone());
let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table)).unwrap();
ctx.sql("select * from test")
.await
.unwrap()
.collect()
.await
.unwrap()
}
}

0 comments on commit 98d9dcc

Please sign in to comment.