diff --git a/Cargo.lock b/Cargo.lock index 813da210e..33d01f6cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -939,7 +939,7 @@ dependencies = [ [[package]] name = "datafusion" version = "16.0.0" -source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=1d8c71a4703d27564c1a2bd60b164769f33fbe8f#1d8c71a4703d27564c1a2bd60b164769f33fbe8f" +source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=e3f156f4acc51fc45a7aa6a99085b091f539d5fa#e3f156f4acc51fc45a7aa6a99085b091f539d5fa" dependencies = [ "ahash 0.8.0", "arrow", @@ -985,7 +985,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "16.0.0" -source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=1d8c71a4703d27564c1a2bd60b164769f33fbe8f#1d8c71a4703d27564c1a2bd60b164769f33fbe8f" +source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=e3f156f4acc51fc45a7aa6a99085b091f539d5fa#e3f156f4acc51fc45a7aa6a99085b091f539d5fa" dependencies = [ "arrow", "chrono", @@ -999,7 +999,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "16.0.0" -source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=1d8c71a4703d27564c1a2bd60b164769f33fbe8f#1d8c71a4703d27564c1a2bd60b164769f33fbe8f" +source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=e3f156f4acc51fc45a7aa6a99085b091f539d5fa#e3f156f4acc51fc45a7aa6a99085b091f539d5fa" dependencies = [ "ahash 0.8.0", "arrow", @@ -1011,7 +1011,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "16.0.0" -source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=1d8c71a4703d27564c1a2bd60b164769f33fbe8f#1d8c71a4703d27564c1a2bd60b164769f33fbe8f" +source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=e3f156f4acc51fc45a7aa6a99085b091f539d5fa#e3f156f4acc51fc45a7aa6a99085b091f539d5fa" dependencies = [ "arrow", "async-trait", @@ -1027,7 +1027,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "16.0.0" -source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=1d8c71a4703d27564c1a2bd60b164769f33fbe8f#1d8c71a4703d27564c1a2bd60b164769f33fbe8f" +source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=e3f156f4acc51fc45a7aa6a99085b091f539d5fa#e3f156f4acc51fc45a7aa6a99085b091f539d5fa" dependencies = [ "ahash 0.8.0", "arrow", @@ -1057,7 +1057,7 @@ dependencies = [ [[package]] name = "datafusion-row" version = "16.0.0" -source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=1d8c71a4703d27564c1a2bd60b164769f33fbe8f#1d8c71a4703d27564c1a2bd60b164769f33fbe8f" +source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=e3f156f4acc51fc45a7aa6a99085b091f539d5fa#e3f156f4acc51fc45a7aa6a99085b091f539d5fa" dependencies = [ "arrow", "datafusion-common", @@ -1068,7 +1068,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "16.0.0" -source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=1d8c71a4703d27564c1a2bd60b164769f33fbe8f#1d8c71a4703d27564c1a2bd60b164769f33fbe8f" +source = "git+https://github.com/jonmmease/arrow-datafusion.git?rev=e3f156f4acc51fc45a7aa6a99085b091f539d5fa#e3f156f4acc51fc45a7aa6a99085b091f539d5fa" dependencies = [ "arrow-schema", "datafusion-common", diff --git a/Cargo.toml b/Cargo.toml index 4d350b72f..2c79db270 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,6 @@ arrow-array = { git = "https://github.com/jonmmease/arrow-rs.git", rev = "59a528 arrow-select = { git = "https://github.com/jonmmease/arrow-rs.git", rev = "59a5289b3603db08433e5309eed488a0abbf5d0d"} # DataFusion 16.0 with backports -datafusion = { git = "https://github.com/jonmmease/arrow-datafusion.git", rev = "1d8c71a4703d27564c1a2bd60b164769f33fbe8f"} -datafusion-common = { git = "https://github.com/jonmmease/arrow-datafusion.git", rev = "1d8c71a4703d27564c1a2bd60b164769f33fbe8f"} -datafusion-expr = { git = "https://github.com/jonmmease/arrow-datafusion.git", rev = "1d8c71a4703d27564c1a2bd60b164769f33fbe8f"} +datafusion = { git = "https://github.com/jonmmease/arrow-datafusion.git", rev = "e3f156f4acc51fc45a7aa6a99085b091f539d5fa"} +datafusion-common = { git = "https://github.com/jonmmease/arrow-datafusion.git", rev = "e3f156f4acc51fc45a7aa6a99085b091f539d5fa"} +datafusion-expr = { git = "https://github.com/jonmmease/arrow-datafusion.git", rev = "e3f156f4acc51fc45a7aa6a99085b091f539d5fa"} diff --git a/python/vegafusion/vegafusion/transformer.py b/python/vegafusion/vegafusion/transformer.py index 5dd2c1ee8..2e2c2ab32 100644 --- a/python/vegafusion/vegafusion/transformer.py +++ b/python/vegafusion/vegafusion/transformer.py @@ -67,8 +67,9 @@ def to_arrow_table(data): except IndexError: pass - # Convert DataFrame to table - table = pa.Table.from_pandas(data) + # Convert DataFrame to table. Keep index only if named + preserve_index = bool([name for name in getattr(data.index, "names", []) if name]) + table = pa.Table.from_pandas(data, preserve_index=preserve_index) return table diff --git a/vegafusion-core/src/data/mod.rs b/vegafusion-core/src/data/mod.rs index 96bdac2a4..805ae6128 100644 --- a/vegafusion-core/src/data/mod.rs +++ b/vegafusion-core/src/data/mod.rs @@ -1,4 +1,9 @@ +use arrow::datatypes::DataType; + pub mod json_writer; pub mod scalar; pub mod table; pub mod tasks; + +pub const ORDER_COL: &str = "_vf_order"; +pub const ORDER_COL_DTYPE: DataType = DataType::UInt32; diff --git a/vegafusion-core/src/data/table.rs b/vegafusion-core/src/data/table.rs index b21dca836..686c60366 100644 --- a/vegafusion-core/src/data/table.rs +++ b/vegafusion-core/src/data/table.rs @@ -19,8 +19,9 @@ use super::scalar::ScalarValue; use crate::arrow::array::ArrayRef; use crate::data::json_writer::record_batches_to_json_rows; -use arrow::array::StructArray; -use arrow::datatypes::Field; +use crate::data::{ORDER_COL, ORDER_COL_DTYPE}; +use arrow::array::{StructArray, UInt32Array}; +use arrow::datatypes::{Field, Schema}; use arrow::json::reader::DecoderOptions; use serde_json::{json, Value}; @@ -77,6 +78,50 @@ impl VegaFusionTable { } } + pub fn with_ordering(self) -> Result { + // Build new schema with leading ORDER_COL + let mut new_fields = self.schema.fields.clone(); + let mut start_idx = 0; + let leading_field = new_fields + .get(0) + .expect("VegaFusionTable must have at least one column"); + let has_order_col = if leading_field.name() == ORDER_COL { + // There is already a leading ORDER_COL, remove it and replace below + new_fields.remove(0); + true + } else { + // We need to add a new leading field for the ORDER_COL + false + }; + new_fields.insert(0, Field::new(ORDER_COL, ORDER_COL_DTYPE, false)); + + let new_schema = Arc::new(Schema::new(new_fields)) as SchemaRef; + + let new_batches = self + .batches + .into_iter() + .map(|batch| { + let order_array = Arc::new(UInt32Array::from_iter_values( + start_idx..(start_idx + batch.num_rows() as u32), + )) as ArrayRef; + + let mut new_columns = Vec::from(batch.columns()); + + if has_order_col { + new_columns[0] = order_array; + } else { + new_columns.insert(0, order_array); + } + + start_idx += batch.num_rows() as u32; + + Ok(RecordBatch::try_new(new_schema.clone(), new_columns)?) + }) + .collect::>>()?; + + Self::try_new(new_schema, new_batches) + } + pub fn batches(&self) -> &Vec { &self.batches } @@ -235,3 +280,53 @@ impl Hash for VegaFusionTable { self.to_ipc_bytes().unwrap().hash(state) } } + +#[cfg(test)] +mod tests { + use crate::data::table::VegaFusionTable; + use serde_json::json; + + #[test] + fn test_with_ordering() { + let table1 = VegaFusionTable::from_json( + &json!([ + {"a": 1, "b": "A"}, + {"a": 2, "b": "BB"}, + {"a": 10, "b": "CCC"}, + {"a": 20, "b": "DDDD"}, + ]), + 2, + ) + .unwrap(); + assert_eq!(table1.batches.len(), 2); + + let table2 = VegaFusionTable::from_json( + &json!([ + {"_vf_order": 10u32, "a": 1, "b": "A"}, + {"_vf_order": 9u32, "a": 2, "b": "BB"}, + {"_vf_order": 8u32, "a": 10, "b": "CCC"}, + {"_vf_order": 7u32, "a": 20, "b": "DDDD"}, + ]), + 2, + ) + .unwrap(); + assert_eq!(table2.batches.len(), 2); + + let expected_json = json!([ + {"_vf_order": 0u32, "a": 1, "b": "A"}, + {"_vf_order": 1u32, "a": 2, "b": "BB"}, + {"_vf_order": 2u32, "a": 10, "b": "CCC"}, + {"_vf_order": 3u32, "a": 20, "b": "DDDD"}, + ]); + + // Add ordering column to table without one + let result_table1 = table1.with_ordering().unwrap(); + assert_eq!(result_table1.batches.len(), 2); + assert_eq!(result_table1.to_json().unwrap(), expected_json); + + // Override prior ordering column + let result_table2 = table2.with_ordering().unwrap(); + assert_eq!(result_table2.batches.len(), 2); + assert_eq!(result_table2.to_json().unwrap(), expected_json); + } +} diff --git a/vegafusion-core/src/spec/transform/impute.rs b/vegafusion-core/src/spec/transform/impute.rs index f9139fc01..22a5916ed 100644 --- a/vegafusion-core/src/spec/transform/impute.rs +++ b/vegafusion-core/src/spec/transform/impute.rs @@ -8,6 +8,10 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; +fn default_value() -> Option { + Some(Value::Number(serde_json::Number::from(0))) +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ImputeTransformSpec { pub field: Field, @@ -22,7 +26,8 @@ pub struct ImputeTransformSpec { #[serde(skip_serializing_if = "Option::is_none")] pub groupby: Option>, - #[serde(skip_serializing_if = "Option::is_none")] + // Default to zero but serialize even if null + #[serde(default = "default_value")] pub value: Option, #[serde(flatten)] @@ -63,7 +68,6 @@ impl TransformSpecTrait for ImputeTransformSpec { && self.keyvals.is_none() && self.method() == ImputeMethodSpec::Value && num_unique_groupby <= 1 - && self.value.is_some() } fn transform_columns( diff --git a/vegafusion-rt-datafusion/src/data/tasks.rs b/vegafusion-rt-datafusion/src/data/tasks.rs index dedcba197..dffa02ec1 100644 --- a/vegafusion-rt-datafusion/src/data/tasks.rs +++ b/vegafusion-rt-datafusion/src/data/tasks.rs @@ -30,7 +30,7 @@ use crate::expression::escape::flat_col; use crate::sql::connection::datafusion_conn::DataFusionConnection; use crate::sql::dataframe::SqlDataFrame; use crate::task_graph::timezone::RuntimeTzConfig; -use crate::transform::pipeline::TransformPipelineUtils; +use crate::transform::pipeline::{remove_order_col, TransformPipelineUtils}; use vegafusion_core::data::scalar::{ScalarValue, ScalarValueHelpers}; use vegafusion_core::data::table::VegaFusionTable; use vegafusion_core::error::{Result, ResultWithContext, ToExternalError, VegaFusionError}; @@ -110,8 +110,13 @@ impl TaskCall for DataUrlTask { let inline_name = inline_name.trim().to_string(); if let Some(inline_dataset) = inline_datasets.get(&inline_name) { let sql_df = match inline_dataset { - VegaFusionDataset::Table { table, .. } => table.to_sql_dataframe().await?, - VegaFusionDataset::SqlDataFrame(sql_df) => sql_df.clone(), + VegaFusionDataset::Table { table, .. } => { + table.clone().with_ordering()?.to_sql_dataframe().await? + } + VegaFusionDataset::SqlDataFrame(sql_df) => { + // TODO: if no ordering column present, create with a window expression + sql_df.clone() + } }; let sql_df = process_datetimes(&parse, sql_df, &config.tz_config).await?; return eval_sql_df(sql_df.clone(), &self.pipeline, &config).await; @@ -162,7 +167,8 @@ async fn eval_sql_df( let pipeline = pipeline.as_ref().unwrap(); pipeline.eval_sql(sql_df, config).await? } else { - // No transforms + // No transforms, just remove any ordering column + let sql_df = remove_order_col(sql_df).await?; (sql_df.collect().await?, Vec::new()) }; @@ -416,6 +422,9 @@ impl TaskCall for DataValuesTask { return Ok((TaskValue::Table(values_table), Default::default())); } + // Add ordering column + let values_table = values_table.with_ordering()?; + // Get parse format for date processing let parse = self.format_type.as_ref().and_then(|fmt| fmt.parse.clone()); @@ -469,6 +478,9 @@ impl TaskCall for DataSourceTask { ) }); + // Add ordering column + let source_table = source_table.with_ordering()?; + // Apply transforms (if any) let (transformed_table, output_values) = if self .pipeline @@ -531,12 +543,18 @@ async fn read_csv(url: String, parse: &Option) -> Result { // Load through VegaFusionTable so that temp file can be deleted let df = ctx.read_csv(path, csv_opts).await.unwrap(); let table = VegaFusionTable::from_dataframe(df).await.unwrap(); + let table = table.with_ordering()?; let df = table.to_dataframe().await.unwrap(); Ok(df) } else { let schema = build_csv_schema(&csv_opts, &url, parse).await?; let csv_opts = csv_opts.schema(&schema); - Ok(ctx.read_csv(url, csv_opts).await?) + + let df = ctx.read_csv(url, csv_opts).await.unwrap(); + let table = VegaFusionTable::from_dataframe(df).await.unwrap(); + let table = table.with_ordering()?; + let df = table.to_dataframe().await.unwrap(); + Ok(df) } } @@ -622,6 +640,7 @@ async fn read_json(url: &str, batch_size: usize) -> Result { }; VegaFusionTable::from_json(&value, batch_size)? + .with_ordering()? .to_dataframe() .await } @@ -678,6 +697,7 @@ async fn read_arrow(url: &str) -> Result { }; VegaFusionTable::try_new(schema, batches)? + .with_ordering()? .to_dataframe() .await } diff --git a/vegafusion-rt-datafusion/src/transform/aggregate.rs b/vegafusion-rt-datafusion/src/transform/aggregate.rs index e432feb3e..b2b8988d0 100644 --- a/vegafusion-rt-datafusion/src/transform/aggregate.rs +++ b/vegafusion-rt-datafusion/src/transform/aggregate.rs @@ -9,13 +9,11 @@ use crate::expression::escape::{flat_col, unescaped_col}; use crate::sql::dataframe::SqlDataFrame; use async_trait::async_trait; use datafusion::common::{DFSchema, ScalarValue}; +use datafusion_expr::aggregate_function; use datafusion_expr::expr; -use datafusion_expr::{ - aggregate_function, BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunction, -}; use std::sync::Arc; use vegafusion_core::arrow::datatypes::DataType; +use vegafusion_core::data::ORDER_COL; use vegafusion_core::error::{Result, VegaFusionError}; use vegafusion_core::expression::escape::unescape_field; use vegafusion_core::proto::gen::transforms::{Aggregate, AggregateOp}; @@ -32,34 +30,11 @@ impl TransformTrait for Aggregate { let group_exprs: Vec<_> = self.groupby.iter().map(|c| unescaped_col(c)).collect(); let (mut agg_exprs, projections) = get_agg_and_proj_exprs(self, &dataframe.schema_df())?; - // Add __row_number column if groupby columns is not empty - let dataframe = if !self.groupby.is_empty() { - // Add row_number column that we can sort by - let row_number_expr = make_row_number_expr(); - - // Add min(__row_number) aggregation that we can sort by later - agg_exprs.push(min(flat_col("__row_number")).alias("__min_row_number")); - - dataframe - .select(vec![Expr::Wildcard, row_number_expr]) - .await? - } else { - dataframe - }; + // Append ordering column to aggregations + agg_exprs.push(min(flat_col(ORDER_COL)).alias(ORDER_COL)); // Perform aggregation - let mut grouped_dataframe = dataframe.aggregate(group_exprs, agg_exprs).await?; - - // Maybe sort by min row number - if !self.groupby.is_empty() { - // Sort groups according to the lowest row number of a value in that group - let sort_exprs = vec![Expr::Sort(expr::Sort { - expr: Box::new(flat_col("__min_row_number")), - asc: true, - nulls_first: false, - })]; - grouped_dataframe = grouped_dataframe.sort(sort_exprs, None).await?; - } + let grouped_dataframe = dataframe.aggregate(group_exprs, agg_exprs).await?; // Make final projection let grouped_dataframe = grouped_dataframe.select(projections).await?; @@ -68,21 +43,6 @@ impl TransformTrait for Aggregate { } } -pub fn make_row_number_expr() -> Expr { - Expr::WindowFunction(expr::WindowFunction { - fun: WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), - args: Vec::new(), - partition_by: Vec::new(), - order_by: Vec::new(), - window_frame: WindowFrame { - units: WindowFrameUnits::Rows, - start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), - end_bound: WindowFrameBound::CurrentRow, - }, - }) - .alias("__row_number") -} - fn get_agg_and_proj_exprs(tx: &Aggregate, schema: &DFSchema) -> Result<(Vec, Vec)> { // DataFusion does not allow repeated (field, op) combinations in an aggregate expression, // so if there are duplicates we need to use a projection after the aggregation to alias @@ -92,6 +52,9 @@ fn get_agg_and_proj_exprs(tx: &Aggregate, schema: &DFSchema) -> Result<(Vec = tx.groupby.iter().map(|f| unescaped_col(f)).collect(); + // Prepend ORDER_COL + projections.insert(0, flat_col(ORDER_COL)); + for (i, (field, op_code)) in tx.fields.iter().zip(tx.ops.iter()).enumerate() { let op = AggregateOp::from_i32(*op_code).unwrap(); diff --git a/vegafusion-rt-datafusion/src/transform/collect.rs b/vegafusion-rt-datafusion/src/transform/collect.rs index 711d18d9d..e5fddda69 100644 --- a/vegafusion-rt-datafusion/src/transform/collect.rs +++ b/vegafusion-rt-datafusion/src/transform/collect.rs @@ -7,9 +7,14 @@ use std::sync::Arc; use vegafusion_core::error::{Result, ResultWithContext}; use vegafusion_core::proto::gen::transforms::{Collect, SortOrder}; -use crate::expression::escape::unescaped_col; +use crate::expression::escape::{flat_col, unescaped_col}; use crate::sql::dataframe::SqlDataFrame; use async_trait::async_trait; +use datafusion::common::ScalarValue; +use datafusion_expr::{ + window_function, BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, +}; +use vegafusion_core::data::ORDER_COL; use vegafusion_core::task_graph::task_value::TaskValue; #[async_trait] @@ -19,6 +24,7 @@ impl TransformTrait for Collect { dataframe: Arc, _config: &CompilationConfig, ) -> Result<(Arc, Vec)> { + // Build vector of sort expressions let sort_exprs: Vec<_> = self .fields .clone() @@ -33,8 +39,41 @@ impl TransformTrait for Collect { }) .collect(); + // We don't actually sort here, use a row number window function sorted by the sort + // criteria. This column becomes the new ORDER_COL, which will be sorted at the end of + // the pipeline. + let order_col = Expr::WindowFunction(expr::WindowFunction { + fun: window_function::WindowFunction::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), + args: vec![], + partition_by: vec![], + order_by: sort_exprs, + window_frame: WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }, + }) + .alias(ORDER_COL); + + // Build vector of selections + let mut selections = dataframe + .schema() + .fields + .iter() + .filter_map(|field| { + if field.name() == ORDER_COL { + None + } else { + Some(flat_col(field.name())) + } + }) + .collect::>(); + selections.insert(0, order_col); + let result = dataframe - .sort(sort_exprs, None) + .select(selections) .await .with_context(|| "Collect transform failed".to_string())?; Ok((result, Default::default())) diff --git a/vegafusion-rt-datafusion/src/transform/identifier.rs b/vegafusion-rt-datafusion/src/transform/identifier.rs index b54781aa1..7cb045664 100644 --- a/vegafusion-rt-datafusion/src/transform/identifier.rs +++ b/vegafusion-rt-datafusion/src/transform/identifier.rs @@ -1,6 +1,7 @@ use crate::expression::compiler::config::CompilationConfig; use crate::transform::TransformTrait; +use crate::expression::escape::flat_col; use crate::sql::dataframe::SqlDataFrame; use async_trait::async_trait; use datafusion_expr::{ @@ -9,6 +10,7 @@ use datafusion_expr::{ }; use std::sync::Arc; use vegafusion_core::data::scalar::ScalarValue; +use vegafusion_core::data::ORDER_COL; use vegafusion_core::error::Result; use vegafusion_core::proto::gen::transforms::Identifier; use vegafusion_core::task_graph::task_value::TaskValue; @@ -20,12 +22,16 @@ impl TransformTrait for Identifier { dataframe: Arc, _config: &CompilationConfig, ) -> Result<(Arc, Vec)> { - // Add row number column with the desired name + // Add row number column with the desired name, sorted by the input order column let row_number_expr = Expr::WindowFunction(expr::WindowFunction { fun: WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), args: Vec::new(), partition_by: Vec::new(), - order_by: Vec::new(), + order_by: vec![Expr::Sort(expr::Sort { + expr: Box::new(flat_col(ORDER_COL)), + asc: true, + nulls_first: false, + })], window_frame: WindowFrame { units: WindowFrameUnits::Rows, start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), diff --git a/vegafusion-rt-datafusion/src/transform/impute.rs b/vegafusion-rt-datafusion/src/transform/impute.rs index 243373138..8eef2cf01 100644 --- a/vegafusion-rt-datafusion/src/transform/impute.rs +++ b/vegafusion-rt-datafusion/src/transform/impute.rs @@ -1,20 +1,22 @@ use crate::expression::compiler::config::CompilationConfig; use crate::expression::escape::{flat_col, unescaped_col}; -use crate::sql::compile::order::ToSqlOrderByExpr; use crate::sql::compile::select::ToSqlSelectItem; use crate::sql::dataframe::SqlDataFrame; -use crate::transform::aggregate::make_row_number_expr; use crate::transform::TransformTrait; use async_trait::async_trait; use datafusion::common::ScalarValue; use datafusion_expr::expr::Cast; -use datafusion_expr::{expr, lit, when, Expr}; +use datafusion_expr::{ + expr, lit, when, window_function, BuiltInWindowFunction, Expr, WindowFrame, WindowFrameBound, + WindowFrameUnits, +}; use itertools::Itertools; use sqlgen::dialect::DialectDisplay; use std::sync::Arc; use vegafusion_core::arrow::datatypes::DataType; use vegafusion_core::data::scalar::ScalarValueHelpers; -use vegafusion_core::error::{Result, VegaFusionError}; +use vegafusion_core::data::ORDER_COL; +use vegafusion_core::error::{Result, ResultWithContext, VegaFusionError}; use vegafusion_core::proto::gen::transforms::Impute; use vegafusion_core::task_graph::task_value::TaskValue; @@ -26,13 +28,19 @@ impl TransformTrait for Impute { _config: &CompilationConfig, ) -> Result<(Arc, Vec)> { // Create ScalarValue used to fill in null values - let json_value: serde_json::Value = - serde_json::from_str(self.value_json.as_ref().unwrap())?; + let json_value: serde_json::Value = serde_json::from_str( + &self + .value_json + .clone() + .unwrap_or_else(|| "null".to_string()), + )?; // JSON numbers are always interpreted as floats, but if the value is an integer we'd // like the fill value to be an integer as well to avoid converting an integer input // column to floats - let value = if json_value.is_i64() { + let value = if json_value.is_null() { + ScalarValue::Float64(None) + } else if json_value.is_i64() { ScalarValue::from(json_value.as_i64().unwrap()) } else if json_value.is_f64() && json_value.as_f64().unwrap().fract() == 0.0 { ScalarValue::from(json_value.as_f64().unwrap() as i64) @@ -112,18 +120,6 @@ async fn single_groupby_sql( let group_col = unescaped_col(groupby); let group_col_str = group_col.to_sql_select()?.sql(dataframe.dialect())?; - // Build row number expr to apply to input table - let row_number_expr = make_row_number_expr(); - let row_number_expr_str = row_number_expr.to_sql_select()?.sql(dataframe.dialect())?; - - // Build order by - let order_by_expr = Expr::Sort(expr::Sort { - expr: Box::new(flat_col("__row_number")), - asc: true, - nulls_first: false, - }); - let order_by_expr_str = order_by_expr.to_sql_order()?.sql(dataframe.dialect())?; - // Build final selection // Finally, select all of the original DataFrame columns, filling in missing values // of the `field` columns @@ -167,15 +163,53 @@ async fn single_groupby_sql( let dataframe = dataframe.chain_query_str(&format!( "SELECT {select_column_csv} from (SELECT DISTINCT {key} from {parent} WHERE {key} IS NOT NULL) AS _key \ CROSS JOIN (SELECT DISTINCT {group} from {parent} WHERE {group} IS NOT NULL) AS _group \ - LEFT OUTER JOIN (SELECT *, {row_number_expr_str} from {parent}) AS _inner USING ({key}, {group}) \ - ORDER BY {order_by_expr_str}", + LEFT OUTER JOIN {parent} \ + USING ({key}, {group})", select_column_csv = select_column_csv, key = key_col_str, group = group_col_str, - row_number_expr_str = row_number_expr_str, - order_by_expr_str = order_by_expr_str, parent = dataframe.parent_name(), )).await?; - Ok(dataframe) + // Override ordering column since null values may have been introduced in the query above. + // Match input ordering with imputed rows (those will null ordering column) pushed + // to the end. + let order_col = Expr::WindowFunction(expr::WindowFunction { + fun: window_function::WindowFunction::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), + args: vec![], + partition_by: vec![], + order_by: vec![Expr::Sort(expr::Sort { + expr: Box::new(flat_col(ORDER_COL)), + asc: true, + nulls_first: false, + })], + window_frame: WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }, + }) + .alias(ORDER_COL); + + // Build vector of selections + let mut selections = dataframe + .schema() + .fields + .iter() + .filter_map(|field| { + if field.name() == ORDER_COL { + None + } else { + Some(flat_col(field.name())) + } + }) + .collect::>(); + selections.insert(0, order_col); + + dataframe + .select(selections) + .await + .with_context(|| "Impute transform failed".to_string()) } diff --git a/vegafusion-rt-datafusion/src/transform/joinaggregate.rs b/vegafusion-rt-datafusion/src/transform/joinaggregate.rs index 7fc12445e..64cb19d05 100644 --- a/vegafusion-rt-datafusion/src/transform/joinaggregate.rs +++ b/vegafusion-rt-datafusion/src/transform/joinaggregate.rs @@ -7,7 +7,7 @@ use datafusion::logical_expr::Expr; use crate::sql::compile::expr::ToSqlExpr; use crate::sql::compile::select::ToSqlSelectItem; use crate::sql::dataframe::SqlDataFrame; -use crate::transform::aggregate::{make_aggr_expr, make_row_number_expr}; +use crate::transform::aggregate::make_aggr_expr; use async_trait::async_trait; use datafusion::common::Column; use sqlgen::dialect::DialectDisplay; @@ -95,10 +95,6 @@ impl TransformTrait for JoinAggregate { .collect::>>()?; let input_col_csv = input_col_strs.join(", "); - // Build row_number select expression - let row_number_expr = make_row_number_expr(); - let row_number_str = row_number_expr.to_sql_select()?.sql(dataframe.dialect())?; - // Perform join aggregation let sql_group_expr_strs = group_exprs .iter() @@ -115,13 +111,11 @@ impl TransformTrait for JoinAggregate { dataframe .chain_query_str(&format!( "select {input_col_csv}, {new_col_csv} \ - from (select *, {row_number_str} from {parent}) \ - CROSS JOIN (select {aggr_csv} from {parent}) as {inner_name} \ - ORDER BY __row_number", + from {parent} \ + CROSS JOIN (select {aggr_csv} from {parent}) as {inner_name}", aggr_csv = aggr_csv, parent = dataframe.parent_name(), input_col_csv = input_col_csv, - row_number_str = row_number_str, new_col_csv = new_col_csv, inner_name = inner_name, )) @@ -130,13 +124,11 @@ impl TransformTrait for JoinAggregate { let group_by_csv = sql_group_expr_strs.join(", "); dataframe.chain_query_str(&format!( "select {input_col_csv}, {new_col_csv} \ - from (select *, {row_number_str} from {parent}) \ - LEFT OUTER JOIN (select {aggr_csv}, {group_by_csv} from {parent} group by {group_by_csv}) as {inner_name} USING ({group_by_csv}) \ - ORDER BY __row_number", + from {parent} \ + LEFT OUTER JOIN (select {aggr_csv}, {group_by_csv} from {parent} group by {group_by_csv}) as {inner_name} USING ({group_by_csv})", aggr_csv = aggr_csv, parent = dataframe.parent_name(), input_col_csv = input_col_csv, - row_number_str = row_number_str, new_col_csv = new_col_csv, group_by_csv = group_by_csv, inner_name = inner_name, diff --git a/vegafusion-rt-datafusion/src/transform/pipeline.rs b/vegafusion-rt-datafusion/src/transform/pipeline.rs index 52011fa2e..7e36b3f0e 100644 --- a/vegafusion-rt-datafusion/src/transform/pipeline.rs +++ b/vegafusion-rt-datafusion/src/transform/pipeline.rs @@ -5,11 +5,14 @@ use itertools::Itertools; use std::collections::HashMap; use std::sync::Arc; -use vegafusion_core::error::Result; +use vegafusion_core::error::{Result, VegaFusionError}; +use crate::expression::escape::flat_col; use crate::sql::dataframe::SqlDataFrame; use async_trait::async_trait; +use datafusion_expr::{expr, Expr}; use vegafusion_core::data::table::VegaFusionTable; +use vegafusion_core::data::ORDER_COL; use vegafusion_core::proto::gen::tasks::{Variable, VariableNamespace}; use vegafusion_core::proto::gen::transforms::TransformPipeline; use vegafusion_core::task_graph::task_value::TaskValue; @@ -35,6 +38,13 @@ impl TransformPipelineUtils for TransformPipeline { let mut result_outputs: HashMap = Default::default(); let mut config = config.clone(); + if result_sql_df.schema().column_with_name(ORDER_COL).is_none() { + return Err(VegaFusionError::internal(format!( + "DataFrame input to eval_sql does not have the expected {} ordering column", + ORDER_COL + ))); + } + // Helper function to add variable value to config let add_output_var_to_config = |config: &mut CompilationConfig, var: &Variable, val: TaskValue| -> Result<()> { @@ -62,6 +72,12 @@ impl TransformPipelineUtils for TransformPipeline { result_sql_df = tx_result.0; + if result_sql_df.schema().column_with_name(ORDER_COL).is_none() { + return Err(VegaFusionError::internal( + format!("DataFrame output of transform does not have the expected {} ordering column: {:?}", ORDER_COL, tx) + )); + } + // Collect output variables for (var, val) in tx.output_vars().iter().zip(tx_result.1) { result_outputs.insert(var.clone(), val.clone()); @@ -72,6 +88,21 @@ impl TransformPipelineUtils for TransformPipeline { } } + // Sort by ordering column at the end + result_sql_df = result_sql_df + .sort( + vec![Expr::Sort(expr::Sort { + expr: Box::new(flat_col(ORDER_COL)), + asc: true, + nulls_first: false, + })], + None, + ) + .await?; + + // Remove ordering column + result_sql_df = remove_order_col(result_sql_df).await?; + let table = result_sql_df.collect().await?; // Sort result signal value by signal name @@ -83,3 +114,19 @@ impl TransformPipelineUtils for TransformPipeline { Ok((table, signals_values)) } } + +pub async fn remove_order_col(result_sql_df: Arc) -> Result> { + let selection = result_sql_df + .schema() + .fields + .iter() + .filter_map(|field| { + if field.name() == ORDER_COL { + None + } else { + Some(flat_col(field.name())) + } + }) + .collect::>(); + result_sql_df.select(selection).await +} diff --git a/vegafusion-rt-datafusion/src/transform/pivot.rs b/vegafusion-rt-datafusion/src/transform/pivot.rs index 1b59807aa..d93281032 100644 --- a/vegafusion-rt-datafusion/src/transform/pivot.rs +++ b/vegafusion-rt-datafusion/src/transform/pivot.rs @@ -4,7 +4,7 @@ use crate::expression::escape::{flat_col, unescaped_col}; use crate::sql::compile::expr::ToSqlExpr; use crate::sql::compile::select::ToSqlSelectItem; use crate::sql::dataframe::SqlDataFrame; -use crate::transform::aggregate::{make_aggr_expr, make_row_number_expr}; +use crate::transform::aggregate::make_aggr_expr; use crate::transform::utils::RecordBatchUtils; use crate::transform::TransformTrait; use async_trait::async_trait; @@ -14,6 +14,7 @@ use sqlgen::dialect::DialectDisplay; use std::sync::Arc; use vegafusion_core::arrow::array::StringArray; use vegafusion_core::arrow::datatypes::DataType; +use vegafusion_core::data::ORDER_COL; use vegafusion_core::error::{Result, ResultWithContext, VegaFusionError}; use vegafusion_core::proto::gen::transforms::{AggregateOp, Pivot}; use vegafusion_core::task_graph::task_value::TaskValue; @@ -168,6 +169,9 @@ async fn pivot_without_grouping( final_selections.push(select_expr) } + // Query will result in a single row, so add a constant valued ORDER_COL + final_selections.insert(0, lit(0u32).alias(ORDER_COL)); + // Build final query let final_selection_strs: Vec<_> = final_selections .iter() @@ -207,12 +211,6 @@ async fn pivot_with_grouping( return Err(VegaFusionError::internal("Unexpected empty pivot dataset")); } - // Add row_index column that we can sort by later - let row_number_expr = make_row_number_expr(); - let dataframe = dataframe - .select(vec![Expr::Wildcard, row_number_expr]) - .await?; - // Process aggregate operation let agg_op: AggregateOp = tx .op @@ -234,10 +232,11 @@ async fn pivot_with_grouping( .map(|col| col.to_sql().unwrap().sql(dialect).unwrap()) .collect(); let groupby_csv = groupby_strs.join(", "); + let grouped_dataframe = dataframe .aggregate( groupby_cols, - vec![min(flat_col("__row_number")).alias("__min_row_number")], + vec![min(flat_col(ORDER_COL)).alias(ORDER_COL)], ) .await?; @@ -247,6 +246,7 @@ async fn pivot_with_grouping( // Initialize vector of final selections let mut final_selections: Vec<_> = tx.groupby.iter().map(|c| unescaped_col(c)).collect(); + final_selections.insert(0, flat_col(ORDER_COL)); // Initialize empty query string let mut query_str = String::new(); @@ -302,9 +302,6 @@ async fn pivot_with_grouping( ), ); - // Append ordering - query_str.push_str("ORDER BY __min_row_number"); - // Perform query and apply final selections let dataframe_joined = grouped_dataframe.chain_query_str(&query_str).await?; let selected = dataframe_joined.select(final_selections).await?; diff --git a/vegafusion-rt-datafusion/src/transform/project.rs b/vegafusion-rt-datafusion/src/transform/project.rs index 8f19d1f40..c3ac42bfd 100644 --- a/vegafusion-rt-datafusion/src/transform/project.rs +++ b/vegafusion-rt-datafusion/src/transform/project.rs @@ -10,6 +10,7 @@ use vegafusion_core::proto::gen::transforms::Project; use crate::expression::escape::flat_col; use crate::sql::dataframe::SqlDataFrame; use async_trait::async_trait; +use vegafusion_core::data::ORDER_COL; use vegafusion_core::expression::escape::unescape_field; use vegafusion_core::task_graph::task_value::TaskValue; @@ -30,7 +31,7 @@ impl TransformTrait for Project { // Keep all of the project columns that are present in the dataframe. // Skip projection fields that are not found - let select_fields: Vec<_> = self + let mut select_fields: Vec<_> = self .fields .iter() .filter_map(|field| { @@ -43,6 +44,9 @@ impl TransformTrait for Project { }) .collect(); + // Always keep ordering column + select_fields.insert(0, ORDER_COL.to_string()); + let select_col_exprs: Vec<_> = select_fields.iter().map(|f| flat_col(f)).collect(); let result = dataframe.select(select_col_exprs).await?; Ok((result, Default::default())) diff --git a/vegafusion-rt-datafusion/src/transform/stack.rs b/vegafusion-rt-datafusion/src/transform/stack.rs index 2ad31c6b8..c3680c069 100644 --- a/vegafusion-rt-datafusion/src/transform/stack.rs +++ b/vegafusion-rt-datafusion/src/transform/stack.rs @@ -13,10 +13,10 @@ use datafusion_expr::{ use sqlgen::dialect::DialectDisplay; use crate::expression::escape::{flat_col, unescaped_col}; -use crate::transform::aggregate::make_row_number_expr; use std::ops::{Add, Div, Sub}; use std::sync::Arc; use vegafusion_core::data::scalar::ScalarValue; +use vegafusion_core::data::ORDER_COL; use vegafusion_core::error::{Result, VegaFusionError}; use vegafusion_core::proto::gen::transforms::{SortOrder, Stack, StackOffset}; use vegafusion_core::task_graph::task_value::TaskValue; @@ -40,7 +40,6 @@ impl TransformTrait for Stack { .collect(); // Build order by vector - // Order by row number last (and only if no explicit ordering provided) let mut order_by: Vec<_> = self .sort_fields .iter() @@ -54,19 +53,13 @@ impl TransformTrait for Stack { }) .collect(); + // Order by input row ordering last order_by.push(Expr::Sort(expr::Sort { - expr: Box::new(flat_col("__row_number")), + expr: Box::new(flat_col(ORDER_COL)), asc: true, nulls_first: true, })); - // Add row number column for sorting - let row_number_expr = make_row_number_expr(); - - let dataframe = dataframe - .select(vec![Expr::Wildcard, row_number_expr]) - .await?; - // Process according to offset let offset = StackOffset::from_i32(self.offset).expect("Failed to convert stack offset"); let dataframe = match offset { @@ -249,18 +242,6 @@ async fn eval_normalize_center_offset( _ => return Err(VegaFusionError::internal("Unexpected stack offset")), }; - // Restore original order - let dataframe = dataframe - .sort( - vec![Expr::Sort(expr::Sort { - expr: Box::new(flat_col("__row_number")), - asc: true, - nulls_first: false, - })], - None, - ) - .await?; - let dataframe = dataframe.select(final_selection.clone()).await?; Ok(dataframe) } @@ -318,18 +299,6 @@ async fn eval_zero_offset( )) .await?; - // Restore original order - let dataframe = dataframe - .sort( - vec![Expr::Sort(expr::Sort { - expr: Box::new(flat_col("__row_number")), - asc: true, - nulls_first: false, - })], - None, - ) - .await?; - // Build final selection let mut final_selection: Vec<_> = input_fields .iter() diff --git a/vegafusion-rt-datafusion/src/transform/window.rs b/vegafusion-rt-datafusion/src/transform/window.rs index 234402d9f..cde966004 100644 --- a/vegafusion-rt-datafusion/src/transform/window.rs +++ b/vegafusion-rt-datafusion/src/transform/window.rs @@ -15,11 +15,11 @@ use vegafusion_core::task_graph::task_value::TaskValue; use crate::expression::compiler::utils::to_numeric; use crate::expression::escape::{flat_col, unescaped_col}; use crate::sql::dataframe::SqlDataFrame; -use crate::transform::aggregate::make_row_number_expr; use datafusion::physical_plan::aggregates; use datafusion_expr::{ window_frame, BuiltInWindowFunction, WindowFrameBound, WindowFrameUnits, WindowFunction, }; +use vegafusion_core::data::ORDER_COL; #[async_trait] impl TransformTrait for Window { @@ -48,20 +48,13 @@ impl TransformTrait for Window { .map(|f| flat_col(f.field().name())) .collect(); - let dataframe = if order_by.is_empty() { - // If no order by fields provided, use the row number - let row_number_expr = make_row_number_expr(); - + if order_by.is_empty() { + // Order by input row if no ordering specified order_by.push(Expr::Sort(expr::Sort { - expr: Box::new(flat_col("__row_number")), + expr: Box::new(flat_col(ORDER_COL)), asc: true, nulls_first: true, })); - dataframe - .select(vec![Expr::Wildcard, row_number_expr]) - .await? - } else { - dataframe }; let partition_by: Vec<_> = self @@ -180,7 +173,6 @@ impl TransformTrait for Window { .collect(); // Add window expressions to original selections - // This will exclude the __row_number column if it was added above. selections.extend(window_exprs); let dataframe = dataframe.select(selections).await?; diff --git a/vegafusion-rt-datafusion/tests/specs/custom/area_streamgraph.comm_plan.json b/vegafusion-rt-datafusion/tests/specs/custom/area_streamgraph.comm_plan.json new file mode 100644 index 000000000..2a58d3181 --- /dev/null +++ b/vegafusion-rt-datafusion/tests/specs/custom/area_streamgraph.comm_plan.json @@ -0,0 +1,15 @@ +{ + "server_to_client": [ + { + "name": "source_0", + "namespace": "data", + "scope": [] + }, + { + "name": "source_0_color_domain_series", + "namespace": "data", + "scope": [] + } + ], + "client_to_server": [] +} diff --git a/vegafusion-rt-datafusion/tests/specs/custom/area_streamgraph.vg.json b/vegafusion-rt-datafusion/tests/specs/custom/area_streamgraph.vg.json new file mode 100644 index 000000000..d797f3ab3 --- /dev/null +++ b/vegafusion-rt-datafusion/tests/specs/custom/area_streamgraph.vg.json @@ -0,0 +1,240 @@ +{ + "$schema": "https://vega.github.io/schema/vega/v5.json", + "background": "white", + "padding": 5, + "width": 400, + "height": 300, + "style": "cell", + "data": [ + { + "name": "source_0", + "url": "https://cdn.jsdelivr.net/npm/vega-datasets@v1.29.0/data/unemployment-across-industries.json", + "format": { + "type": "json", + "parse": { + "date": "date" + } + }, + "transform": [ + { + "field": "date", + "type": "timeunit", + "units": [ + "year", + "month" + ], + "as": [ + "yearmonth_date", + "yearmonth_date_end" + ] + }, + { + "type": "aggregate", + "groupby": [ + "series", + "yearmonth_date" + ], + "ops": [ + "sum" + ], + "fields": [ + "count" + ], + "as": [ + "sum_count" + ] + }, + { + "type": "impute", + "field": "sum_count", + "groupby": [ + "series" + ], + "key": "yearmonth_date", + "method": "value", + "value": 0 + }, + { + "type": "stack", + "groupby": [ + "yearmonth_date" + ], + "field": "sum_count", + "sort": { + "field": [ + "series" + ], + "order": [ + "descending" + ] + }, + "as": [ + "sum_count_start", + "sum_count_end" + ], + "offset": "center" + } + ] + } + ], + "marks": [ + { + "name": "pathgroup", + "type": "group", + "from": { + "facet": { + "name": "faceted_path_main", + "data": "source_0", + "groupby": [ + "series" + ] + } + }, + "encode": { + "update": { + "width": { + "field": { + "group": "width" + } + }, + "height": { + "field": { + "group": "height" + } + } + } + }, + "marks": [ + { + "name": "marks", + "type": "area", + "style": [ + "area" + ], + "sort": { + "field": "datum[\"yearmonth_date\"]" + }, + "from": { + "data": "faceted_path_main" + }, + "encode": { + "update": { + "orient": { + "value": "vertical" + }, + "fill": { + "scale": "color", + "field": "series" + }, + "description": { + "signal": "\"series: \" + (isValid(datum[\"series\"]) ? datum[\"series\"] : \"\"+datum[\"series\"]) + \"; date (year-month): \" + (timeFormat(datum[\"yearmonth_date\"], '%Y')) + \"; Sum of count: \" + (format(datum[\"sum_count\"], \"\"))" + }, + "x": { + "scale": "x", + "field": "yearmonth_date" + }, + "y": { + "scale": "y", + "field": "sum_count_end" + }, + "y2": { + "scale": "y", + "field": "sum_count_start" + }, + "defined": { + "signal": "isValid(datum[\"yearmonth_date\"]) && isFinite(+datum[\"yearmonth_date\"]) && isValid(datum[\"sum_count\"]) && isFinite(+datum[\"sum_count\"])" + } + } + } + } + ] + } + ], + "scales": [ + { + "name": "x", + "type": "time", + "domain": { + "data": "source_0", + "field": "yearmonth_date" + }, + "range": [ + 0, + { + "signal": "width" + } + ] + }, + { + "name": "y", + "type": "linear", + "domain": { + "data": "source_0", + "fields": [ + "sum_count_start", + "sum_count_end" + ] + }, + "range": [ + { + "signal": "height" + }, + 0 + ], + "nice": true, + "zero": true + }, + { + "name": "color", + "type": "ordinal", + "domain": { + "data": "source_0", + "field": "series", + "sort": true + }, + "range": { + "scheme": "category20b" + } + } + ], + "axes": [ + { + "scale": "x", + "orient": "bottom", + "gridScale": "y", + "grid": true, + "tickCount": { + "signal": "ceil(width/40)" + }, + "domain": false, + "labels": false, + "aria": false, + "maxExtent": 0, + "minExtent": 0, + "ticks": false, + "zindex": 0 + }, + { + "scale": "x", + "orient": "bottom", + "grid": false, + "title": "date (year-month)", + "domain": false, + "format": "%Y", + "tickSize": 0, + "labelFlush": true, + "labelOverlap": true, + "tickCount": { + "signal": "ceil(width/40)" + }, + "zindex": 0 + } + ], + "legends": [ + { + "fill": "color", + "symbolType": "circle", + "title": "series" + } + ] +} \ No newline at end of file diff --git a/vegafusion-rt-datafusion/tests/specs/vegalite/stacked_bar_h_order_custom.vg.json b/vegafusion-rt-datafusion/tests/specs/vegalite/stacked_bar_h_order_custom.vg.json index a221bb4d4..19d8890cf 100644 --- a/vegafusion-rt-datafusion/tests/specs/vegalite/stacked_bar_h_order_custom.vg.json +++ b/vegafusion-rt-datafusion/tests/specs/vegalite/stacked_bar_h_order_custom.vg.json @@ -12,7 +12,7 @@ "transform": [ { "type": "formula", - "expr": "if(datum.site === 'University Farm', '0', if(datum.site === 'Grand Rapids', '1', datum.site))", + "expr": "if(datum.site === 'University Farm', 0, if(datum.site === 'Grand Rapids', 1, 2))", "as": "siteOrder" }, { diff --git a/vegafusion-rt-datafusion/tests/test_image_comparison.rs b/vegafusion-rt-datafusion/tests/test_image_comparison.rs index bff1aea0e..806eb3e70 100644 --- a/vegafusion-rt-datafusion/tests/test_image_comparison.rs +++ b/vegafusion-rt-datafusion/tests/test_image_comparison.rs @@ -129,7 +129,8 @@ mod test_custom_specs { case("custom/datetime_scatter", 0.001, false), case("custom/stack_divide_by_zero_error", 0.001, false), case("custom/casestudy-us_population_pyramid_over_time", 0.001, true), - case("custom/sin_cos", 0.001, true) + case("custom/sin_cos", 0.001, true), + case("custom/area_streamgraph", 0.001, true) )] fn test_image_comparison(spec_name: &str, tolerance: f64, extract_inline_values: bool) { println!("spec_name: {}", spec_name); diff --git a/vegafusion-rt-datafusion/tests/test_transform_aggregate.rs b/vegafusion-rt-datafusion/tests/test_transform_aggregate.rs index c20c8f733..e3dc7a5e9 100644 --- a/vegafusion-rt-datafusion/tests/test_transform_aggregate.rs +++ b/vegafusion-rt-datafusion/tests/test_transform_aggregate.rs @@ -54,7 +54,7 @@ mod test_aggregate_single { eval_vegafusion_transforms(&dataset, transform_specs.as_slice(), &comp_config); } else { let eq_config = TablesEqualConfig { - row_order: false, + row_order: true, ..Default::default() }; @@ -115,7 +115,7 @@ mod test_aggregate_multi { } else { // Order of grouped rows is not defined, so set row_order to false let eq_config = TablesEqualConfig { - row_order: false, + row_order: true, ..Default::default() }; @@ -175,7 +175,7 @@ fn test_bin_aggregate() { let comp_config = Default::default(); let eq_config = TablesEqualConfig { - row_order: false, + row_order: true, ..Default::default() }; @@ -218,7 +218,7 @@ fn test_aggregate_overwrite() { // Order of grouped rows is not defined, so set row_order to false let eq_config = TablesEqualConfig { - row_order: false, + row_order: true, ..Default::default() }; @@ -279,7 +279,7 @@ mod test_aggregate_with_nulls { eval_vegafusion_transforms(&dataset, transform_specs.as_slice(), &comp_config); } else { let eq_config = TablesEqualConfig { - row_order: false, + row_order: true, ..Default::default() }; diff --git a/vegafusion-rt-datafusion/tests/test_transform_identifier.rs b/vegafusion-rt-datafusion/tests/test_transform_identifier.rs index 26559bf7b..00c2c420c 100644 --- a/vegafusion-rt-datafusion/tests/test_transform_identifier.rs +++ b/vegafusion-rt-datafusion/tests/test_transform_identifier.rs @@ -9,7 +9,7 @@ use vegafusion_core::spec::transform::identifier::IdentifierTransformSpec; use vegafusion_core::spec::transform::TransformSpec; #[test] -fn test_formula_valid() { +fn test_identifier() { let dataset = vega_json_dataset("penguins"); let tx_spec = IdentifierTransformSpec { as_: "id".to_string(), diff --git a/vegafusion-rt-datafusion/tests/test_transform_impute.rs b/vegafusion-rt-datafusion/tests/test_transform_impute.rs index 81b0da4b6..bf6ef5d3b 100644 --- a/vegafusion-rt-datafusion/tests/test_transform_impute.rs +++ b/vegafusion-rt-datafusion/tests/test_transform_impute.rs @@ -90,4 +90,51 @@ mod test_impute { &eq_config, ); } + + #[test] + fn test_one_groupby_window_frame() { + let dataset = simple_dataset(); + + let transform_specs: Vec = serde_json::from_value(json!( + [ + {"type": "formula", "expr": "toNumber(datum[\"a\"])", "as": "a"}, + { + "type": "impute", + "field": "b", + "key": "a", + "method": "value", + "groupby": ["c"], + "value": null + }, + { + "type": "window", + "as": ["imputed_b_value"], + "ops": ["mean"], + "fields": ["b"], + "frame": [-2, 2], + "ignorePeers": false, + "groupby": ["c"] + }, + { + "type": "formula", + "expr": "datum.b === null ? datum.imputed_b_value : datum.b", + "as": "b" + } + ] + )) + .unwrap(); + + let comp_config = Default::default(); + let eq_config = TablesEqualConfig { + row_order: true, + ..Default::default() + }; + + check_transform_evaluation( + &dataset, + transform_specs.as_slice(), + &comp_config, + &eq_config, + ); + } } diff --git a/vegafusion-rt-datafusion/tests/test_transform_window.rs b/vegafusion-rt-datafusion/tests/test_transform_window.rs index d022e19d5..a1986e53a 100644 --- a/vegafusion-rt-datafusion/tests/test_transform_window.rs +++ b/vegafusion-rt-datafusion/tests/test_transform_window.rs @@ -84,7 +84,7 @@ mod test_window_single { let comp_config = Default::default(); let eq_config = TablesEqualConfig { - row_order: false, + row_order: true, ..Default::default() }; diff --git a/vegafusion-rt-datafusion/tests/util/check.rs b/vegafusion-rt-datafusion/tests/util/check.rs index 97636e4e7..cd994b0f6 100644 --- a/vegafusion-rt-datafusion/tests/util/check.rs +++ b/vegafusion-rt-datafusion/tests/util/check.rs @@ -114,6 +114,8 @@ pub fn eval_vegafusion_transforms( transform_specs: &[TransformSpec], compilation_config: &CompilationConfig, ) -> (VegaFusionTable, Vec) { + // add ordering column + let data = data.clone().with_ordering().unwrap(); let pipeline = TransformPipeline::try_from(transform_specs).unwrap(); let sql_df = (*TOKIO_RUNTIME).block_on(data.to_sql_dataframe()).unwrap(); diff --git a/vegafusion-rt-datafusion/tests/util/equality.rs b/vegafusion-rt-datafusion/tests/util/equality.rs index 1545efbc6..cd7135d11 100644 --- a/vegafusion-rt-datafusion/tests/util/equality.rs +++ b/vegafusion-rt-datafusion/tests/util/equality.rs @@ -10,6 +10,7 @@ use datafusion::logical_expr::{expr, Expr}; use std::sync::Arc; use vegafusion_core::data::scalar::DATETIME_PREFIX; use vegafusion_core::data::table::VegaFusionTable; +use vegafusion_core::data::ORDER_COL; use vegafusion_rt_datafusion::data::table::VegaFusionTableUtils; use vegafusion_rt_datafusion::expression::compiler::utils::is_numeric_datatype; use vegafusion_rt_datafusion::expression::escape::flat_col; @@ -36,18 +37,30 @@ pub fn assert_tables_equal( rhs: &VegaFusionTable, config: &TablesEqualConfig, ) { - // Check column names + // Check column names (filtering out order col) let lhs_columns: HashSet<_> = lhs .schema .fields() .iter() - .map(|f| f.name().clone()) + .filter_map(|f| { + if f.name() == ORDER_COL { + None + } else { + Some(f.name().clone()) + } + }) .collect(); let rhs_columns: HashSet<_> = rhs .schema .fields() .iter() - .map(|f| f.name().clone()) + .filter_map(|f| { + if f.name() == ORDER_COL { + None + } else { + Some(f.name().clone()) + } + }) .collect(); assert_eq!( lhs_columns, rhs_columns, @@ -70,17 +83,21 @@ pub fn assert_tables_equal( let rhs_rb = rhs.to_record_batch().unwrap(); (lhs_rb, rhs_rb) } else { - // Sort by all columns + // Sort by all columns except ORDER_COL let sort_exprs: Vec<_> = lhs .schema .fields() .iter() - .map(|f| { - Expr::Sort(expr::Sort { - expr: Box::new(flat_col(f.name())), - asc: false, - nulls_first: false, - }) + .filter_map(|f| { + if f.name() == ORDER_COL { + None + } else { + Some(Expr::Sort(expr::Sort { + expr: Box::new(flat_col(f.name())), + asc: false, + nulls_first: false, + })) + } }) .collect(); @@ -144,13 +161,25 @@ fn assert_scalars_almost_equals( let lhs_map: HashMap<_, _> = lhs_fields .iter() .zip(lhs_vals.iter()) - .map(|(field, val)| (field.name().clone(), val.clone())) + .filter_map(|(field, val)| { + if field.name() == ORDER_COL { + None + } else { + Some((field.name().clone(), val.clone())) + } + }) .collect(); let rhs_map: HashMap<_, _> = rhs_fields .iter() .zip(rhs_vals.iter()) - .map(|(field, val)| (field.name().clone(), val.clone())) + .filter_map(|(field, val)| { + if field.name() == ORDER_COL { + None + } else { + Some((field.name().clone(), val.clone())) + } + }) .collect(); // Check column names