From 8ed3aa81ff27da821aa67c3def6cffa22e8b12b3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 4 Mar 2024 10:17:35 -0800 Subject: [PATCH] For review --- .../apache/comet/serde/QueryPlanSerde.scala | 2130 ++++++++--------- 1 file changed, 1058 insertions(+), 1072 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 661b7f2a1..d39d7bbe6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -333,16 +333,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } - def exprToProto( - expr: Expression, - input: Seq[Attribute], - binding: Boolean = true): Option[Expr] = { - val conf = SQLConf.get - val newExpr = - DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled) - exprToProtoInternal(newExpr, input, binding) - } - /** * Convert a Spark expression to protobuf. * @@ -355,711 +345,712 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { * @return * The protobuf representation of the expression, or None if the expression is not supported */ - def exprToProtoInternal( + def exprToProto( expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - SQLConf.get - expr match { - case a @ Alias(_, _) => - exprToProtoInternal(a.child, inputs, binding) + input: Seq[Attribute], + binding: Boolean = true): Option[Expr] = { - case cast @ Cast(_: Literal, dataType, _, _) => - // This can happen after promoting decimal precisions - val value = cast.eval() - exprToProtoInternal(Literal(value, dataType), inputs, binding) + def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): Option[Expr] = { + SQLConf.get + expr match { + case a @ Alias(_, _) => + exprToProtoInternal(a.child, inputs) - case Cast(child, dt, timeZoneId, _) => - val childExpr = exprToProtoInternal(child, inputs, binding) - val dataType = serializeDataType(dt) + case cast @ Cast(_: Literal, dataType, _, _) => + // This can happen after promoting decimal precisions + val value = cast.eval() + exprToProtoInternal(Literal(value, dataType), inputs) - if (childExpr.isDefined && dataType.isDefined) { - val castBuilder = ExprOuterClass.Cast.newBuilder() - castBuilder.setChild(childExpr.get) - castBuilder.setDatatype(dataType.get) + case Cast(child, dt, timeZoneId, _) => + val childExpr = exprToProtoInternal(child, inputs) + val dataType = serializeDataType(dt) - val timeZone = timeZoneId.getOrElse("UTC") - castBuilder.setTimezone(timeZone) + if (childExpr.isDefined && dataType.isDefined) { + val castBuilder = ExprOuterClass.Cast.newBuilder() + castBuilder.setChild(childExpr.get) + castBuilder.setDatatype(dataType.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setCast(castBuilder) - .build()) - } else { - None - } + val timeZone = timeZoneId.getOrElse("UTC") + castBuilder.setTimezone(timeZone) - case add @ Add(left, right, _) if supportedDataType(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val addBuilder = ExprOuterClass.Add.newBuilder() - addBuilder.setLeft(leftExpr.get) - addBuilder.setRight(rightExpr.get) - addBuilder.setFailOnError(getFailOnError(add)) - serializeDataType(add.dataType).foreach { t => - addBuilder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setCast(castBuilder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setAdd(addBuilder) - .build()) - } else { - None - } + case add @ Add(left, right, _) if supportedDataType(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val addBuilder = ExprOuterClass.Add.newBuilder() + addBuilder.setLeft(leftExpr.get) + addBuilder.setRight(rightExpr.get) + addBuilder.setFailOnError(getFailOnError(add)) + serializeDataType(add.dataType).foreach { t => + addBuilder.setReturnType(t) + } - case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Subtract.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(getFailOnError(sub)) - serializeDataType(sub.dataType).foreach { t => - builder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setAdd(addBuilder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setSubtract(builder) - .build()) - } else { - None - } + case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Subtract.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setFailOnError(getFailOnError(sub)) + serializeDataType(sub.dataType).foreach { t => + builder.setReturnType(t) + } - case mul @ Multiply(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Multiply.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(getFailOnError(mul)) - serializeDataType(mul.dataType).foreach { t => - builder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setSubtract(builder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setMultiply(builder) - .build()) - } else { - None - } + case mul @ Multiply(left, right, _) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Multiply.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setFailOnError(getFailOnError(mul)) + serializeDataType(mul.dataType).foreach { t => + builder.setReturnType(t) + } - case div @ Divide(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - // Datafusion now throws an exception for dividing by zero - // See https://github.com/apache/arrow-datafusion/pull/6792 - // For now, use NullIf to swap zeros with nulls. - val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs, binding) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Divide.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(getFailOnError(div)) - serializeDataType(div.dataType).foreach { t => - builder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setMultiply(builder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setDivide(builder) - .build()) - } else { - None - } + case div @ Divide(left, right, _) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + // Datafusion now throws an exception for dividing by zero + // See https://github.com/apache/arrow-datafusion/pull/6792 + // For now, use NullIf to swap zeros with nulls. + val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Divide.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setFailOnError(getFailOnError(div)) + serializeDataType(div.dataType).foreach { t => + builder.setReturnType(t) + } - case rem @ Remainder(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs, binding) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Remainder.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(getFailOnError(rem)) - serializeDataType(rem.dataType).foreach { t => - builder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setDivide(builder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setRemainder(builder) - .build()) - } else { - None - } - - case EqualTo(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + case rem @ Remainder(left, right, _) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Remainder.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setFailOnError(getFailOnError(rem)) + serializeDataType(rem.dataType).foreach { t => + builder.setReturnType(t) + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Equal.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setRemainder(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setEq(builder) - .build()) - } else { - None - } + case EqualTo(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case Not(EqualTo(left, right)) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Equal.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.NotEqual.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setEq(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setNeq(builder) - .build()) - } else { - None - } + case Not(EqualTo(left, right)) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case EqualNullSafe(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.NotEqual.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.EqualNullSafe.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setNeq(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setEqNullSafe(builder) - .build()) - } else { - None - } + case EqualNullSafe(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case Not(EqualNullSafe(left, right)) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.EqualNullSafe.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.NotEqualNullSafe.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setEqNullSafe(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setNeqNullSafe(builder) - .build()) - } else { - None - } + case Not(EqualNullSafe(left, right)) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case GreaterThan(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.NotEqualNullSafe.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.GreaterThan.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setNeqNullSafe(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setGt(builder) - .build()) - } else { - None - } + case GreaterThan(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case GreaterThanOrEqual(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.GreaterThan.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.GreaterThanEqual.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setGt(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setGtEq(builder) - .build()) - } else { - None - } + case GreaterThanOrEqual(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case LessThan(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.GreaterThanEqual.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.LessThan.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setGtEq(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setLt(builder) - .build()) - } else { - None - } + case LessThan(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case LessThanOrEqual(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.LessThan.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.LessThanEqual.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setLt(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setLtEq(builder) - .build()) - } else { - None - } + case LessThanOrEqual(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case Literal(value, dataType) if supportedDataType(dataType) => - val exprBuilder = ExprOuterClass.Literal.newBuilder() + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.LessThanEqual.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (value == null) { - exprBuilder.setIsNull(true) - } else { - exprBuilder.setIsNull(false) - dataType match { - case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) - case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) - case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) - case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) - case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) - case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) - case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) - case _: StringType => - exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString) - case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long]) - case _: DecimalType => - // Pass decimal literal as bytes. - val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue - exprBuilder.setDecimalVal( - com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray)) - case _: BinaryType => - val byteStr = - com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) - exprBuilder.setBytesVal(byteStr) - case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) - case dt => - logWarning(s"Unexpected date type '$dt' for literal value '$value'") + Some( + ExprOuterClass.Expr + .newBuilder() + .setLtEq(builder) + .build()) + } else { + None } - } - val dt = serializeDataType(dataType) + case Literal(value, dataType) if supportedDataType(dataType) => + val exprBuilder = ExprOuterClass.Literal.newBuilder() + + if (value == null) { + exprBuilder.setIsNull(true) + } else { + exprBuilder.setIsNull(false) + dataType match { + case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) + case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) + case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) + case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) + case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) + case _: StringType => + exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString) + case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: DecimalType => + // Pass decimal literal as bytes. + val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue + exprBuilder.setDecimalVal( + com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray)) + case _: BinaryType => + val byteStr = + com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) + exprBuilder.setBytesVal(byteStr) + case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case dt => + logWarning(s"Unexpected date type '$dt' for literal value '$value'") + } + } - if (dt.isDefined) { - exprBuilder.setDatatype(dt.get) + val dt = serializeDataType(dataType) - Some( - ExprOuterClass.Expr - .newBuilder() - .setLiteral(exprBuilder) - .build()) - } else { - None - } + if (dt.isDefined) { + exprBuilder.setDatatype(dt.get) - case Substring(str, Literal(pos, _), Literal(len, _)) => - val strExpr = exprToProtoInternal(str, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setLiteral(exprBuilder) + .build()) + } else { + None + } - if (strExpr.isDefined) { - val builder = ExprOuterClass.Substring.newBuilder() - builder.setChild(strExpr.get) - builder.setStart(pos.asInstanceOf[Int]) - builder.setLen(len.asInstanceOf[Int]) + case Substring(str, Literal(pos, _), Literal(len, _)) => + val strExpr = exprToProtoInternal(str, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setSubstring(builder) - .build()) - } else { - None - } + if (strExpr.isDefined) { + val builder = ExprOuterClass.Substring.newBuilder() + builder.setChild(strExpr.get) + builder.setStart(pos.asInstanceOf[Int]) + builder.setLen(len.asInstanceOf[Int]) - case Like(left, right, _) => - // TODO escapeChar - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setSubstring(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Like.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case Like(left, right, _) => + // TODO escapeChar + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setLike(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Like.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - // TODO waiting for arrow-rs update -// case RLike(left, right) => -// val leftExpr = exprToProtoInternal(left, inputs) -// val rightExpr = exprToProtoInternal(right, inputs) -// -// if (leftExpr.isDefined && rightExpr.isDefined) { -// val builder = ExprOuterClass.RLike.newBuilder() -// builder.setLeft(leftExpr.get) -// builder.setRight(rightExpr.get) -// -// Some( -// ExprOuterClass.Expr -// .newBuilder() -// .setRlike(builder) -// .build()) -// } else { -// None -// } - - case StartsWith(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.StartsWith.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setLike(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setStartsWith(builder) - .build()) - } else { - None - } + // TODO waiting for arrow-rs update + // case RLike(left, right) => + // val leftExpr = exprToProtoInternal(left, inputs) + // val rightExpr = exprToProtoInternal(right, inputs) + // + // if (leftExpr.isDefined && rightExpr.isDefined) { + // val builder = ExprOuterClass.RLike.newBuilder() + // builder.setLeft(leftExpr.get) + // builder.setRight(rightExpr.get) + // + // Some( + // ExprOuterClass.Expr + // .newBuilder() + // .setRlike(builder) + // .build()) + // } else { + // None + // } + + case StartsWith(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.StartsWith.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case EndsWith(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setStartsWith(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.EndsWith.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case EndsWith(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setEndsWith(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.EndsWith.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case Contains(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setEndsWith(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Contains.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case Contains(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setContains(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Contains.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case StringSpace(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setContains(builder) + .build()) + } else { + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.StringSpace.newBuilder() - builder.setChild(childExpr.get) + case StringSpace(child) => + val childExpr = exprToProtoInternal(child, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setStringSpace(builder) - .build()) - } else { - None - } + if (childExpr.isDefined) { + val builder = ExprOuterClass.StringSpace.newBuilder() + builder.setChild(childExpr.get) - case Hour(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setStringSpace(builder) + .build()) + } else { + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.Hour.newBuilder() - builder.setChild(childExpr.get) + case Hour(child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Hour.newBuilder() + builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setHour(builder) - .build()) - } else { - None - } + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - case Minute(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setHour(builder) + .build()) + } else { + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.Minute.newBuilder() - builder.setChild(childExpr.get) + case Minute(child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Minute.newBuilder() + builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setMinute(builder) - .build()) - } else { - None - } + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - case TruncDate(child, format) => - val childExpr = exprToProtoInternal(child, inputs, binding) - val formatExpr = exprToProtoInternal(format, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setMinute(builder) + .build()) + } else { + None + } - if (childExpr.isDefined && formatExpr.isDefined) { - val builder = ExprOuterClass.TruncDate.newBuilder() - builder.setChild(childExpr.get) - builder.setFormat(formatExpr.get) + case TruncDate(child, format) => + val childExpr = exprToProtoInternal(child, inputs) + val formatExpr = exprToProtoInternal(format, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setTruncDate(builder) - .build()) - } else { - None - } + if (childExpr.isDefined && formatExpr.isDefined) { + val builder = ExprOuterClass.TruncDate.newBuilder() + builder.setChild(childExpr.get) + builder.setFormat(formatExpr.get) - case TruncTimestamp(format, child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs, binding) - val formatExpr = exprToProtoInternal(format, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setTruncDate(builder) + .build()) + } else { + None + } - if (childExpr.isDefined && formatExpr.isDefined) { - val builder = ExprOuterClass.TruncTimestamp.newBuilder() - builder.setChild(childExpr.get) - builder.setFormat(formatExpr.get) + case TruncTimestamp(format, child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs) + val formatExpr = exprToProtoInternal(format, inputs) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + if (childExpr.isDefined && formatExpr.isDefined) { + val builder = ExprOuterClass.TruncTimestamp.newBuilder() + builder.setChild(childExpr.get) + builder.setFormat(formatExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setTruncTimestamp(builder) - .build()) - } else { - None - } + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - case Second(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setTruncTimestamp(builder) + .build()) + } else { + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.Second.newBuilder() - builder.setChild(childExpr.get) + case Second(child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Second.newBuilder() + builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setSecond(builder) - .build()) - } else { - None - } + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - case Year(child) => - val periodType = exprToProtoInternal(Literal("year"), inputs, binding) - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("datepart", Seq(periodType, childExpr): _*) - .map(e => { - Expr - .newBuilder() - .setCast( - ExprOuterClass.Cast - .newBuilder() - .setChild(e) - .setDatatype(serializeDataType(IntegerType).get) - .build()) - .build() - }) + Some( + ExprOuterClass.Expr + .newBuilder() + .setSecond(builder) + .build()) + } else { + None + } - case IsNull(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) + case Year(child) => + val periodType = exprToProtoInternal(Literal("year"), inputs) + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("datepart", Seq(periodType, childExpr): _*) + .map(e => { + Expr + .newBuilder() + .setCast( + ExprOuterClass.Cast + .newBuilder() + .setChild(e) + .setDatatype(serializeDataType(IntegerType).get) + .build()) + .build() + }) + + case IsNull(child) => + val childExpr = exprToProtoInternal(child, inputs) + + if (childExpr.isDefined) { + val castBuilder = ExprOuterClass.IsNull.newBuilder() + castBuilder.setChild(childExpr.get) - if (childExpr.isDefined) { - val castBuilder = ExprOuterClass.IsNull.newBuilder() - castBuilder.setChild(childExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIsNull(castBuilder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setIsNull(castBuilder) - .build()) - } else { - None - } + case IsNotNull(child) => + val childExpr = exprToProtoInternal(child, inputs) - case IsNotNull(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) + if (childExpr.isDefined) { + val castBuilder = ExprOuterClass.IsNotNull.newBuilder() + castBuilder.setChild(childExpr.get) - if (childExpr.isDefined) { - val castBuilder = ExprOuterClass.IsNotNull.newBuilder() - castBuilder.setChild(childExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIsNotNull(castBuilder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setIsNotNull(castBuilder) - .build()) - } else { - None - } + case SortOrder(child, direction, nullOrdering, _) => + val childExpr = exprToProtoInternal(child, inputs) - case SortOrder(child, direction, nullOrdering, _) => - val childExpr = exprToProtoInternal(child, inputs, binding) + if (childExpr.isDefined) { + val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() + sortOrderBuilder.setChild(childExpr.get) - if (childExpr.isDefined) { - val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() - sortOrderBuilder.setChild(childExpr.get) + direction match { + case Ascending => sortOrderBuilder.setDirectionValue(0) + case Descending => sortOrderBuilder.setDirectionValue(1) + } - direction match { - case Ascending => sortOrderBuilder.setDirectionValue(0) - case Descending => sortOrderBuilder.setDirectionValue(1) - } + nullOrdering match { + case NullsFirst => sortOrderBuilder.setNullOrderingValue(0) + case NullsLast => sortOrderBuilder.setNullOrderingValue(1) + } - nullOrdering match { - case NullsFirst => sortOrderBuilder.setNullOrderingValue(0) - case NullsLast => sortOrderBuilder.setNullOrderingValue(1) + Some( + ExprOuterClass.Expr + .newBuilder() + .setSortOrder(sortOrderBuilder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setSortOrder(sortOrderBuilder) - .build()) - } else { - None - } - - case And(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + case And(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.And.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.And.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setAnd(builder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setAnd(builder) + .build()) + } else { + None + } - case Or(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + case Or(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Or.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Or.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setOr(builder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setOr(builder) + .build()) + } else { + None + } - case UnaryExpression(child) if expr.prettyName == "promote_precision" => - // `UnaryExpression` includes `PromotePrecision` for Spark 3.2 & 3.3 - // `PromotePrecision` is just a wrapper, don't need to serialize it. - exprToProtoInternal(child, inputs, binding) + case UnaryExpression(child) if expr.prettyName == "promote_precision" => + // `UnaryExpression` includes `PromotePrecision` for Spark 3.2 & 3.3 + // `PromotePrecision` is just a wrapper, don't need to serialize it. + exprToProtoInternal(child, inputs) - case CheckOverflow(child, dt, nullOnOverflow) => - val childExpr = exprToProtoInternal(child, inputs, binding) + case CheckOverflow(child, dt, nullOnOverflow) => + val childExpr = exprToProtoInternal(child, inputs) - if (childExpr.isDefined) { - val builder = ExprOuterClass.CheckOverflow.newBuilder() - builder.setChild(childExpr.get) - builder.setFailOnError(!nullOnOverflow) + if (childExpr.isDefined) { + val builder = ExprOuterClass.CheckOverflow.newBuilder() + builder.setChild(childExpr.get) + builder.setFailOnError(!nullOnOverflow) - // `dataType` must be decimal type - val dataType = serializeDataType(dt) - builder.setDatatype(dataType.get) + // `dataType` must be decimal type + val dataType = serializeDataType(dt) + builder.setDatatype(dataType.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setCheckOverflow(builder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setCheckOverflow(builder) + .build()) + } else { + None + } - case attr: AttributeReference => - val dataType = serializeDataType(attr.dataType) + case attr: AttributeReference => + val dataType = serializeDataType(attr.dataType) - if (dataType.isDefined) { - if (binding) { + if (dataType.isDefined) { val boundRef = BindReferences .bindReference(attr, inputs, allowFailures = false) .asInstanceOf[BoundReference] @@ -1075,535 +1066,530 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setBound(boundExpr) .build()) } else { - val unboundRef = ExprOuterClass.UnboundReference - .newBuilder() - .setName(attr.name) - .setDatatype(dataType.get) - .build() + None + } - Some( - ExprOuterClass.Expr + case Abs(child, _) => + exprToProtoInternal(child, inputs).map(childExpr => { + val abs = + ExprOuterClass.Abs .newBuilder() - .setUnbound(unboundRef) - .build()) - } - } else { - None - } + .setChild(childExpr) + .build() + Expr.newBuilder().setAbs(abs).build() + }) - case Abs(child, _) => - exprToProtoInternal(child, inputs, binding).map(childExpr => { - val abs = - ExprOuterClass.Abs - .newBuilder() - .setChild(childExpr) - .build() - Expr.newBuilder().setAbs(abs).build() - }) - - case Acos(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("acos", childExpr) - - case Asin(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("asin", childExpr) - - case Atan(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("atan", childExpr) - - case Atan2(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) - scalarExprToProto("atan2", leftExpr, rightExpr) - - case e @ Ceil(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - child.dataType match { - case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - None - case _ => - scalarExprToProtoWithReturnType("ceil", e.dataType, childExpr) - } + case Acos(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("acos", childExpr) + + case Asin(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("asin", childExpr) + + case Atan(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("atan", childExpr) + + case Atan2(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + scalarExprToProto("atan2", leftExpr, rightExpr) + + case e @ Ceil(child) => + val childExpr = exprToProtoInternal(child, inputs) + child.dataType match { + case t: DecimalType if t.scale == 0 => // zero scale is no-op + childExpr + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + None + case _ => + scalarExprToProtoWithReturnType("ceil", e.dataType, childExpr) + } - case Cos(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("cos", childExpr) + case Cos(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("cos", childExpr) + + case Exp(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("exp", childExpr) + + case e @ Floor(child) => + val childExpr = exprToProtoInternal(child, inputs) + child.dataType match { + case t: DecimalType if t.scale == 0 => // zero scale is no-op + childExpr + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + None + case _ => + scalarExprToProtoWithReturnType("floor", e.dataType, childExpr) + } - case Exp(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("exp", childExpr) + case Log(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("ln", childExpr) + + case Log10(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("log10", childExpr) + + case Log2(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("log2", childExpr) + + case Pow(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + scalarExprToProto("pow", leftExpr, rightExpr) + + // round function for Spark 3.2 does not allow negative round target scale. In addition, + // it has different result precision/scale for decimals. Supporting only 3.3 and above. + case r: Round if !isSpark32 => + // _scale s a constant, copied from Spark's RoundBase because it is a protected val + val scaleV: Any = r.scale.eval(EmptyRow) + val _scale: Int = scaleV.asInstanceOf[Int] + + lazy val childExpr = exprToProtoInternal(r.child, inputs) + r.child.dataType match { + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + None + case _ if scaleV == null => + exprToProtoInternal(Literal(null), inputs) + case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => + childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark + case _: FloatType | DoubleType => + // We cannot properly match with the Spark behavior for floating-point numbers. + // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a + // double to string internally in order to create its own internal representation. + // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated + // rounding algorithm. E.g. -5.81855622136895E8 is actually + // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of + // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a + // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be + // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that + // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can + // be rounded up to 6.13171162472835E18 that still represents the same double number. + // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. + // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead + // of 6.1317116247283999E18. + None + case _ => + // `scale` must be Int64 type in DataFusion + val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) + scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) + } - case e @ Floor(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - child.dataType match { - case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + case Signum(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("signum", childExpr) + + case Sin(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("sin", childExpr) + + case Sqrt(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("sqrt", childExpr) + + case Tan(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("tan", childExpr) + + case Ascii(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("ascii", childExpr) + + case BitLength(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("bit_length", childExpr) + + case If(predicate, trueValue, falseValue) => + val predicateExpr = exprToProtoInternal(predicate, inputs) + val trueExpr = exprToProtoInternal(trueValue, inputs) + val falseExpr = exprToProtoInternal(falseValue, inputs) + if (predicateExpr.isDefined && trueExpr.isDefined && falseExpr.isDefined) { + val builder = ExprOuterClass.IfExpr.newBuilder() + builder.setIfExpr(predicateExpr.get) + builder.setTrueExpr(trueExpr.get) + builder.setFalseExpr(falseExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIf(builder) + .build()) + } else { None - case _ => - scalarExprToProtoWithReturnType("floor", e.dataType, childExpr) - } + } - case Log(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("ln", childExpr) - - case Log10(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("log10", childExpr) - - case Log2(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("log2", childExpr) - - case Pow(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) - scalarExprToProto("pow", leftExpr, rightExpr) - - // round function for Spark 3.2 does not allow negative round target scale. In addition, - // it has different result precision/scale for decimals. Supporting only 3.3 and above. - case r: Round if !isSpark32 => - // _scale s a constant, copied from Spark's RoundBase because it is a protected val - val scaleV: Any = r.scale.eval(EmptyRow) - val _scale: Int = scaleV.asInstanceOf[Int] - - lazy val childExpr = exprToProtoInternal(r.child, inputs, binding) - r.child.dataType match { - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - None - case _ if scaleV == null => - exprToProtoInternal(Literal(null), inputs, binding) - case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => - childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark - case _: FloatType | DoubleType => - // We cannot properly match with the Spark behavior for floating-point numbers. - // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a - // double to string internally in order to create its own internal representation. - // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated - // rounding algorithm. E.g. -5.81855622136895E8 is actually - // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of - // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a - // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be - // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that - // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can - // be rounded up to 6.13171162472835E18 that still represents the same double number. - // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. - // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead - // of 6.1317116247283999E18. + case CaseWhen(branches, elseValue) => + val whenSeq = branches.map(elements => exprToProtoInternal(elements._1, inputs)) + val thenSeq = branches.map(elements => exprToProtoInternal(elements._2, inputs)) + assert(whenSeq.length == thenSeq.length) + if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) { + val builder = ExprOuterClass.CaseWhen.newBuilder() + builder.addAllWhen(whenSeq.map(_.get).asJava) + builder.addAllThen(thenSeq.map(_.get).asJava) + if (elseValue.isDefined) { + val elseValueExpr = exprToProtoInternal(elseValue.get, inputs) + if (elseValueExpr.isDefined) { + builder.setElseExpr(elseValueExpr.get) + } else { + return None + } + } + Some( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(builder) + .build()) + } else { None - case _ => - // `scale` must be Int64 type in DataFusion - val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding) - scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) - } + } - case Signum(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("signum", childExpr) - - case Sin(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("sin", childExpr) - - case Sqrt(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("sqrt", childExpr) - - case Tan(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("tan", childExpr) - - case Ascii(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) - scalarExprToProto("ascii", childExpr) - - case BitLength(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) - scalarExprToProto("bit_length", childExpr) - - case If(predicate, trueValue, falseValue) => - val predicateExpr = exprToProtoInternal(predicate, inputs, binding) - val trueExpr = exprToProtoInternal(trueValue, inputs, binding) - val falseExpr = exprToProtoInternal(falseValue, inputs, binding) - if (predicateExpr.isDefined && trueExpr.isDefined && falseExpr.isDefined) { - val builder = ExprOuterClass.IfExpr.newBuilder() - builder.setIfExpr(predicateExpr.get) - builder.setTrueExpr(trueExpr.get) - builder.setFalseExpr(falseExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIf(builder) - .build()) - } else { - None - } + case ConcatWs(children) => + val exprs = children.map(e => exprToProtoInternal(Cast(e, StringType), inputs)) + scalarExprToProto("concat_ws", exprs: _*) - case CaseWhen(branches, elseValue) => - val whenSeq = branches.map(elements => exprToProtoInternal(elements._1, inputs, binding)) - val thenSeq = branches.map(elements => exprToProtoInternal(elements._2, inputs, binding)) - assert(whenSeq.length == thenSeq.length) - if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) { - val builder = ExprOuterClass.CaseWhen.newBuilder() - builder.addAllWhen(whenSeq.map(_.get).asJava) - builder.addAllThen(thenSeq.map(_.get).asJava) - if (elseValue.isDefined) { - val elseValueExpr = exprToProtoInternal(elseValue.get, inputs, binding) - if (elseValueExpr.isDefined) { - builder.setElseExpr(elseValueExpr.get) - } else { - return None - } - } - Some( - ExprOuterClass.Expr - .newBuilder() - .setCaseWhen(builder) - .build()) - } else { - None - } + case Chr(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("chr", childExpr) - case ConcatWs(children) => - val exprs = children.map(e => exprToProtoInternal(Cast(e, StringType), inputs, binding)) - scalarExprToProto("concat_ws", exprs: _*) + case InitCap(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("initcap", childExpr) - case Chr(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProto("chr", childExpr) + case Length(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("length", childExpr) - case InitCap(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) - scalarExprToProto("initcap", childExpr) + case Lower(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("lower", childExpr) - case Length(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) - scalarExprToProto("length", childExpr) + case Md5(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("md5", childExpr) - case Lower(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) - scalarExprToProto("lower", childExpr) + case OctetLength(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("octet_length", childExpr) - case Md5(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) - scalarExprToProto("md5", childExpr) + case Reverse(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("reverse", childExpr) - case OctetLength(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) - scalarExprToProto("octet_length", childExpr) + case StringInstr(str, substr) => + val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) + val rightExpr = exprToProtoInternal(Cast(substr, StringType), inputs) + scalarExprToProto("strpos", leftExpr, rightExpr) - case Reverse(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) - scalarExprToProto("reverse", childExpr) + case StringRepeat(str, times) => + val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) + val rightExpr = exprToProtoInternal(Cast(times, LongType), inputs) + scalarExprToProto("repeat", leftExpr, rightExpr) - case StringInstr(str, substr) => - val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs, binding) - val rightExpr = exprToProtoInternal(Cast(substr, StringType), inputs, binding) - scalarExprToProto("strpos", leftExpr, rightExpr) + case StringReplace(src, search, replace) => + val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) + val searchExpr = exprToProtoInternal(Cast(search, StringType), inputs) + val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) + scalarExprToProto("replace", srcExpr, searchExpr, replaceExpr) - case StringRepeat(str, times) => - val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs, binding) - val rightExpr = exprToProtoInternal(Cast(times, LongType), inputs, binding) - scalarExprToProto("repeat", leftExpr, rightExpr) + case StringTranslate(src, matching, replace) => + val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) + val matchingExpr = exprToProtoInternal(Cast(matching, StringType), inputs) + val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) + scalarExprToProto("translate", srcExpr, matchingExpr, replaceExpr) - case StringReplace(src, search, replace) => - val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs, binding) - val searchExpr = exprToProtoInternal(Cast(search, StringType), inputs, binding) - val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs, binding) - scalarExprToProto("replace", srcExpr, searchExpr, replaceExpr) + case StringTrim(srcStr, trimStr) => + trim(srcStr, trimStr, inputs, "trim") - case StringTranslate(src, matching, replace) => - val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs, binding) - val matchingExpr = exprToProtoInternal(Cast(matching, StringType), inputs, binding) - val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs, binding) - scalarExprToProto("translate", srcExpr, matchingExpr, replaceExpr) + case StringTrimLeft(srcStr, trimStr) => + trim(srcStr, trimStr, inputs, "ltrim") - case StringTrim(srcStr, trimStr) => - trim(srcStr, trimStr, inputs, "trim", binding) + case StringTrimRight(srcStr, trimStr) => + trim(srcStr, trimStr, inputs, "rtrim") - case StringTrimLeft(srcStr, trimStr) => - trim(srcStr, trimStr, inputs, "ltrim", binding) + case StringTrimBoth(srcStr, trimStr, _) => + trim(srcStr, trimStr, inputs, "btrim") - case StringTrimRight(srcStr, trimStr) => - trim(srcStr, trimStr, inputs, "rtrim", binding) + case Upper(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("upper", childExpr) - case StringTrimBoth(srcStr, trimStr, _) => - trim(srcStr, trimStr, inputs, "btrim", binding) + case BitwiseAnd(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case Upper(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) - scalarExprToProto("upper", childExpr) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseAnd.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case BitwiseAnd(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseAnd(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseAnd.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case BitwiseNot(child) => + val childExpr = exprToProtoInternal(child, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseAnd(builder) - .build()) - } else { - None - } + if (childExpr.isDefined) { + val builder = ExprOuterClass.BitwiseNot.newBuilder() + builder.setChild(childExpr.get) - case BitwiseNot(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseNot(builder) + .build()) + } else { + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.BitwiseNot.newBuilder() - builder.setChild(childExpr.get) + case BitwiseOr(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseNot(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseOr.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case BitwiseOr(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseOr(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseOr.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case BitwiseXor(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseOr(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseXor.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case BitwiseXor(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseXor(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseXor.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case ShiftRight(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = if (left.dataType == LongType) { + // DataFusion bitwise shift right expression requires + // same data type between left and right side + exprToProtoInternal(Cast(right, LongType), inputs) + } else { + exprToProtoInternal(right, inputs) + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseXor(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseShiftRight.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case ShiftRight(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = if (left.dataType == LongType) { - // DataFusion bitwise shift right expression requires - // same data type between left and right side - exprToProtoInternal(Cast(right, LongType), inputs, binding) - } else { - exprToProtoInternal(right, inputs, binding) - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseShiftRight(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseShiftRight.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case ShiftLeft(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = if (left.dataType == LongType) { + // DataFusion bitwise shift left expression requires + // same data type between left and right side + exprToProtoInternal(Cast(right, LongType), inputs) + } else { + exprToProtoInternal(right, inputs) + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftRight(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseShiftLeft.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case ShiftLeft(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = if (left.dataType == LongType) { - // DataFusion bitwise shift left expression requires - // same data type between left and right side - exprToProtoInternal(Cast(right, LongType), inputs, binding) - } else { - exprToProtoInternal(right, inputs, binding) - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseShiftLeft(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseShiftLeft.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case In(value, list) => + in(value, list, inputs, false) + + case InSet(value, hset) => + val valueDataType = value.dataType + val list = hset.map { setVal => + Literal(setVal, valueDataType) + }.toSeq + // Change `InSet` to `In` expression + // We do Spark `InSet` optimization in native (DataFusion) side. + in(value, list, inputs, false) + + case Not(In(value, list)) => + in(value, list, inputs, true) + + case Not(child) => + val childExpr = exprToProtoInternal(child, inputs) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Not.newBuilder() + builder.setChild(childExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setNot(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftLeft(builder) - .build()) - } else { - None - } + case UnaryMinus(child, _) => + val childExpr = exprToProtoInternal(child, inputs) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Negative.newBuilder() + builder.setChild(childExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setNegative(builder) + .build()) + } else { + None + } - case In(value, list) => - in(value, list, inputs, false, binding) - - case InSet(value, hset) => - val valueDataType = value.dataType - val list = hset.map { setVal => - Literal(setVal, valueDataType) - }.toSeq - // Change `InSet` to `In` expression - // We do Spark `InSet` optimization in native (DataFusion) side. - in(value, list, inputs, false, binding) - - case Not(In(value, list)) => - in(value, list, inputs, true, binding) - - case Not(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - if (childExpr.isDefined) { - val builder = ExprOuterClass.Not.newBuilder() - builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setNot(builder) - .build()) - } else { - None - } + case a @ Coalesce(_) => + val exprChildren = a.children.map(exprToProtoInternal(_, inputs)) + scalarExprToProto("coalesce", exprChildren: _*) + + // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for + // char types. Use rpad to achieve the behavior. + // See https://github.com/apache/spark/pull/38151 + case StaticInvoke( + _: Class[CharVarcharCodegenUtils], + _: StringType, + "readSidePadding", + arguments, + _, + true, + false, + true) if arguments.size == 2 => + val argsExpr = Seq( + exprToProtoInternal(Cast(arguments(0), StringType), inputs), + exprToProtoInternal(arguments(1), inputs)) + + if (argsExpr.forall(_.isDefined)) { + val builder = ExprOuterClass.ScalarFunc.newBuilder() + builder.setFunc("rpad") + argsExpr.foreach(arg => builder.addArgs(arg.get)) + + Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) + } else { + None + } - case UnaryMinus(child, _) => - val childExpr = exprToProtoInternal(child, inputs, binding) - if (childExpr.isDefined) { - val builder = ExprOuterClass.Negative.newBuilder() - builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr + case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) => + val dataType = serializeDataType(expr.dataType) + if (dataType.isEmpty) { + return None + } + exprToProtoInternal(expr, inputs).map { child => + val builder = ExprOuterClass.NormalizeNaNAndZero .newBuilder() - .setNegative(builder) - .build()) - } else { - None - } + .setChild(child) + .setDatatype(dataType.get) + ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build() + } - case a @ Coalesce(_) => - val exprChildren = a.children.map(exprToProtoInternal(_, inputs, binding)) - scalarExprToProto("coalesce", exprChildren: _*) - - // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for char - // types. Use rpad to achieve the behavior. See https://github.com/apache/spark/pull/38151 - case StaticInvoke( - _: Class[CharVarcharCodegenUtils], - _: StringType, - "readSidePadding", - arguments, - _, - true, - false, - true) if arguments.size == 2 => - val argsExpr = Seq( - exprToProtoInternal(Cast(arguments(0), StringType), inputs, binding), - exprToProtoInternal(arguments(1), inputs, binding)) - - if (argsExpr.forall(_.isDefined)) { - val builder = ExprOuterClass.ScalarFunc.newBuilder() - builder.setFunc("rpad") - argsExpr.foreach(arg => builder.addArgs(arg.get)) - - Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) - } else { - None - } + case s @ execution.ScalarSubquery(_, _) => + val dataType = serializeDataType(s.dataType) + if (dataType.isEmpty) { + return None + } - case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) => - val dataType = serializeDataType(expr.dataType) - if (dataType.isEmpty) { - return None - } - exprToProtoInternal(expr, inputs, binding).map { child => - val builder = ExprOuterClass.NormalizeNaNAndZero + val builder = ExprOuterClass.Subquery .newBuilder() - .setChild(child) + .setId(s.exprId.id) .setDatatype(dataType.get) - ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build() - } - - case s @ execution.ScalarSubquery(_, _) => - val dataType = serializeDataType(s.dataType) - if (dataType.isEmpty) { - return None - } + Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build()) - val builder = ExprOuterClass.Subquery - .newBuilder() - .setId(s.exprId.id) - .setDatatype(dataType.get) - Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build()) + case UnscaledValue(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProtoWithReturnType("unscaled_value", LongType, childExpr) - case UnscaledValue(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProtoWithReturnType("unscaled_value", LongType, childExpr) + case MakeDecimal(child, precision, scale, true) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProtoWithReturnType( + "make_decimal", + DecimalType(precision, scale), + childExpr) - case MakeDecimal(child, precision, scale, true) => - val childExpr = exprToProtoInternal(child, inputs, binding) - scalarExprToProtoWithReturnType("make_decimal", DecimalType(precision, scale), childExpr) - - case e => - emitWarning(s"unsupported Spark expression: '$e' of class '${e.getClass.getName}") - None + case e => + emitWarning(s"unsupported Spark expression: '$e' of class '${e.getClass.getName}") + None + } } - } - private def trim( - srcStr: Expression, - trimStr: Option[Expression], - inputs: Seq[Attribute], - trimType: String, - binding: Boolean): Option[Expr] = { - val srcExpr = exprToProtoInternal(Cast(srcStr, StringType), inputs, binding) - if (trimStr.isDefined) { - val trimExpr = exprToProtoInternal(Cast(trimStr.get, StringType), inputs, binding) - scalarExprToProto(trimType, srcExpr, trimExpr) - } else { - scalarExprToProto(trimType, srcExpr) + def trim( + srcStr: Expression, + trimStr: Option[Expression], + inputs: Seq[Attribute], + trimType: String): Option[Expr] = { + val srcExpr = exprToProtoInternal(Cast(srcStr, StringType), inputs) + if (trimStr.isDefined) { + val trimExpr = exprToProtoInternal(Cast(trimStr.get, StringType), inputs) + scalarExprToProto(trimType, srcExpr, trimExpr) + } else { + scalarExprToProto(trimType, srcExpr) + } } - } - private def in( - value: Expression, - list: Seq[Expression], - inputs: Seq[Attribute], - negate: Boolean, - binding: Boolean): Option[Expr] = { - val valueExpr = exprToProtoInternal(value, inputs, binding) - val listExprs = list.map(exprToProtoInternal(_, inputs, binding)) - if (valueExpr.isDefined && listExprs.forall(_.isDefined)) { - val builder = ExprOuterClass.In.newBuilder() - builder.setInValue(valueExpr.get) - builder.addAllLists(listExprs.map(_.get).asJava) - builder.setNegated(negate) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIn(builder) - .build()) - } else { - None + def in( + value: Expression, + list: Seq[Expression], + inputs: Seq[Attribute], + negate: Boolean): Option[Expr] = { + val valueExpr = exprToProtoInternal(value, inputs) + val listExprs = list.map(exprToProtoInternal(_, inputs)) + if (valueExpr.isDefined && listExprs.forall(_.isDefined)) { + val builder = ExprOuterClass.In.newBuilder() + builder.setInValue(valueExpr.get) + builder.addAllLists(listExprs.map(_.get).asJava) + builder.setNegated(negate) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIn(builder) + .build()) + } else { + None + } } + + val conf = SQLConf.get + val newExpr = + DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled) + exprToProtoInternal(newExpr, input) } def scalarExprToProtoWithReturnType(