Skip to content

Commit

Permalink
Prioritize high confidence stats during broadcast joins
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavmuk04 committed Jun 29, 2024
1 parent c3e4c8e commit 0d702bb
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ public final class SystemSessionProperties
public static final String OPTIMIZE_METADATA_QUERIES_CALL_THRESHOLD = "optimize_metadata_queries_call_threshold";
public static final String FAST_INEQUALITY_JOINS = "fast_inequality_joins";
public static final String QUERY_PRIORITY = "query_priority";
public static final String CONFIDENCE_BASED_BROADCAST_ENABLED = "confidence_based_broadcast_enabled";
public static final String SPILL_ENABLED = "spill_enabled";
public static final String JOIN_SPILL_ENABLED = "join_spill_enabled";
public static final String AGGREGATION_SPILL_ENABLED = "aggregation_spill_enabled";
Expand Down Expand Up @@ -423,6 +424,11 @@ public SystemSessionProperties(
"Consider source table size when determining join distribution type when CBO fails",
featuresConfig.isSizeBasedJoinDistributionTypeEnabled(),
false),
booleanProperty(
CONFIDENCE_BASED_BROADCAST_ENABLED,
"Enable confidence based broadcasting when enabled",
false,
false),
booleanProperty(
DISTRIBUTED_INDEX_JOIN,
"Distribute index joins on join keys instead of executing inline",
Expand Down Expand Up @@ -2019,6 +2025,11 @@ public static boolean isDistributedIndexJoinEnabled(Session session)
return session.getSystemProperty(DISTRIBUTED_INDEX_JOIN, Boolean.class);
}

public static boolean confidenceBasedBroadcastEnabled(Session session)
{
return session.getSystemProperty(CONFIDENCE_BASED_BROADCAST_ENABLED, Boolean.class);
}

