diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5091b933b3842..32a60a3c8ad75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -265,11 +265,7 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { false } else { val allCondBooleans = predicates.forall(_.dataType == BooleanType) - val valueTypes = branches.sliding(2, 2).map { - case Seq(_, value) => value.dataType - case Seq(elseVal) => elseVal.dataType - }.toSeq - val dataTypesEqual = valueTypes.distinct.size <= 1 + val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1 allCondBooleans && dataTypesEqual } }