Skip to content

Commit

Permalink
Add additional type support for KLL Sketches
Browse files Browse the repository at this point in the history
This adds support just for the aggregation functions.
New types supported:

1. real
2. smallint
3. tinyint
5. date
6. time
7. timestamp
8. timestamp with timezone
  • Loading branch information
ZacBlanco committed Jul 16, 2024
1 parent 3d25cd9 commit 9ecfaf2
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.operator.aggregation.sketch.kll;

import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
Expand All @@ -31,36 +32,37 @@ public class KllSketchAggregationFunction
* here.
*/
private static final int DEFAULT_K = 200;

private KllSketchAggregationFunction()
{
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") long value)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") long value)
{
KllSketchWithKAggregationFunction.input(state, value, DEFAULT_K);
KllSketchWithKAggregationFunction.input(type, state, value, DEFAULT_K);
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") double value)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") double value)
{
KllSketchWithKAggregationFunction.input(state, value, DEFAULT_K);
KllSketchWithKAggregationFunction.input(type, state, value, DEFAULT_K);
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") Slice value)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") Slice value)
{
KllSketchWithKAggregationFunction.input(state, value, DEFAULT_K);
KllSketchWithKAggregationFunction.input(type, state, value, DEFAULT_K);
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") boolean value)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") boolean value)
{
KllSketchWithKAggregationFunction.input(state, value, DEFAULT_K);
KllSketchWithKAggregationFunction.input(type, state, value, DEFAULT_K);
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package com.facebook.presto.operator.aggregation.sketch.kll;

import com.facebook.presto.common.array.ObjectBigArray;
import com.facebook.presto.common.type.AbstractVarcharType;
import com.facebook.presto.common.type.BigintEnumType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.aggregation.sketch.theta.ThetaSketchStateFactory;
import com.facebook.presto.operator.aggregation.state.AbstractGroupedAccumulatorState;
Expand All @@ -34,10 +36,25 @@

import java.util.Comparator;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DateTimeEncoding.unpackMillisUtc;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.TimeType.TIME;
import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS;
import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH;
import static java.util.Objects.requireNonNull;

@AccumulatorStateMetadata(stateFactoryClass = KllSketchStateFactory.class, stateSerializerClass = KllSketchStateSerializer.class)
Expand Down Expand Up @@ -67,15 +84,26 @@ public interface KllSketchAggregationState

<T> void setSketch(KllItemsSketch<T> sketch);

void setConversion(Function<Object, Object> conversion);

<T> void update(T item);

class SketchParameters<T>
{
private final Comparator<T> comparator;
private final ArrayOfItemsSerDe<T> serde;
private final Function<Object, Object> conversion;

public SketchParameters(Comparator<T> comparator, ArrayOfItemsSerDe<T> serde)
public SketchParameters(Comparator<T> comparator, ArrayOfItemsSerDe<T> serde, Function<Object, Object> conversion)
{
this.comparator = comparator;
this.serde = serde;
this.conversion = conversion;
}

public SketchParameters(Comparator<T> comparator, ArrayOfItemsSerDe<T> serde)
{
this(comparator, serde, Function.identity());
}

public Comparator<T> getComparator()
Expand All @@ -87,6 +115,11 @@ public ArrayOfItemsSerDe<T> getSerde()
{
return serde;
}

public Function<Object, Object> getConversion()
{
return conversion;
}
}

class Single
Expand All @@ -96,6 +129,7 @@ class Single
@Nullable
private KllItemsSketch sketch;
private final Type type;
private Function<Object, Object> conversion = Function.identity();

public Single(Type type)
{
Expand All @@ -115,6 +149,18 @@ public <T> void setSketch(KllItemsSketch<T> sketch)
this.sketch = sketch;
}

@Override
public void setConversion(Function<Object, Object> conversion)
{
this.conversion = conversion;
}

@Override
public <T> void update(T item)
{
sketch.update(conversion.apply(item));
}

@Override
public void addMemoryUsage(Supplier<Long> usage)
{
Expand Down Expand Up @@ -149,6 +195,7 @@ class Grouped
private long accumulatedSizeInBytes;

private final Type type;
private Function<Object, Object> conversion = Function.identity();

public Grouped(Type type)
{
Expand Down Expand Up @@ -179,6 +226,18 @@ public <T> void setSketch(KllItemsSketch<T> sketch)
sketches.set(getGroupId(), requireNonNull(sketch, "kll sketch is null"));
}

@Override
public void setConversion(Function<Object, Object> conversion)
{
this.conversion = conversion;
}

@Override
public <T> void update(T item)
{
getSketch().update(conversion.apply(item));
}

@Override
public long getEstimatedSize()
{
Expand Down Expand Up @@ -208,20 +267,39 @@ static long getEstimatedKllInMemorySize(@Nullable KllItemsSketch<?> sketch, Clas

static SketchParameters<?> getSketchParameters(Type type)
{
if (type.getJavaType().equals(double.class)) {
return new SketchParameters<>(Double::compareTo, new ArrayOfDoublesSerDe());
if (!type.isOrderable() || !type.isComparable()) {
throw new PrestoException(INVALID_ARGUMENTS, type + " does not support comparisons or ordering");
}
else if (type.getJavaType().equals(long.class)) {
return new SketchParameters<>(Long::compareTo, new ArrayOfLongsSerDe());

if (type.equals(REAL)) {
return new SketchParameters<>(Double::compareTo, new ArrayOfDoublesSerDe(),
(Object intValue) -> (double) Float.intBitsToFloat(((Long) intValue).intValue()));
}
else if (type.getJavaType().equals(Slice.class)) {
return new SketchParameters<>(String::compareTo, new ArrayOfStringsSerDe());
else if (type.equals(DOUBLE)) {
return new SketchParameters<>(Double::compareTo, new ArrayOfDoublesSerDe());
}
else if (type.getJavaType().equals(boolean.class)) {
else if (type.equals(BOOLEAN)) {
return new SketchParameters<>(Boolean::compareTo, new ArrayOfBooleansSerDe());
}
else if (type.equals(TIMESTAMP_WITH_TIME_ZONE) || type.equals(TIME_WITH_TIME_ZONE)) {
return new SketchParameters<>(Long::compareTo, new ArrayOfLongsSerDe(), (Object packed) -> unpackMillisUtc((Long) packed));
}
else if (type.equals(TINYINT) ||
type.equals(SMALLINT) ||
type.equals(INTEGER) ||
type.equals(BIGINT) ||
type instanceof BigintEnumType ||
type.equals(TIME) ||
type.equals(TIMESTAMP) ||
type.equals(DATE) ||
type.equals(INTERVAL_YEAR_MONTH)) {
return new SketchParameters<>(Long::compareTo, new ArrayOfLongsSerDe());
}
else if (type instanceof AbstractVarcharType) {
return new SketchParameters<>(String::compareTo, new ArrayOfStringsSerDe(), (Object slice) -> ((Slice) slice).toStringUtf8());
}
else {
throw new PrestoException(INVALID_ARGUMENTS, "failed to deserialize KLL Sketch. No suitable type found for " + type);
throw new PrestoException(INVALID_ARGUMENTS, "Unsupported type for KLL sketch: " + type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.TypeParameter;
import io.airlift.slice.Slices;
import org.apache.datasketches.kll.KllItemsSketch;
import org.apache.datasketches.memory.Memory;

Expand Down Expand Up @@ -46,12 +45,7 @@ public Type getSerializedType()
@Override
public void serialize(KllSketchAggregationState state, BlockBuilder out)
{
if (state.getSketch() == null) {
out.appendNull();
return;
}

VARBINARY.writeSlice(out, Slices.wrappedBuffer(state.getSketch().toByteArray()));
KllSketchWithKAggregationFunction.output(state, out);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.operator.aggregation.sketch.kll;

import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
Expand All @@ -24,16 +25,8 @@
import com.facebook.presto.spi.function.TypeParameter;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import org.apache.datasketches.common.ArrayOfBooleansSerDe;
import org.apache.datasketches.common.ArrayOfDoublesSerDe;
import org.apache.datasketches.common.ArrayOfItemsSerDe;
import org.apache.datasketches.common.ArrayOfLongsSerDe;
import org.apache.datasketches.common.ArrayOfStringsSerDe;
import org.apache.datasketches.kll.KllItemsSketch;

import java.util.Comparator;
import java.util.function.Supplier;

import static com.facebook.presto.common.type.StandardTypes.BIGINT;
import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.operator.aggregation.sketch.kll.KllSketchAggregationState.getEstimatedKllInMemorySize;
Expand All @@ -50,45 +43,45 @@ private KllSketchWithKAggregationFunction()

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") long value, @SqlType(BIGINT) long k)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") long value, @SqlType(BIGINT) long k)
{
initializeSketch(state, () -> Long::compareTo, ArrayOfLongsSerDe::new, k);
initializeSketch(state, type, k);
KllItemsSketch<Long> sketch = state.getSketch();
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, long.class));
state.getSketch().update(value);
state.update(value);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, long.class));
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") double value, @SqlType(BIGINT) long k)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") double value, @SqlType(BIGINT) long k)
{
initializeSketch(state, () -> Double::compareTo, ArrayOfDoublesSerDe::new, k);
initializeSketch(state, type, k);
KllItemsSketch<Double> sketch = state.getSketch();
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, double.class));
state.getSketch().update(value);
state.update(value);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, double.class));
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") Slice value, @SqlType(BIGINT) long k)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") Slice value, @SqlType(BIGINT) long k)
{
initializeSketch(state, () -> String::compareTo, ArrayOfStringsSerDe::new, k);
initializeSketch(state, type, k);
KllItemsSketch sketch = state.getSketch();
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, Slice.class));
state.getSketch().update(value.toStringUtf8());
state.update(value);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, Slice.class));
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") boolean value, @SqlType(BIGINT) long k)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") boolean value, @SqlType(BIGINT) long k)
{
initializeSketch(state, () -> Boolean::compareTo, ArrayOfBooleansSerDe::new, k);
initializeSketch(state, type, k);
KllItemsSketch<Boolean> sketch = state.getSketch();
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, boolean.class));
state.getSketch().update(value);
state.update(value);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, boolean.class));
}

