Skip to content

Commit

Permalink
Prioritize HBO stats during broadcast joins
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavmuk04 committed Jun 24, 2024
1 parent 1e80972 commit 2d8068c
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateJoinCostWithoutOutput;
import static com.facebook.presto.spi.plan.JoinDistributionType.PARTITIONED;
import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.HIGH;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.AUTOMATIC;
import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isBelowBroadcastLimit;
import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isSmallerThanThreshold;
Expand Down Expand Up @@ -124,6 +125,21 @@ private PlanNode getCostBasedJoin(JoinNode joinNode, Context context)
addJoinsWithDifferentDistributions(joinNode, possibleJoinNodes, context);
addJoinsWithDifferentDistributions(joinNode.flipChildren(), possibleJoinNodes, context);

if (isBelowMaxBroadcastSize(joinNode, context) && isBelowMaxBroadcastSize(joinNode.flipChildren(), context)) {
boolean rightIsHBO = isHBO(joinNode.getRight(), context);
boolean leftIsHBO = isHBO(joinNode.getLeft(), context);

if (rightIsHBO && leftIsHBO) {
return chooseSmallerSideForBroadcast(joinNode, context);
}
else if (rightIsHBO) {
return joinNode.withDistributionType(REPLICATED);
}
else if (leftIsHBO) {
return joinNode.flipChildren().withDistributionType(REPLICATED);
}
}

if (possibleJoinNodes.stream().anyMatch(result -> result.getCost().hasUnknownComponents()) || possibleJoinNodes.isEmpty()) {
// TODO: currently this session parameter is added so as to roll out the plan change gradually, after proved to be a better choice, make it default and get rid of the session parameter here.
if (isUseBroadcastJoinWhenBuildSizeSmallProbeSizeUnknownEnabled(context.getSession()) && possibleJoinNodes.stream().anyMatch(result -> ((JoinNode) result.getPlanNode()).getDistributionType().get().equals(REPLICATED))) {
Expand Down Expand Up @@ -236,6 +252,25 @@ private JoinNode getSyntacticOrderJoin(JoinNode joinNode, Context context, JoinD
return joinNode.withDistributionType(REPLICATED);
}

private boolean isHBO(PlanNode planNode, Context context)
{
StatsProvider statsProvider = context.getStatsProvider();
PlanNodeStatsEstimate stats = statsProvider.getStats(planNode);
return stats.confidenceLevel() == HIGH;
}

private JoinNode chooseSmallerSideForBroadcast(JoinNode joinNode, Context context)
{
double rightSize = getSourceTablesSizeInBytes(joinNode.getRight(), context);
double leftSize = getSourceTablesSizeInBytes(joinNode.getLeft(), context);
if (rightSize <= leftSize) {
return joinNode.withDistributionType(REPLICATED);
}
else {
return joinNode.flipChildren().withDistributionType(REPLICATED);
}
}

private boolean mustPartition(JoinNode joinNode)
{
return joinNode.getType().mustPartition();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
Expand Down Expand Up @@ -77,10 +79,12 @@
import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.HIGH;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.AUTOMATIC;
import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique;
import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.getSourceTablesSizeInBytes;
import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.isBelowMaxBroadcastSize;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT;
Expand Down Expand Up @@ -537,6 +541,22 @@ private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode)
if (isAtMostScalar(joinNode.getLeft(), lookup)) {
return createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(REPLICATED));
}

if (isBelowMaxBroadcastSize(joinNode, context) && isBelowMaxBroadcastSize(joinNode.flipChildren(), context)) {
boolean rightIsHBO = isHBO(joinNode.getRight(), context);
boolean leftIsHBO = isHBO(joinNode.getLeft(), context);

if (rightIsHBO && leftIsHBO) {
return chooseSmallerSideForBroadcast(joinNode, context);
}
else if (rightIsHBO) {
return createJoinEnumerationResult(joinNode.withDistributionType(REPLICATED));
}
else if (leftIsHBO) {
return createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(REPLICATED));
}
}

List<JoinEnumerationResult> possibleJoinNodes = getPossibleJoinNodes(joinNode, getJoinDistributionType(session));
verify(!possibleJoinNodes.isEmpty(), "possibleJoinNodes is empty");
if (possibleJoinNodes.stream().anyMatch(UNKNOWN_COST_RESULT::equals)) {
Expand All @@ -545,6 +565,25 @@ private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode)
return resultComparator.min(possibleJoinNodes);
}

private boolean isHBO(PlanNode planNode, Context context)
{
StatsProvider statsProvider = context.getStatsProvider();
PlanNodeStatsEstimate stats = statsProvider.getStats(planNode);
return stats.confidenceLevel() == HIGH;
}

