Skip to content

Commit

Permalink
Allow arbitrary aggregation functions for statistics collection
Browse files Browse the repository at this point in the history
The `ColumnStatisticType` enum was defining what is possible to collect
during statistics collection. While looking generic, the chosen options
matched exactly what stats Hive metastore collects. Different metadata
storages may require different statistics to be collected, for example
data sketches with some specific configuration.

This change allows a connector to pick any existing aggregation
function.
  • Loading branch information
findepi committed Sep 22, 2022
1 parent 52f6e01 commit 30b39e1
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.trino.operator.aggregation.MaxDataSizeForStats;
import io.trino.operator.aggregation.SumDataSizeForStats;
import io.trino.spi.TrinoException;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.statistics.ColumnStatisticMetadata;
import io.trino.spi.statistics.ColumnStatisticType;
import io.trino.spi.statistics.TableStatisticType;
Expand All @@ -36,6 +37,7 @@
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.base.Verify.verifyNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -92,13 +94,23 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta

for (ColumnStatisticMetadata columnStatisticMetadata : statisticsMetadata.getColumnStatistics()) {
String columnName = columnStatisticMetadata.getColumnName();
ColumnStatisticType statisticType = columnStatisticMetadata.getStatisticType();
Symbol inputSymbol = columnToSymbolMap.get(columnName);
verifyNotNull(inputSymbol, "inputSymbol is null");
Type inputType = symbolAllocator.getTypes().get(inputSymbol);
verifyNotNull(inputType, "inputType is null for symbol: %s", inputSymbol);
ColumnStatisticsAggregation aggregation = createColumnAggregation(statisticType, inputSymbol, inputType);
Symbol symbol = symbolAllocator.newSymbol(statisticType + ":" + columnName, aggregation.getOutputType());
ColumnStatisticsAggregation aggregation;
String symbolHint;
if (columnStatisticMetadata.getStatisticTypeIfPresent().isPresent()) {
ColumnStatisticType statisticType = columnStatisticMetadata.getStatisticType();
aggregation = createColumnAggregation(statisticType, inputSymbol, inputType);
symbolHint = statisticType + ":" + columnName;
}
else {
FunctionName aggregationName = columnStatisticMetadata.getAggregation();
aggregation = createColumnAggregation(aggregationName, inputSymbol, inputType);
symbolHint = aggregationName.getName() + ":" + columnName;
}
Symbol symbol = symbolAllocator.newSymbol(symbolHint, aggregation.getOutputType());
aggregations.put(symbol, aggregation.getAggregation());
descriptor.addColumnStatistic(columnStatisticMetadata, symbol);
}
Expand All @@ -123,6 +135,12 @@ private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticType
};
}

private ColumnStatisticsAggregation createColumnAggregation(FunctionName aggregation, Symbol input, Type inputType)
{
checkArgument(aggregation.getCatalogSchema().isEmpty(), "Catalog/schema name not supported");
return createAggregation(QualifiedName.of(aggregation.getName()), input.toSymbolReference(), inputType);
}