public static int getHashPartitionCount(Session session)
{
return session.getSystemProperty(HASH_PARTITION_COUNT, Integer.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative;

import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.JoinNode;

import java.util.Optional;

import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED;

public class ConfidenceBasedBroadcastUtil
{
private ConfidenceBasedBroadcastUtil() {};

public static Optional<JoinNode> confidenceBasedBroadcast(JoinNode joinNode, Rule.Context context)
{
boolean rightIsAtleastHighConfidence = isAtleastHighConfidence(joinNode.getRight(), context);
boolean leftIsAtleastHighConfidence = isAtleastHighConfidence(joinNode.getLeft(), context);

if (rightIsAtleastHighConfidence && !leftIsAtleastHighConfidence) {
return Optional.of(joinNode.withDistributionType(REPLICATED));
}
else if (leftIsAtleastHighConfidence && !rightIsAtleastHighConfidence) {
return Optional.of(joinNode.flipChildren().withDistributionType(REPLICATED));
}

return Optional.empty();
}

private static boolean isAtleastHighConfidence(PlanNode planNode, Rule.Context context)
{
StatsProvider statsProvider = context.getStatsProvider();
PlanNodeStatsEstimate stats = statsProvider.getStats(planNode);

return stats.confidenceLevel().getConfidenceOrdinal() > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.confidenceBasedBroadcastEnabled;
import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType;
import static com.facebook.presto.SystemSessionProperties.getJoinMaxBroadcastTableSize;
import static com.facebook.presto.SystemSessionProperties.isSizeBasedJoinDistributionTypeEnabled;
Expand All @@ -49,6 +51,7 @@
import static com.facebook.presto.spi.plan.JoinDistributionType.PARTITIONED;
import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.AUTOMATIC;
import static com.facebook.presto.sql.planner.iterative.ConfidenceBasedBroadcastUtil.confidenceBasedBroadcast;
import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isBelowBroadcastLimit;
import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isSmallerThanThreshold;
import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
Expand Down Expand Up @@ -124,6 +127,13 @@ private PlanNode getCostBasedJoin(JoinNode joinNode, Context context)
addJoinsWithDifferentDistributions(joinNode, possibleJoinNodes, context);
addJoinsWithDifferentDistributions(joinNode.flipChildren(), possibleJoinNodes, context);

if (isBelowMaxBroadcastSize(joinNode, context) && isBelowMaxBroadcastSize(joinNode.flipChildren(), context) && !mustPartition(joinNode) && confidenceBasedBroadcastEnabled(context.getSession())) {
Optional<JoinNode> result = confidenceBasedBroadcast(joinNode, context);
if (result.isPresent()) {
return result.get();
}
}

if (possibleJoinNodes.stream().anyMatch(result -> result.getCost().hasUnknownComponents()) || possibleJoinNodes.isEmpty()) {
// TODO: currently this session parameter is added so as to roll out the plan change gradually, after proved to be a better choice, make it default and get rid of the session parameter here.
if (isUseBroadcastJoinWhenBuildSizeSmallProbeSizeUnknownEnabled(context.getSession()) && possibleJoinNodes.stream().anyMatch(result -> ((JoinNode) result.getPlanNode()).getDistributionType().get().equals(REPLICATED))) {
Expand Down Expand Up @@ -236,7 +246,7 @@ private JoinNode getSyntacticOrderJoin(JoinNode joinNode, Context context, JoinD
return joinNode.withDistributionType(REPLICATED);
}

private boolean mustPartition(JoinNode joinNode)
private static boolean mustPartition(JoinNode joinNode)
{
return joinNode.getType().mustPartition();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import static com.facebook.presto.SystemSessionProperties.confidenceBasedBroadcastEnabled;
import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType;
import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy;
import static com.facebook.presto.SystemSessionProperties.getMaxReorderedJoins;
Expand All @@ -81,6 +82,7 @@
import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique;
import static com.facebook.presto.sql.planner.iterative.ConfidenceBasedBroadcastUtil.confidenceBasedBroadcast;
import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.isBelowMaxBroadcastSize;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT;
Expand Down Expand Up @@ -537,6 +539,14 @@ private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode)
if (isAtMostScalar(joinNode.getLeft(), lookup)) {
return createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(REPLICATED));
}

if (isBelowMaxBroadcastSize(joinNode, context) && isBelowMaxBroadcastSize(joinNode.flipChildren(), context) && confidenceBasedBroadcastEnabled(context.getSession())) {
Optional<JoinNode> result = confidenceBasedBroadcast(joinNode, context);
if (result.isPresent()) {
return createJoinEnumerationResult(result.get());
}
}

List<JoinEnumerationResult> possibleJoinNodes = getPossibleJoinNodes(joinNode, getJoinDistributionType(session));
verify(!possibleJoinNodes.isEmpty(), "possibleJoinNodes is empty");
if (possibleJoinNodes.stream().anyMatch(UNKNOWN_COST_RESULT::equals)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.CONFIDENCE_BASED_BROADCAST_ENABLED;
import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static com.facebook.presto.SystemSessionProperties.JOIN_MAX_BROADCAST_TABLE_SIZE;
import static com.facebook.presto.SystemSessionProperties.USE_BROADCAST_WHEN_BUILDSIZE_SMALL_PROBESIDE_UNKNOWN;
Expand All @@ -54,6 +55,8 @@
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.spi.plan.JoinType.RIGHT;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.HIGH;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.enforceSingleRow;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
Expand Down Expand Up @@ -246,6 +249,146 @@ public void testRetainDistributionType()
.doesNotFire();
}

@Test
public void testHighConfidenceLeftAndLowConfidenceRight()
{
int aRows = 50;
int bRows = 30;
assertDetermineJoinDistributionType()
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name())
.setSystemProperty(CONFIDENCE_BASED_BROADCAST_ENABLED, "TRUE")
.overrideStats("valuesA", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(aRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100)))
.build())
.overrideStats("valuesB", PlanNodeStatsEstimate.builder()
.setConfidence(LOW)
.setOutputRowCount(bRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100)))
.build())
.on(p ->
p.join(
INNER,
p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT), p.variable("A2", BIGINT)),
p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)),
ImmutableList.of(new EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))),
ImmutableList.of(p.variable("A1", BIGINT), p.variable("A2", BIGINT), p.variable("B1", BIGINT)),
Optional.empty()))
.matches(join(
INNER,
ImmutableList.of(equiJoinClause("B1", "A1")),
Optional.empty(),
Optional.of(REPLICATED),
values(ImmutableMap.of("B1", 0)),
values(ImmutableMap.of("A1", 0, "A2", 1))));
}

