Skip to content

Commit

Permalink
Add support for CreateNamedStruct
Browse files Browse the repository at this point in the history
  • Loading branch information
eejbyfeldt committed Jul 1, 2024
1 parent 42ed636 commit 5cf3264
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 1 deletion.
127 changes: 127 additions & 0 deletions core/src/execution/datafusion/expressions/create_named_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::{
any::Any,
fmt::{Display, Formatter},
hash::{Hash, Hasher},
sync::Arc,
};

use arrow::record_batch::RecordBatch;
use arrow_array::StructArray;
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_physical_expr::PhysicalExpr;

use crate::execution::datafusion::expressions::utils::down_cast_any_ref;

#[derive(Debug, Hash)]
pub struct CreateNamedStruct {
values: Vec<Arc<dyn PhysicalExpr>>,
data_type: DataType,
}

impl CreateNamedStruct {
pub fn new(values: Vec<Arc<dyn PhysicalExpr>>, data_type: DataType) -> Self {
Self { values, data_type }
}
}

impl PhysicalExpr for CreateNamedStruct {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
Ok(self.data_type.clone())
}

fn nullable(&self, _input_schema: &Schema) -> DataFusionResult<bool> {
Ok(false)
}

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
let values = self
.values
.iter()
.map(|expr| expr.evaluate(batch))
.collect::<datafusion_common::Result<Vec<_>>>()?;
let arrays = ColumnarValue::values_to_arrays(&values)?;
let fields = match &self.data_type {
DataType::Struct(fields) => fields,
_ => {
return Err(DataFusionError::Internal(format!(
"Expected struct data type, got {:?}",
self.data_type
)))
}
};
Ok(ColumnarValue::Array(Arc::new(StructArray::new(
fields.clone(),
arrays,
None,
))))
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
self.values.iter().collect()
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(CreateNamedStruct::new(
children.clone(),
self.data_type.clone(),
)))
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.values.hash(&mut s);
self.data_type.hash(&mut s);
self.hash(&mut s);
}
}

impl Display for CreateNamedStruct {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"CreateNamedStruct [values: {:?}, data_type: {:?}]",
self.values, self.data_type
)
}
}

impl PartialEq<dyn Any> for CreateNamedStruct {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.values
.iter()
.zip(x.values.iter())
.all(|(a, b)| a.eq(b))
&& self.data_type.eq(&x.data_type)
})
.unwrap_or(false)
}
}
1 change: 1 addition & 0 deletions core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub mod avg_decimal;
pub mod bloom_filter_might_contain;
pub mod correlation;
pub mod covariance;
pub mod create_named_struct;
pub mod negative;
pub mod stats;
pub mod stddev;
Expand Down
11 changes: 10 additions & 1 deletion core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ use crate::{
},
};

use super::expressions::{abs::CometAbsFunc, EvalMode};
use super::expressions::{abs::CometAbsFunc, create_named_struct::CreateNamedStruct, EvalMode};

// For clippy error on type_complexity.
type ExecResult<T> = Result<T, ExecutionError>;
Expand Down Expand Up @@ -584,6 +584,15 @@ impl PhysicalPlanner {
value_expr,
)?))
}
ExprStruct::CreateNamedStruct(expr) => {
let values = expr
.values
.iter()
.map(|expr| self.create_expr(expr, input_schema.clone()))
.collect::<Result<Vec<_>, _>>()?;
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(CreateNamedStruct::new(values, data_type)))
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
6 changes: 6 additions & 0 deletions core/src/execution/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ message Expr {
Subquery subquery = 50;
UnboundReference unbound = 51;
BloomFilterMightContain bloom_filter_might_contain = 52;
CreateNamedStruct create_named_struct = 53;
}
}

Expand Down Expand Up @@ -486,6 +487,11 @@ message BloomFilterMightContain {
Expr value = 2;
}

message CreateNamedStruct {
repeated Expr values = 1;
DataType datatype = 2;
}

enum SortDirection {
Ascending = 0;
Descending = 1;
Expand Down
19 changes: 19 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2141,6 +2141,25 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
scalarExprToProtoWithReturnType(algorithm, StringType, childExpr)
}

case struct @ CreateNamedStruct(_) =>
val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding))
val dataType = serializeDataType(struct.dataType)

if (valExprs.forall(_.isDefined) && dataType.isDefined) {
val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
structBuilder.addAllValues(valExprs.map(_.get).asJava)
structBuilder.setDatatype(dataType.get)

Some(
ExprOuterClass.Expr
.newBuilder()
.setCreateNamedStruct(structBuilder)
.build())
} else {
withInfo(expr, struct.valExprs: _*)
None
}

case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
14 changes: 14 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1719,4 +1719,18 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}
test("named_struct") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
withParquetTable(path.toString, "tbl") {
checkSparkAnswerAndOperator("SELECT named_struct('a', _1, 'b', _2) FROM tbl")
checkSparkAnswerAndOperator("SELECT named_struct('a', _1, 'b', 2) FROM tbl")
checkSparkAnswerAndOperator(
"SELECT named_struct('a', named_struct('b', _1, 'c', _2)) FROM tbl")
}
}
}
}
}

0 comments on commit 5cf3264

Please sign in to comment.