Expand Down Expand Up @@ -117,15 +110,22 @@ public static void output(@AggregationState KllSketchAggregationState state, Blo
VARBINARY.writeSlice(out, Slices.wrappedBuffer(state.getSketch().toByteArray()));
}

private static <T> void initializeSketch(KllSketchAggregationState state, Supplier<Comparator<T>> comparator, Supplier<ArrayOfItemsSerDe<T>> serdeSupplier, long k)
@SuppressWarnings({"rawtypes", "unchecked"})
private static void initializeSketch(KllSketchAggregationState state, Type type, long k)
{
if (state.getSketch() != null) {
return;
}

if (k < 8 || k > MAX_K) {
throw new PrestoException(INVALID_ARGUMENTS, format("k value must satisfy 8 <= k <= %d: %d", MAX_K, k));
}
if (state.getSketch() == null) {
KllItemsSketch<T> sketch = KllItemsSketch.newHeapInstance((int) k, comparator.get(), serdeSupplier.get());
state.setSketch(sketch);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, state.getType().getJavaType()));
}

KllSketchAggregationState.SketchParameters parameters = KllSketchAggregationState.getSketchParameters(type);
KllItemsSketch sketch = KllItemsSketch.newHeapInstance((int) k, parameters.getComparator(), parameters.getSerde());

state.setSketch(sketch);
state.setConversion(parameters.getConversion());
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, state.getType().getJavaType()));
}
}
Loading

0 comments on commit 9ecfaf2

Please sign in to comment.