Skip to content

Commit

Permalink
Fix column names for boolean pivot field (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmmease authored Dec 6, 2022
1 parent b61ae4a commit 2c9c7ec
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 12 deletions.
24 changes: 22 additions & 2 deletions vegafusion-rt-datafusion/src/transform/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::transform::utils::RecordBatchUtils;
use crate::transform::TransformTrait;
use async_trait::async_trait;
use datafusion::prelude::Column;
use datafusion_expr::{coalesce, lit, min, BuiltInWindowFunction, Expr, WindowFunction};
use datafusion_expr::{coalesce, col, lit, min, when, BuiltInWindowFunction, Expr, WindowFunction};
use sqlgen::dialect::DialectDisplay;
use std::sync::Arc;
use vegafusion_core::arrow::array::StringArray;
Expand All @@ -27,7 +27,26 @@ impl TransformTrait for Pivot {
) -> Result<(Arc<SqlDataFrame>, Vec<TaskValue>)> {
// Make sure the pivot column is a string
let pivot_dtype = data_type(&unescaped_col(&self.field), &dataframe.schema_df())?;
let dataframe = if !is_string_datatype(&pivot_dtype) {
let dataframe = if matches!(pivot_dtype, DataType::Boolean) {
// Boolean column type. For consistency with vega, replace 0 with "false" and 1 with "true"
let select_exprs: Vec<_> = dataframe
.schema()
.fields
.iter()
.map(|field| {
if field.name() == &self.field {
Ok(when(col(&self.field).eq(lit(true)), lit("true"))
.otherwise(lit("false"))
.expect("Failed to construct Case expression")
.alias(&self.field))
} else {
Ok(unescaped_col(field.name()))
}
})
.collect::<Result<Vec<_>>>()?;
dataframe.select(select_exprs)?
} else if !is_string_datatype(&pivot_dtype) {
// Column type is not string, so cast values to strings
let select_exprs: Vec<_> = dataframe
.schema()
.fields
Expand All @@ -47,6 +66,7 @@ impl TransformTrait for Pivot {
.collect::<Result<Vec<_>>>()?;
dataframe.select(select_exprs)?
} else {
// Column type is string
dataframe
};

Expand Down
69 changes: 59 additions & 10 deletions vegafusion-rt-datafusion/tests/test_transform_pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ use vegafusion_core::data::table::VegaFusionTable;
fn medals() -> VegaFusionTable {
VegaFusionTable::from_json(
&json!([
{"country": "Germany", "type": "gold", "count": 14},
{"country": "Norway", "type": "gold", "count": 14},
{"country": "Norway", "type": "silver", "count": 14},
{"country": "Canada", "type": "copper", "count": 10},
{"country": "Norway", "type": "bronze", "count": 11},
{"country": "Germany", "type": "silver", "count": 10},
{"country": "Germany", "type": "bronze", "count": 7},
{"country": "Canada", "type": "gold", "count": 11},
{"country": "Canada", "type": "silver", "count": 8},
{"country": "Canada", "type": "bronze", "count": 10},
{"country": "Germany", "type": "gold", "count": 14, "is_gold": true},
{"country": "Norway", "type": "gold", "count": 14, "is_gold": true},
{"country": "Norway", "type": "silver", "count": 14, "is_gold": false},
{"country": "Canada", "type": "copper", "count": 10, "is_gold": false},
{"country": "Norway", "type": "bronze", "count": 11, "is_gold": false},
{"country": "Germany", "type": "silver", "count": 10, "is_gold": false},
{"country": "Germany", "type": "bronze", "count": 7, "is_gold": false},
{"country": "Canada", "type": "gold", "count": 11, "is_gold": true},
{"country": "Canada", "type": "silver", "count": 8, "is_gold": false},
{"country": "Canada", "type": "bronze", "count": 10, "is_gold": false},
]),
1024,
)
Expand Down Expand Up @@ -129,3 +129,52 @@ mod test_pivot_no_group {
);
}
}

#[cfg(test)]
mod test_pivot_no_group_boolean {
use crate::medals;
use crate::util::check::check_transform_evaluation;
use rstest::rstest;
use vegafusion_core::spec::transform::aggregate::AggregateOpSpec;
use vegafusion_core::spec::transform::pivot::PivotTransformSpec;
use vegafusion_core::spec::transform::TransformSpec;

#[rstest(
op,
limit,
case(None, None),
case(Some(AggregateOpSpec::Sum), None),
case(Some(AggregateOpSpec::Sum), Some(2)),
case(Some(AggregateOpSpec::Count), None),
case(Some(AggregateOpSpec::Count), Some(3)),
case(Some(AggregateOpSpec::Mean), None),
case(Some(AggregateOpSpec::Mean), Some(4)),
case(Some(AggregateOpSpec::Max), None),
case(Some(AggregateOpSpec::Max), Some(10)),
case(Some(AggregateOpSpec::Min), None),
case(Some(AggregateOpSpec::Min), Some(0))
)]
fn test(op: Option<AggregateOpSpec>, limit: Option<i32>) {
let dataset = medals();

let pivot_spec = PivotTransformSpec {
field: "is_gold".to_string(),
value: "count".to_string(),
groupby: None,
limit,
op,
extra: Default::default(),
};
let transform_specs = vec![TransformSpec::Pivot(pivot_spec)];

let comp_config = Default::default();
let eq_config = Default::default();

check_transform_evaluation(
&dataset,
transform_specs.as_slice(),
&comp_config,
&eq_config,
);
}
}

0 comments on commit 2c9c7ec

Please sign in to comment.