Skip to content

Commit

Permalink
[FLINK-36629][table-planner] Modify AdaptiveJoinProcessor to support …
Browse files Browse the repository at this point in the history
…adaptive skewed join optimization
  • Loading branch information
noorall authored and zhuzhurk committed Jan 14, 2025
1 parent c78097a commit ab1649e
Show file tree
Hide file tree
Showing 15 changed files with 689 additions and 231 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ private boolean isAdaptiveJoinEnabled(ProcessorContext context) {
!= OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE
&& !TableConfigUtils.isOperatorDisabled(
tableConfig, OperatorType.BroadcastHashJoin);
isAdaptiveJoinEnabled |=
tableConfig.get(
OptimizerConfigOptions
.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_STRATEGY)
!= OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.NONE;
JobManagerOptions.SchedulerType schedulerType =
context.getPlanner()
.getExecEnv()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.flink.table.planner.delegation
import org.apache.flink.api.common.RuntimeExecutionMode
import org.apache.flink.api.dag.Transformation
import org.apache.flink.configuration.ExecutionOptions
import org.apache.flink.runtime.scheduler.adaptivebatch.StreamGraphOptimizationStrategy
import org.apache.flink.runtime.scheduler.adaptivebatch.{BatchExecutionOptionsInternal, StreamGraphOptimizationStrategy}
import org.apache.flink.table.api._
import org.apache.flink.table.api.config.OptimizerConfigOptions
import org.apache.flink.table.catalog.{CatalogManager, FunctionCatalog}
Expand All @@ -35,7 +35,7 @@ import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodePlanDumper
import org.apache.flink.table.planner.plan.optimize.{BatchCommonSubGraphBasedOptimizer, Optimizer}
import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil
import org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment
import org.apache.flink.table.runtime.strategy.{AdaptiveBroadcastJoinOptimizationStrategy, PostProcessAdaptiveJoinStrategy}
import org.apache.flink.table.runtime.strategy.{AdaptiveBroadcastJoinOptimizationStrategy, AdaptiveSkewedJoinOptimizationStrategy, PostProcessAdaptiveJoinStrategy}

import org.apache.calcite.plan.{ConventionTraitDef, RelTrait, RelTraitDef}
import org.apache.calcite.rel.RelCollationTraitDef
Expand Down Expand Up @@ -106,11 +106,27 @@ class BatchPlanner(
super.afterTranslation()
val configuration = getTableConfig
val optimizationStrategies = new util.ArrayList[String]()
if (
configuration.get(OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY)
!= OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE
) {
val isAdaptiveBroadcastJoinEnabled = configuration.get(
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY) != OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE
val isAdaptiveSkewedJoinEnabled = configuration.get(
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_STRATEGY) != OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.NONE
if (isAdaptiveBroadcastJoinEnabled) {
optimizationStrategies.add(classOf[AdaptiveBroadcastJoinOptimizationStrategy].getName)
}
if (isAdaptiveSkewedJoinEnabled) {
optimizationStrategies.add(classOf[AdaptiveSkewedJoinOptimizationStrategy].getName)
configuration.set(
BatchExecutionOptionsInternal.ADAPTIVE_SKEWED_OPTIMIZATION_SKEWED_FACTOR,
configuration.get(
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_SKEWED_FACTOR)
)
configuration.set(
BatchExecutionOptionsInternal.ADAPTIVE_SKEWED_OPTIMIZATION_SKEWED_THRESHOLD,
configuration.get(
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_SKEWED_THRESHOLD)
)
}
if (isAdaptiveBroadcastJoinEnabled || isAdaptiveSkewedJoinEnabled) {
optimizationStrategies.add(classOf[PostProcessAdaptiveJoinStrategy].getName)
}
configuration.set(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.runtime.utils;

import org.apache.flink.configuration.BatchExecutionOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.test.junit5.MiniClusterExtension;

import org.junit.jupiter.api.extension.RegisterExtension;

/** Adaptive batch test base to use {@link RegisterExtension}. */
public class AdaptiveBatchAbstractTestBase {
protected static final int DEFAULT_PARALLELISM = 3;

@RegisterExtension
private static final MiniClusterExtension MINI_CLUSTER_EXTENSION =
new MiniClusterExtension(
new MiniClusterResourceConfiguration.Builder()
.setConfiguration(getConfiguration())
.setNumberTaskManagers(1)
.setNumberSlotsPerTaskManager(DEFAULT_PARALLELISM)
.build());

private static Configuration getConfiguration() {
Configuration config = new Configuration();
config.set(TaskManagerOptions.MANAGED_MEMORY_SIZE, MemorySize.parse("100m"));
config.set(
BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_AVG_DATA_VOLUME_PER_TASK,
MemorySize.parse("100k"));
return config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,29 @@ LogicalProject(a1=[$0], b1=[$1], c1=[$2], d1=[$3], a2=[$4], b2=[$5], c2=[$6], d2
HashJoin(joinType=[InnerJoin], where=[(a1 = a2)], select=[a1, b1, c1, d1, a2, b2, c2, d2], build=[left])
:- Exchange(distribution=[hash[a1]])
: +- TableSourceScan(table=[[default_catalog, default_database, T1]], fields=[a1, b1, c1, d1])
+- Exchange(distribution=[hash[a2]])
+- TableSourceScan(table=[[default_catalog, default_database, T2]], fields=[a2, b2, c2, d2])
]]>
</Resource>
</TestCase>
<TestCase name="testAdaptiveJoinWithAdaptiveSkewedOptimizationEnabled">
<Resource name="sql">
<![CDATA[SELECT * FROM T1, T2 WHERE a1 = a2]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a1=[$0], b1=[$1], c1=[$2], d1=[$3], a2=[$4], b2=[$5], c2=[$6], d2=[$7])
+- LogicalFilter(condition=[=($0, $4)])
+- LogicalJoin(condition=[true], joinType=[inner])
:- LogicalTableScan(table=[[default_catalog, default_database, T1]])
+- LogicalTableScan(table=[[default_catalog, default_database, T2]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
AdaptiveJoin(originalJoin=[ShuffleHashJoin], joinType=[InnerJoin], where=[(a1 = a2)], select=[a1, b1, c1, d1, a2, b2, c2, d2], build=[left])
:- Exchange(distribution=[hash[a1]])
: +- TableSourceScan(table=[[default_catalog, default_database, T1]], fields=[a1, b1, c1, d1])
+- Exchange(distribution=[hash[a2]])
+- TableSourceScan(table=[[default_catalog, default_database, T2]], fields=[a2, b2, c2, d2])
]]>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,16 @@ class AdaptiveJoinTest extends TableTestBase {
val sql = "SELECT * FROM T1, T2 WHERE a1 = a2"
util.verifyExecPlan(sql)
}

@Test
def testAdaptiveJoinWithAdaptiveSkewedOptimizationEnabled(): Unit = {
util.tableConfig.set(
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY,
OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE)
util.tableConfig.set(
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_STRATEGY,
OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.AUTO)
val sql = "SELECT * FROM T1, T2 WHERE a1 = a2"
util.verifyExecPlan(sql)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,136 +17,48 @@
*/
package org.apache.flink.table.planner.runtime.batch.sql.adaptive

import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRING_TYPE_INFO}
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions}
import org.apache.flink.table.planner.runtime.utils.BatchTestBase
import org.apache.flink.types.Row
import org.apache.flink.table.api.config.OptimizerConfigOptions

import org.junit.jupiter.api.{BeforeEach, Test}

import scala.collection.JavaConversions._
import scala.util.Random
import org.junit.jupiter.api.BeforeEach

/** IT cases for adaptive broadcast join. */
class AdaptiveBroadcastJoinITCase extends BatchTestBase {

class AdaptiveBroadcastJoinITCase extends AdaptiveJoinITCase {
@BeforeEach
override def before(): Unit = {
super.before()

tEnv.getConfig
.set(
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_STRATEGY,
OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.NONE)

registerCollection(
"T",
AdaptiveBroadcastJoinITCase.dataT,
AdaptiveBroadcastJoinITCase.rowType,
AdaptiveJoinITCase.generateRandomData,
AdaptiveJoinITCase.rowType,
"a, b, c, d",
AdaptiveBroadcastJoinITCase.nullables)
AdaptiveJoinITCase.nullables)
registerCollection(
"T1",
AdaptiveBroadcastJoinITCase.dataT1,
AdaptiveBroadcastJoinITCase.rowType,
AdaptiveJoinITCase.generateRandomData,
AdaptiveJoinITCase.rowType,
"a1, b1, c1, d1",
AdaptiveBroadcastJoinITCase.nullables)
AdaptiveJoinITCase.nullables)
registerCollection(
"T2",
AdaptiveBroadcastJoinITCase.dataT2,
AdaptiveBroadcastJoinITCase.rowType,
AdaptiveJoinITCase.generateRandomData,
AdaptiveJoinITCase.rowType,
"a2, b2, c2, d2",
AdaptiveBroadcastJoinITCase.nullables)
AdaptiveJoinITCase.nullables)
registerCollection(
"T3",
AdaptiveBroadcastJoinITCase.dataT3,
AdaptiveBroadcastJoinITCase.rowType,
AdaptiveJoinITCase.generateRandomData,
AdaptiveJoinITCase.rowType,
"a3, b3, c3, d3",
AdaptiveBroadcastJoinITCase.nullables)
}

@Test
def testWithShuffleHashJoin(): Unit = {
tEnv.getConfig
.set(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin")
testSimpleJoin()
AdaptiveJoinITCase.nullables)
}

@Test
def testWithShuffleMergeJoin(): Unit = {
tEnv.getConfig
.set(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,ShuffleHashJoin")
testSimpleJoin()
}

@Test
def testWithBroadcastJoin(): Unit = {
tEnv.getConfig.set(
ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS,
"SortMergeJoin,NestedLoopJoin")
tEnv.getConfig.set(
OptimizerConfigOptions.TABLE_OPTIMIZER_BROADCAST_JOIN_THRESHOLD,
Long.box(Long.MaxValue))
testSimpleJoin()
}

@Test
def testShuffleJoinWithForwardForConsecutiveHash(): Unit = {
tEnv.getConfig.set(
OptimizerConfigOptions.TABLE_OPTIMIZER_MULTIPLE_INPUT_ENABLED,
Boolean.box(false))
val sql =
"""
|WITH
| r AS (SELECT * FROM T1, T2, T3 WHERE a1 = a2 and a1 = a3)
|SELECT sum(b1) FROM r group by a1
|""".stripMargin
checkResult(sql)
}

@Test
def testJoinWithUnionInput(): Unit = {
val sql =
"""
|SELECT * FROM
| (SELECT a FROM (SELECT a1 as a FROM T1) UNION ALL (SELECT a2 as a FROM T2)) Y
| LEFT JOIN T ON T.a = Y.a
|""".stripMargin
checkResult(sql)
}

@Test
def testJoinWithMultipleInput(): Unit = {
val sql =
"""
|SELECT * FROM
| (SELECT a FROM T1 JOIN T ON a = a1) t1
| INNER JOIN
| (SELECT d2 FROM T JOIN T2 ON d2 = a) t2
|ON t1.a = t2.d2
|""".stripMargin
checkResult(sql)
}

def testSimpleJoin(): Unit = {
// inner join
val sql1 = "SELECT * FROM T1, T2 WHERE a1 = a2"
checkResult(sql1)

// left join
val sql2 = "SELECT * FROM T1 LEFT JOIN T2 on a1 = a2"
checkResult(sql2)

// right join
val sql3 = "SELECT * FROM T1 RIGHT JOIN T2 on a1 = a2"
checkResult(sql3)

// semi join
val sql4 = "SELECT * FROM T1 WHERE a1 IN (SELECT a2 FROM T2)"
checkResult(sql4)

// anti join
val sql5 = "SELECT * FROM T1 WHERE a1 NOT IN (SELECT a2 FROM T2 where a2 = a1)"
checkResult(sql5)
}

def checkResult(sql: String): Unit = {
override def checkResult(sql: String): Unit = {
tEnv.getConfig
.set(
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY,
Expand All @@ -164,26 +76,3 @@ class AdaptiveBroadcastJoinITCase extends BatchTestBase {
checkResult(sql, expected)
}
}

object AdaptiveBroadcastJoinITCase {

def generateRandomData(): Seq[Row] = {
val data = new java.util.ArrayList[Row]()
val numRows = Random.nextInt(30)
lazy val strs = Seq("adaptive", "join", "itcase")
for (x <- 0 until numRows) {
data.add(
BatchTestBase.row(x.toLong, Random.nextLong(), strs(Random.nextInt(3)), Random.nextLong()))
}
data
}

lazy val rowType =
new RowTypeInfo(LONG_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO, LONG_TYPE_INFO)
lazy val nullables: Array[Boolean] = Array(true, true, true, true)

lazy val dataT: Seq[Row] = generateRandomData()
lazy val dataT1: Seq[Row] = generateRandomData()
lazy val dataT2: Seq[Row] = generateRandomData()
lazy val dataT3: Seq[Row] = generateRandomData()
}
Loading

0 comments on commit ab1649e

Please sign in to comment.