@Test
public void testLowConfidenceLeftAndHighConfidenceRight()
{
int aRows = 50;
int bRows = 90;
assertDetermineJoinDistributionType()
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name())
.setSystemProperty(CONFIDENCE_BASED_BROADCAST_ENABLED, "TRUE")
.overrideStats("valuesA", PlanNodeStatsEstimate.builder()
.setConfidence(LOW)
.setOutputRowCount(aRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100)))
.build())
.overrideStats("valuesB", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(bRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100)))
.build())
.on(p ->
p.join(
INNER,
p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT), p.variable("A2", BIGINT)),
p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)),
ImmutableList.of(new EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))),
ImmutableList.of(p.variable("A1", BIGINT), p.variable("A2", BIGINT), p.variable("B1", BIGINT)),
Optional.empty()))
.matches(join(
INNER,
ImmutableList.of(equiJoinClause("A1", "B1")),
Optional.empty(),
Optional.of(REPLICATED),
values(ImmutableMap.of("A1", 0, "A2", 1)),
values(ImmutableMap.of("B1", 0))));
}

@Test
public void testLeftAndRightHighConfidenceRightSmaller()
{
int aRows = 90;
int bRows = 50;
assertDetermineJoinDistributionType()
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name())
.setSystemProperty(CONFIDENCE_BASED_BROADCAST_ENABLED, "TRUE")
.overrideStats("valuesA", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(aRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100)))
.build())
.overrideStats("valuesB", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(bRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640, 100)))
.build())
.on(p ->
p.join(
INNER,
p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT), p.variable("A2", BIGINT)),
p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)),
ImmutableList.of(new EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))),
ImmutableList.of(p.variable("A1", BIGINT), p.variable("A2", BIGINT), p.variable("B1", BIGINT)),
Optional.empty()))
.matches(join(
INNER,
ImmutableList.of(equiJoinClause("A1", "B1")),
Optional.empty(),
Optional.of(REPLICATED),
values(ImmutableMap.of("A1", 0, "A2", 1)),
values(ImmutableMap.of("B1", 0))));
}

@Test
public void testLeftAndRightHighConfidenceLeftSmaller()
{
int aRows = 50;
int bRows = 90;
assertDetermineJoinDistributionType()
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name())
.setSystemProperty(CONFIDENCE_BASED_BROADCAST_ENABLED, "TRUE")
.overrideStats("valuesA", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(aRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100)))
.build())
.overrideStats("valuesB", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(bRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100)))
.build())
.on(p ->
p.join(
INNER,
p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT), p.variable("A2", BIGINT)),
p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)),
ImmutableList.of(new EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))),
ImmutableList.of(p.variable("A1", BIGINT), p.variable("A2", BIGINT), p.variable("B1", BIGINT)),
Optional.empty()))
.matches(join(
INNER,
ImmutableList.of(equiJoinClause("B1", "A1")),
Optional.empty(),
Optional.of(REPLICATED),
values(ImmutableMap.of("B1", 0)),
values(ImmutableMap.of("A1", 0, "A2", 1))));
}

@Test
public void testFlipAndReplicateWhenOneTableMuchSmaller()
{
Expand Down
Loading

0 comments on commit 0d702bb

Please sign in to comment.