private ColumnStatisticsAggregation createAggregation(QualifiedName functionName, SymbolReference input, Type inputType)
{
ResolvedFunction resolvedFunction = metadata.resolveFunction(session, functionName, fromTypes(inputType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.NullableValue;
import io.trino.spi.predicate.Range;
Expand Down Expand Up @@ -1458,9 +1459,22 @@ private void printStatisticAggregationsInfo(
}

for (Map.Entry<ColumnStatisticMetadata, Symbol> columnStatistic : columnStatistics.entrySet()) {
String aggregationName;
if (columnStatistic.getKey().getStatisticTypeIfPresent().isPresent()) {
aggregationName = columnStatistic.getKey().getStatisticType().name();
}
else {
FunctionName aggregation = columnStatistic.getKey().getAggregation();
if (aggregation.getCatalogSchema().isPresent()) {
aggregationName = aggregation.getCatalogSchema().get() + "." + aggregation.getName();
}
else {
aggregationName = aggregation.getName();
}
}
nodeOutput.appendDetails(
indentString(1) + "%s[%s] => [%s := %s]",
columnStatistic.getKey().getStatisticType(),
aggregationName,
anonymizer.anonymizeColumn(columnStatistic.getKey().getColumnName()),
anonymizer.anonymize(columnStatistic.getValue()),
formatAggregation(anonymizer, aggregations.get(columnStatistic.getValue())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.reflect.TypeToken;
import io.airlift.json.JsonCodec;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.statistics.ColumnStatisticMetadata;
import io.trino.spi.statistics.ColumnStatisticType;
import io.trino.sql.planner.Symbol;
Expand Down Expand Up @@ -51,6 +52,8 @@ private static StatisticAggregationsDescriptor<Symbol> createTestDescriptor()
for (ColumnStatisticType type : ColumnStatisticType.values()) {
builder.addColumnStatistic(new ColumnStatisticMetadata(column, type), testSymbol(symbolAllocator));
}
builder.addColumnStatistic(new ColumnStatisticMetadata(column, new FunctionName("count")), testSymbol(symbolAllocator));
builder.addColumnStatistic(new ColumnStatisticMetadata(column, new FunctionName("count_if")), testSymbol(symbolAllocator));
builder.addGrouping(column, testSymbol(symbolAllocator));
}
builder.addTableStatistic(ROW_COUNT, testSymbol(symbolAllocator));
Expand Down
28 changes: 28 additions & 0 deletions core/trino-spi/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,34 @@
<code>java.method.removed</code>
<old>method io.trino.spi.type.TypeSignature io.trino.spi.type.VarcharType::getParametrizedVarcharSignature(java.lang.String)</old>
</item>
<item>
<ignore>true</ignore>
<code>java.annotation.removed</code>
<old>parameter void io.trino.spi.statistics.ColumnStatisticMetadata::&lt;init&gt;(===java.lang.String===, io.trino.spi.statistics.ColumnStatisticType)</old>
<new>parameter void io.trino.spi.statistics.ColumnStatisticMetadata::&lt;init&gt;(===java.lang.String===, io.trino.spi.statistics.ColumnStatisticType)</new>
<annotation>@com.fasterxml.jackson.annotation.JsonProperty("columnName")</annotation>
</item>
<item>
<ignore>true</ignore>
<code>java.annotation.removed</code>
<old>parameter void io.trino.spi.statistics.ColumnStatisticMetadata::&lt;init&gt;(java.lang.String, ===io.trino.spi.statistics.ColumnStatisticType===)</old>
<new>parameter void io.trino.spi.statistics.ColumnStatisticMetadata::&lt;init&gt;(java.lang.String, ===io.trino.spi.statistics.ColumnStatisticType===)</new>
<annotation>@com.fasterxml.jackson.annotation.JsonProperty("statisticType")</annotation>
</item>
<item>
<ignore>true</ignore>
<code>java.annotation.removed</code>
<old>method void io.trino.spi.statistics.ColumnStatisticMetadata::&lt;init&gt;(java.lang.String, io.trino.spi.statistics.ColumnStatisticType)</old>
<new>method void io.trino.spi.statistics.ColumnStatisticMetadata::&lt;init&gt;(java.lang.String, io.trino.spi.statistics.ColumnStatisticType)</new>
<annotation>@com.fasterxml.jackson.annotation.JsonCreator</annotation>
</item>
<item>
<ignore>true</ignore>
<code>java.annotation.removed</code>
<old>method io.trino.spi.statistics.ColumnStatisticType io.trino.spi.statistics.ColumnStatisticMetadata::getStatisticType()</old>
<new>method io.trino.spi.statistics.ColumnStatisticType io.trino.spi.statistics.ColumnStatisticMetadata::getStatisticType()</new>
<annotation>@com.fasterxml.jackson.annotation.JsonProperty</annotation>
</item>
<item>
<code>java.method.visibilityReduced</code>
<old>method void io.trino.spi.block.DictionaryBlock::&lt;init&gt;(int, int, io.trino.spi.block.Block, int[], boolean, boolean, io.trino.spi.block.DictionaryId)</old>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
*/
package io.trino.spi.connector;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.Objects;

import static java.util.Locale.ENGLISH;
Expand All @@ -22,17 +25,22 @@ public final class CatalogSchemaName
private final String catalogName;
private final String schemaName;

public CatalogSchemaName(String catalogName, String schemaName)
@JsonCreator
public CatalogSchemaName(
@JsonProperty("catalogName") String catalogName,
@JsonProperty("schemaName") String schemaName)
{
this.catalogName = catalogName.toLowerCase(ENGLISH);
this.schemaName = schemaName.toLowerCase(ENGLISH);
}

@JsonProperty
public String getCatalogName()
{
return catalogName;
}

@JsonProperty
public String getSchemaName()
{
return schemaName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package io.trino.spi.expression;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.trino.spi.connector.CatalogSchemaName;

import java.util.Objects;
Expand All @@ -32,7 +34,10 @@ public FunctionName(String name)
this(Optional.empty(), name);
}

public FunctionName(Optional<CatalogSchemaName> catalogSchema, String name)
@JsonCreator
public FunctionName(
@JsonProperty("catalogSchema") Optional<CatalogSchemaName> catalogSchema,
@JsonProperty("name") String name)
{
this.catalogSchema = requireNonNull(catalogSchema, "catalogSchema is null");
this.name = requireNonNull(name, "name is null");
Expand All @@ -41,6 +46,7 @@ public FunctionName(Optional<CatalogSchemaName> catalogSchema, String name)
/**
* @return the catalog and schema of this function, or {@link Optional#empty()} if this is a built-in function
*/
@JsonProperty
public Optional<CatalogSchemaName> getCatalogSchema()
{
return catalogSchema;
Expand All @@ -49,6 +55,7 @@ public Optional<CatalogSchemaName> getCatalogSchema()
/**
* @return the function's name
*/
@JsonProperty
public String getName()
{
return name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,59 @@
package io.trino.spi.statistics;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.trino.spi.Experimental;
import io.trino.spi.expression.FunctionName;

import java.util.Objects;
import java.util.Optional;
import java.util.StringJoiner;

import static java.util.Objects.requireNonNull;

public class ColumnStatisticMetadata
{
private final String columnName;
private final ColumnStatisticType statisticType;
private final Optional<ColumnStatisticType> statisticType;
private final Optional<FunctionName> aggregation;

@JsonCreator
public ColumnStatisticMetadata(
@JsonProperty("columnName") String columnName,
@JsonProperty("statisticType") ColumnStatisticType statisticType)
String columnName,
ColumnStatisticType statisticType)
{
this(columnName, Optional.of(statisticType), Optional.empty());
}

@Experimental(eta = "2023-01-31")
public ColumnStatisticMetadata(
String columnName,
FunctionName aggregation)
{
this(columnName, Optional.empty(), Optional.of(aggregation));
}

private ColumnStatisticMetadata(
String columnName,
Optional<ColumnStatisticType> statisticType,
Optional<FunctionName> aggregation)
{
this.columnName = requireNonNull(columnName, "columnName is null");
this.statisticType = requireNonNull(statisticType, "statisticType is null");
this.aggregation = requireNonNull(aggregation, "aggregation is null");
if (statisticType.isPresent() == aggregation.isPresent()) {
throw new IllegalArgumentException("Exactly one of statisticType and aggregation should be set");
}
}

@Deprecated // For JSON deserialization only
@JsonCreator
public static ColumnStatisticMetadata fromJson(
@JsonProperty("columnName") String columnName,
@JsonProperty("statisticType") Optional<ColumnStatisticType> statisticType,
@JsonProperty("aggregation") Optional<FunctionName> aggregation)
{
return new ColumnStatisticMetadata(columnName, statisticType, aggregation);
}

@JsonProperty
Expand All @@ -40,12 +75,33 @@ public String getColumnName()
return columnName;
}

@JsonProperty
@JsonIgnore
public ColumnStatisticType getStatisticType()
{
return statisticType.orElseThrow();
}

@Experimental(eta = "2023-01-31")
@JsonProperty("statisticType")
public Optional<ColumnStatisticType> getStatisticTypeIfPresent()
{
return statisticType;
}

@Experimental(eta = "2023-01-31")
@JsonIgnore
public FunctionName getAggregation()
{
return aggregation.orElseThrow();
}

@Experimental(eta = "2023-01-31")
@JsonProperty("aggregation")
public Optional<FunctionName> getAggregationIfPresent()
{
return aggregation;
}

@Override
public boolean equals(Object o)
{
Expand All @@ -57,21 +113,23 @@ public boolean equals(Object o)
}
ColumnStatisticMetadata that = (ColumnStatisticMetadata) o;
return Objects.equals(columnName, that.columnName) &&
statisticType == that.statisticType;
Objects.equals(statisticType, that.statisticType) &&
Objects.equals(aggregation, that.aggregation);
}

@Override
public int hashCode()
{
return Objects.hash(columnName, statisticType);
return Objects.hash(columnName, statisticType, aggregation);
}

@Override
public String toString()
{
return "ColumnStatisticMetadata{" +
"columnName='" + columnName + '\'' +
", statisticType=" + statisticType +
'}';
return new StringJoiner(", ", ColumnStatisticMetadata.class.getSimpleName() + "[", "]")
.add("columnName='" + columnName + "'")
.add("statisticType=" + statisticType)
.add("aggregation=" + aggregation)
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import io.trino.spi.connector.TableColumnsMetadata;
import io.trino.spi.connector.TableNotFoundException;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.NullableValue;
Expand Down Expand Up @@ -224,7 +225,6 @@
import static io.trino.spi.connector.RetryMode.NO_RETRIES;
import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW;
import static io.trino.spi.predicate.Utils.blockToNativeValue;
import static io.trino.spi.statistics.ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
import static io.trino.spi.type.UuidType.UUID;
Expand Down Expand Up @@ -256,6 +256,8 @@ public class IcebergMetadata
public static final String ORC_BLOOM_FILTER_COLUMNS_KEY = "orc.bloom.filter.columns";
public static final String ORC_BLOOM_FILTER_FPP_KEY = "orc.bloom.filter.fpp";

private static final FunctionName NUMBER_OF_DISTINCT_VALUES = new FunctionName("approx_distinct");

private final TypeManager typeManager;
private final TypeOperators typeOperators;
private final JsonCodec<CommitTaskData> commitTaskCodec;
Expand Down Expand Up @@ -1526,7 +1528,7 @@ public void finishStatisticsCollection(ConnectorSession session, ConnectorTableH
verify(computedStatistic.getTableStatistics().isEmpty(), "Unexpected table statistics");
for (Map.Entry<ColumnStatisticMetadata, Block> entry : computedStatistic.getColumnStatistics().entrySet()) {
ColumnStatisticMetadata statisticMetadata = entry.getKey();
if (statisticMetadata.getStatisticType() == NUMBER_OF_DISTINCT_VALUES) {
if (statisticMetadata.getAggregation().equals(NUMBER_OF_DISTINCT_VALUES)) {
long ndv = (long) blockToNativeValue(BIGINT, entry.getValue());
Integer columnId = verifyNotNull(
columnNameToId.get(statisticMetadata.getColumnName()),
Expand Down

0 comments on commit 30b39e1

Please sign in to comment.