private JoinEnumerationResult chooseSmallerSideForBroadcast(JoinNode joinNode, Context context)
{
double rightSize = getSourceTablesSizeInBytes(joinNode.getRight(), context);
double leftSize = getSourceTablesSizeInBytes(joinNode.getLeft(), context);
if (rightSize <= leftSize) {
return createJoinEnumerationResult(joinNode.withDistributionType(REPLICATED));
}
else {
return createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(REPLICATED));
}
}

private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, JoinDistributionType distributionType)
{
checkArgument(joinNode.getType() == INNER, "unexpected join node type: %s", joinNode.getType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.spi.plan.JoinType.RIGHT;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.HIGH;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.enforceSingleRow;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
Expand Down Expand Up @@ -246,6 +247,140 @@ public void testRetainDistributionType()
.doesNotFire();
}

@Test
public void testHBOLeftAndRightNot()
{
int aRows = 100;
int bRows = 100;
assertDetermineJoinDistributionType()
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name())
.overrideStats("valuesA", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(aRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100)))
.build())
.overrideStats("valuesB", PlanNodeStatsEstimate.builder()
.setOutputRowCount(bRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100)))
.build())
.on(p ->
p.join(
INNER,
p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)),
p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)),
ImmutableList.of(new EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))),
ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)),
Optional.empty()))
.matches(join(
INNER,
ImmutableList.of(equiJoinClause("A1", "B1")),
Optional.empty(),
Optional.of(REPLICATED),
values(ImmutableMap.of("A1", 0)),
values(ImmutableMap.of("B1", 0))));
}

@Test
public void testLeftNotAndRightHBO()
{
int aRows = 100;
int bRows = 100;
assertDetermineJoinDistributionType()
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name())
.overrideStats("valuesA", PlanNodeStatsEstimate.builder()
.setOutputRowCount(aRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100)))
.build())
.overrideStats("valuesB", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(bRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100)))
.build())
.on(p ->
p.join(
INNER,
p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)),
p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)),
ImmutableList.of(new EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))),
ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)),
Optional.empty()))
.matches(join(
INNER,
ImmutableList.of(equiJoinClause("B1", "A1")),
Optional.empty(),
Optional.of(REPLICATED),
values(ImmutableMap.of("B1", 0)),
values(ImmutableMap.of("A1", 0))));
}

@Test
public void testBothHBOLeftSmaller()
{
int aRows = 10;
int bRows = 100;
assertDetermineJoinDistributionType()
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name())
.overrideStats("valuesA", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(aRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100)))
.build())
.overrideStats("valuesB", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(bRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100)))
.build())
.on(p ->
p.join(
INNER,
p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)),
p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)),
ImmutableList.of(new EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))),
ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)),
Optional.empty()))
.matches(join(
INNER,
ImmutableList.of(equiJoinClause("B1", "A1")),
Optional.empty(),
Optional.of(REPLICATED),
values(ImmutableMap.of("B1", 0)),
values(ImmutableMap.of("A1", 0))));
}

@Test
public void testBothHBORightSmaller()
{
int aRows = 100;
int bRows = 10;
assertDetermineJoinDistributionType()
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name())
.overrideStats("valuesA", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(aRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100)))
.build())
.overrideStats("valuesB", PlanNodeStatsEstimate.builder()
.setConfidence(HIGH)
.setOutputRowCount(bRows)
.addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100)))
.build())
.on(p ->
p.join(
INNER,
p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)),
p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)),
ImmutableList.of(new EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))),
ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)),
Optional.empty()))
.matches(join(
INNER,
ImmutableList.of(equiJoinClause("B1", "A1")),
Optional.empty(),
Optional.of(REPLICATED),
values(ImmutableMap.of("B1", 0)),
values(ImmutableMap.of("A1", 0))));
}

@Test
public void testFlipAndReplicateWhenOneTableMuchSmaller()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,24 @@ public void testBroadcastJoin()
anyTree(any())));
}

@Test
public void testBroadcastHBO()
{
Session broadcastSession = Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "AUTOMATIC")
.setSystemProperty(JOIN_REORDERING_STRATEGY, "NONE")
.setSystemProperty("prefer_partial_aggregation", "false")
.build();
String sql = "SELECT COUNT(*) FROM lineitem l JOIN supplier s ON l.suppkey = s.suppkey";

executeAndTrackHistory(sql, broadcastSession);
assertPlan(
broadcastSession,
sql,
anyTree(
node(AggregationNode.class, anyTree(any())).withOutputRowCount(1)));
}

private void executeAndTrackHistory(String sql)
{
getQueryRunner().execute(sql);
Expand Down

0 comments on commit 2d8068c

Please sign in to comment.