Skip to content

Commit

Permalink
feat(col): dropFields & withField (#32)
Browse files Browse the repository at this point in the history
* feat(col): dropFields & withField
  • Loading branch information
sjrusso8 authored May 15, 2024
1 parent 7fa53b0 commit eb12b83
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 2 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,11 @@ Spark [Column](https://spark.apache.org/docs/latest/api/python/reference/pyspark
| desc | ![done] | |
| desc_nulls_first | ![done] | |
| desc_nulls_last | ![done] | |
| dropFields | ![open] | |
| dropFields | ![done] | |
| endswith | ![done] | |
| eqNullSafe | ![open] | |
| getField | ![open] | This is depreciated but will need to be implemented |
| getItem | ![open] | This is depreciated but will need to be implemented |
| ilike | ![done] | |
| isNotNull | ![done] | |
| isNull | ![done] | |
Expand All @@ -296,6 +299,7 @@ Spark [Column](https://spark.apache.org/docs/latest/api/python/reference/pyspark
| startswith | ![done] | |
| substr | ![open] | |
| when | ![open] | |
| withField | ![done] | |
| eq `==` | ![done] | Rust does not like when you try to overload `==` and return something other than a `bool`. Currently implemented column equality like `col('name').eq(col('id'))`. Not the best, but it works for now |
| addition `+` | ![done] | |
| subtration `-` | ![done] | |
Expand Down
46 changes: 46 additions & 0 deletions core/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,51 @@ impl Column {
Column::from(expression)
}

#[allow(non_snake_case)]
pub fn dropFields<'a, I>(self, fieldNames: I) -> Column
where
I: IntoIterator<Item = &'a str>,
{
let mut parent_col = self.expression;

for field in fieldNames {
parent_col = spark::Expression {
expr_type: Some(spark::expression::ExprType::UpdateFields(Box::new(
spark::expression::UpdateFields {
struct_expression: Some(Box::new(parent_col)),
field_name: field.to_string(),
value_expression: None,
},
))),
};
}

Column::from(parent_col)
}

#[allow(non_snake_case)]
pub fn withField(self, fieldName: &str, col: Column) -> Column {
let update_field = spark::Expression {
expr_type: Some(spark::expression::ExprType::UpdateFields(Box::new(
spark::expression::UpdateFields {
struct_expression: Some(Box::new(self.expression)),
field_name: fieldName.to_string(),
value_expression: Some(Box::new(col.to_literal_expr())),
},
))),
};

Column::from(update_field)
}

#[allow(non_snake_case)]
pub fn substr<T: ToExpr>(self, startPos: T, length: T) -> Column {
invoke_func(
"substr",
vec![self.to_expr(), startPos.to_expr(), length.to_expr()],
)
}

/// Casts the column into the Spark type represented as a `&str`
///
/// # Arguments:
Expand Down Expand Up @@ -296,6 +341,7 @@ impl Column {
pub fn or<T: ToExpr>(self, other: T) -> Column {
invoke_func("or", vec![self.to_expr(), other.to_expr()])
}

/// A filter expression that evaluates to true is the expression is null
#[allow(non_snake_case)]
pub fn isNull(self) -> Column {
Expand Down
63 changes: 62 additions & 1 deletion core/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,6 @@ mod tests {
]));

let expected = RecordBatch::try_from_iter(vec![("struct", struct_array)])?;
println!("{:?}", &res);

assert_eq!(res, expected);
Ok(())
Expand Down Expand Up @@ -940,4 +939,66 @@ mod tests {
assert_eq!(expected, res);
Ok(())
}

#[tokio::test]
async fn test_func_col_drop_fields() -> Result<(), SparkError> {
let spark = setup().await;

let df = spark.range(None, 1, 1, None).select(
named_struct([
lit("a"),
lit(1),
lit("b"),
lit(2),
lit("c"),
lit(3),
lit("d"),
lit(4),
])
.alias("struct_col"),
);

let df = df.select(col("struct_col").dropFields(["b", "c"]).alias("struct_col"));

let res = df.collect().await?;

let a: ArrayRef = Arc::new(Int32Array::from(vec![1]));
let d: ArrayRef = Arc::new(Int32Array::from(vec![4]));

let struct_array: ArrayRef = Arc::new(StructArray::from(vec![
(Arc::new(Field::new("a", DataType::Int32, false)), a),
(Arc::new(Field::new("d", DataType::Int32, false)), d),
]));

let expected = RecordBatch::try_from_iter(vec![("struct_col", struct_array)])?;

assert_eq!(expected, res);
Ok(())
}

#[tokio::test]
async fn test_func_col_with_field() -> Result<(), SparkError> {
let spark = setup().await;

let df = spark
.range(None, 1, 1, None)
.select(named_struct([lit("a"), lit(1), lit("b"), lit(2)]).alias("struct_col"));

let df = df.select(col("struct_col").withField("b", lit(4)).alias("struct_col"));

let res = df.collect().await?;

let a: ArrayRef = Arc::new(Int32Array::from(vec![1]));
let b: ArrayRef = Arc::new(Int32Array::from(vec![4]));

let struct_array: ArrayRef = Arc::new(StructArray::from(vec![
(Arc::new(Field::new("a", DataType::Int32, false)), a),
(Arc::new(Field::new("b", DataType::Int32, false)), b),
]));

let expected = RecordBatch::try_from_iter(vec![("struct_col", struct_array)])?;

assert_eq!(expected, res);
Ok(())
}
}

0 comments on commit eb12b83

Please sign in to comment.