Skip to content

Commit

Permalink
Fix testcases
Browse files Browse the repository at this point in the history
  • Loading branch information
Surbhi-Vijay committed Feb 13, 2024
1 parent f2a60a4 commit 7583869
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
right: SparkPlan,
buildSide: BuildSide,
joinType: JoinType,
condition: Option[Expression]): BroadcastNestedLoopJoinTransformer =
GlutenBroadcastNestedLoopJoinTransformer(
condition: Option[Expression]): BroadcastNestedLoopJoinExecTransformer =
GlutenBroadcastNestedLoopJoinExecTransformer(
left,
right,
buildSide,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.BuildSideRelation

case class GlutenBroadcastNestedLoopJoinTransformer(
case class GlutenBroadcastNestedLoopJoinExecTransformer(
left: SparkPlan,
right: SparkPlan,
buildSide: BuildSide,
joinType: JoinType,
condition: Option[Expression])
extends BroadcastNestedLoopJoinTransformer(
extends BroadcastNestedLoopJoinExecTransformer(
left,
right,
buildSide,
Expand All @@ -43,7 +43,7 @@ case class GlutenBroadcastNestedLoopJoinTransformer(

override protected def withNewChildrenInternal(
newLeft: SparkPlan,
newRight: SparkPlan): GlutenBroadcastNestedLoopJoinTransformer =
newRight: SparkPlan): GlutenBroadcastNestedLoopJoinExecTransformer =
copy(left = newLeft, right = newRight)

}
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
|select * from t1 cross join t2;
|""".stripMargin
) {
checkOperatorMatch[BroadcastNestedLoopJoinTransformer]
checkOperatorMatch[BroadcastNestedLoopJoinExecTransformer]
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ trait SparkPlanExecApi {
right: SparkPlan,
buildSide: BuildSide,
joinType: JoinType,
condition: Option[Expression]): BroadcastNestedLoopJoinTransformer
condition: Option[Expression]): BroadcastNestedLoopJoinExecTransformer

def genAliasTransformer(
substraitExprName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import io.substrait.proto.CrossRel

abstract class BroadcastNestedLoopJoinTransformer(
abstract class BroadcastNestedLoopJoinExecTransformer(
left: SparkPlan,
right: SparkPlan,
buildSide: BuildSide,
Expand All @@ -44,6 +44,8 @@ abstract class BroadcastNestedLoopJoinTransformer(
extends BaseJoinExec
with TransformSupport {

def joinBuildSide: BuildSide = buildSide

override def leftKeys: Seq[Expression] = Nil
override def rightKeys: Seq[Expression] = Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -133,6 +133,13 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
}

override protected def doValidateInternal(): ValidationResult = {
// CH backend does not support IdentityBroadcastMode used in BNLJ
if (
mode == IdentityBroadcastMode && !BackendsApiManager.getSettings
.supportBroadcastNestedLoopJoinExec()
) {
return ValidationResult.notOk("This backend does not support IdentityBroadcastMode and BNLJ")
}
BackendsApiManager.getValidatorApiInstance
.doSchemaValidate(schema)
.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.joins

import io.glutenproject.GlutenConfig
import io.glutenproject.execution.{BroadcastHashJoinExecTransformer, ColumnarToRowExecBase, WholeStageTransformer}
import io.glutenproject.execution.{BroadcastHashJoinExecTransformer, BroadcastNestedLoopJoinExecTransformer, ColumnarToRowExecBase, WholeStageTransformer}
import io.glutenproject.utils.{BackendTestUtils, SystemParameters}

import org.apache.spark.sql.{GlutenTestsCommonTrait, SparkSession}
Expand All @@ -43,10 +43,16 @@ class GlutenBroadcastJoinSuite extends BroadcastJoinSuite with GlutenTestsCommon

private val EnsureRequirements = new EnsureRequirements()

private val isVeloxBackend = BackendTestUtils.isVeloxBackendLoaded()

// BroadcastHashJoinExecTransformer is not case class, can't call toString method,
// let's put constant string here.
private val bh = "BroadcastHashJoinExecTransformer"
private val bl = BroadcastNestedLoopJoinExec.toString
private val bl = if (isVeloxBackend) {
"BroadcastNestedLoopJoinExecTransformer"
} else {
BroadcastNestedLoopJoinExec.toString
}

override def beforeAll(): Unit = {
super.beforeAll()
Expand Down Expand Up @@ -249,10 +255,15 @@ class GlutenBroadcastJoinSuite extends BroadcastJoinSuite with GlutenTestsCommon
c2r.child match {
case w: WholeStageTransformer =>
val join = w.child match {
case b: BroadcastHashJoinExecTransformer => b
case b: BroadcastHashJoinExecTransformer =>
b
assert(join.getClass.getSimpleName.endsWith(joinMethod))
assert(join.joinBuildSide == buildSide)
case b: BroadcastNestedLoopJoinExecTransformer =>
b
assert(join.getClass.getSimpleName.endsWith(joinMethod))
assert(join.joinBuildSide == buildSide)
}
assert(join.getClass.getSimpleName.endsWith(joinMethod))
assert(join.joinBuildSide == buildSide)
}
case b: BroadcastNestedLoopJoinExec =>
assert(b.getClass.getSimpleName === joinMethod)
Expand Down

0 comments on commit 7583869

Please sign in to comment.