Skip to content

Commit

Permalink
Fix wrong results bug with count over mixed aggregation
Browse files Browse the repository at this point in the history
Fix a wrong results bug with count over an aggregation that had a mix of
global and non-global grouping sets.

We were not checking for single global aggregations correctly in
QueryCardinalityUtil, so we would return that the plan was scalar if
there were *any* empty grouping sets rather than if the empty grouping
set was the only grouping set.  we have now fixed this to return that
if all of the grouping sets are global, then the cardinality will be the
number of grouping sets, and otherwise it is at least the number of
global grouping sets.

This fixes queries like the following:
SELECT COUNT(*) FROM (SELECT count(*) FROM tpch.sf1.nation GROUP BY GROUPING SETS (nationkey, ()));

previously we would incorrectly return 1. And now we return 26.

This change may also fix bugs with correlated subqueries, as those also use
the isScalar() utility function.
  • Loading branch information
rschlussel committed Jun 18, 2024
1 parent 1286ae4 commit 7727beb
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ public Range<Long> visitEnforceSingleRow(EnforceSingleRowNode node, Void context
@Override
public Range<Long> visitAggregation(AggregationNode node, Void context)
{
if (node.hasEmptyGroupingSet()) {
return Range.singleton(1L);
if (!node.hasNonEmptyGroupingSet()) {
// if there are no non-empty grouping sets, then the number of rows returned will be the number of
// non-empty (i.e. global) grouping sets.
return Range.singleton((long) node.getGlobalGroupingSets().size());
}
return Range.atLeast(0L);
return Range.atLeast((long) node.getGlobalGroupingSets().size());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.GroupingSetDescriptor;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.testing.TestingTransactionHandle;
import com.facebook.presto.tpch.TpchColumnHandle;
import com.facebook.presto.tpch.TpchTableHandle;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;

import java.util.Optional;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions;
Expand Down Expand Up @@ -118,7 +119,8 @@ public void testDoesNotFireOnNestedCountAggregateWithNonEmptyGroupBy()
.source(
p.aggregation(aggregationBuilder -> {
aggregationBuilder
.source(p.tableScan(ImmutableList.of(), ImmutableMap.of())).groupingSets(singleGroupingSet(ImmutableList.of(p.variable("orderkey"))));
.source(p.tableScan(ImmutableList.of(), ImmutableMap.of())).groupingSets(
new GroupingSetDescriptor(ImmutableList.of(p.variable("orderkey")), 2, ImmutableSet.of(1)));
aggregationBuilder
.source(p.tableScan(ImmutableList.of(), ImmutableMap.of()));
}))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,30 @@
*/
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.plan.AggregationNode.GroupingSetDescriptor;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Range;
import org.testng.annotations.Test;

import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.metadata.AbstractMockMetadata.dummyMetadata;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality;
import static java.util.Collections.emptyList;
import static org.testng.Assert.assertEquals;

public class TestCardinalityExtractorPlanVisitor
{
private static final Metadata METADATA = MetadataManager.createTestMetadataManager();

@Test
public void testLimitOnTopOfValues()
{
PlanBuilder planBuilder = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), dummyMetadata());
PlanBuilder planBuilder = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), METADATA);

assertEquals(
extractCardinality(planBuilder.limit(3, planBuilder.values(emptyList(), ImmutableList.of(emptyList())))),
Expand All @@ -40,4 +46,52 @@ public void testLimitOnTopOfValues()
extractCardinality(planBuilder.limit(3, planBuilder.values(emptyList(), ImmutableList.of(emptyList(), emptyList(), emptyList(), emptyList())))),
Range.singleton(3L));
}

@Test
public void testGlobalAggregation()
{
PlanBuilder planBuilder = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), METADATA);
assertEquals(
extractCardinality(planBuilder.aggregation(aggregationBuilder -> aggregationBuilder
.addAggregation(planBuilder.variable("count", BIGINT), planBuilder.rowExpression("count()"))
.globalGrouping()
.source(planBuilder.values(planBuilder.variable("x", BIGINT), planBuilder.variable("y", BIGINT), planBuilder.variable("z", BIGINT))))),
Range.singleton(1L));
}

@Test
public void testSimpleGroupedAggregation()
{
PlanBuilder planBuilder = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), METADATA);
assertEquals(
extractCardinality(planBuilder.aggregation(aggregationBuilder -> aggregationBuilder
.addAggregation(planBuilder.variable("count", BIGINT), planBuilder.rowExpression("count()"))
.singleGroupingSet(planBuilder.variable("y", BIGINT), planBuilder.variable("z", BIGINT))
.source(planBuilder.values(planBuilder.variable("x", BIGINT), planBuilder.variable("y", BIGINT), planBuilder.variable("z", BIGINT))))),
Range.atLeast(0L));
}

@Test
public void testMultipleGlobalGroupingSets()
{
PlanBuilder planBuilder = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), METADATA);
assertEquals(
extractCardinality(planBuilder.aggregation(aggregationBuilder -> aggregationBuilder
.addAggregation(planBuilder.variable("count", BIGINT), planBuilder.rowExpression("count()"))
.groupingSets(new GroupingSetDescriptor(ImmutableList.of(), 2, ImmutableSet.of(0, 1)))
.source(planBuilder.values(planBuilder.variable("x", BIGINT), planBuilder.variable("y", BIGINT), planBuilder.variable("z", BIGINT))))),
Range.singleton(2L));
}

@Test
public void testEmptyAndNonEmptyGroupingSets()
{
PlanBuilder planBuilder = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), METADATA);
assertEquals(
extractCardinality(planBuilder.aggregation(aggregationBuilder -> aggregationBuilder
.addAggregation(planBuilder.variable("count", BIGINT), planBuilder.rowExpression("count()"))
.groupingSets(new GroupingSetDescriptor(ImmutableList.of(planBuilder.variable("y", BIGINT)), 2, ImmutableSet.of(0)))
.source(planBuilder.values(planBuilder.variable("x", BIGINT), planBuilder.variable("y", BIGINT), planBuilder.variable("z", BIGINT))))),
Range.atLeast(1L));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,18 @@ public void testCountColumn()
assertQuery("SELECT COUNT(CAST(NULL AS BIGINT)) FROM orders"); // todo: make COUNT(null) work
}

@Test
public void testCountOverGlobalAggregation()
{
assertQuery("SELECT COUNT(*) FROM (SELECT COUNT(*) FROM nation)");
}

@Test
public void testCountOverGroupedAggregation()
{
assertQuery("SELECT COUNT(*) FROM (SELECT COUNT(*) FROM nation GROUP BY GROUPING SETS (nationkey, ()))", "SELECT 26");
}

@Test
public void testWildcard()
{
Expand Down

0 comments on commit 7727beb

Please sign in to comment.