Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional type support for KLL Sketches #23114

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.operator.aggregation.sketch.kll;

import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
Expand All @@ -31,36 +32,37 @@ public class KllSketchAggregationFunction
* here.
*/
private static final int DEFAULT_K = 200;

private KllSketchAggregationFunction()
{
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") long value)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") long value)
{
KllSketchWithKAggregationFunction.input(state, value, DEFAULT_K);
KllSketchWithKAggregationFunction.input(type, state, value, DEFAULT_K);
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") double value)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") double value)
{
KllSketchWithKAggregationFunction.input(state, value, DEFAULT_K);
KllSketchWithKAggregationFunction.input(type, state, value, DEFAULT_K);
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") Slice value)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") Slice value)
{
KllSketchWithKAggregationFunction.input(state, value, DEFAULT_K);
KllSketchWithKAggregationFunction.input(type, state, value, DEFAULT_K);
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") boolean value)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") boolean value)
{
KllSketchWithKAggregationFunction.input(state, value, DEFAULT_K);
KllSketchWithKAggregationFunction.input(type, state, value, DEFAULT_K);
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package com.facebook.presto.operator.aggregation.sketch.kll;

import com.facebook.presto.common.array.ObjectBigArray;
import com.facebook.presto.common.type.AbstractVarcharType;
import com.facebook.presto.common.type.BigintEnumType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.aggregation.sketch.theta.ThetaSketchStateFactory;
import com.facebook.presto.operator.aggregation.state.AbstractGroupedAccumulatorState;
Expand All @@ -34,10 +36,25 @@

import java.util.Comparator;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DateTimeEncoding.unpackMillisUtc;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.TimeType.TIME;
import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS;
import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH;
import static java.util.Objects.requireNonNull;

@AccumulatorStateMetadata(stateFactoryClass = KllSketchStateFactory.class, stateSerializerClass = KllSketchStateSerializer.class)
Expand Down Expand Up @@ -67,15 +84,26 @@ public interface KllSketchAggregationState

<T> void setSketch(KllItemsSketch<T> sketch);

void setConversion(Function<Object, Object> conversion);

<T> void update(T item);

class SketchParameters<T>
{
private final Comparator<T> comparator;
private final ArrayOfItemsSerDe<T> serde;
private final Function<Object, Object> conversion;

public SketchParameters(Comparator<T> comparator, ArrayOfItemsSerDe<T> serde)
public SketchParameters(Comparator<T> comparator, ArrayOfItemsSerDe<T> serde, Function<Object, Object> conversion)
{
this.comparator = comparator;
this.serde = serde;
this.conversion = conversion;
}

public SketchParameters(Comparator<T> comparator, ArrayOfItemsSerDe<T> serde)
{
this(comparator, serde, Function.identity());
}

public Comparator<T> getComparator()
Expand All @@ -87,6 +115,11 @@ public ArrayOfItemsSerDe<T> getSerde()
{
return serde;
}

public Function<Object, Object> getConversion()
{
return conversion;
}
}

class Single
Expand All @@ -96,6 +129,7 @@ class Single
@Nullable
private KllItemsSketch sketch;
private final Type type;
private Function<Object, Object> conversion = Function.identity();

public Single(Type type)
{
Expand All @@ -115,6 +149,18 @@ public <T> void setSketch(KllItemsSketch<T> sketch)
this.sketch = sketch;
}

@Override
public void setConversion(Function<Object, Object> conversion)
{
this.conversion = conversion;
}

@Override
public <T> void update(T item)
{
sketch.update(conversion.apply(item));
}

@Override
public void addMemoryUsage(Supplier<Long> usage)
{
Expand Down Expand Up @@ -149,6 +195,7 @@ class Grouped
private long accumulatedSizeInBytes;

private final Type type;
private Function<Object, Object> conversion = Function.identity();

public Grouped(Type type)
{
Expand Down Expand Up @@ -179,6 +226,18 @@ public <T> void setSketch(KllItemsSketch<T> sketch)
sketches.set(getGroupId(), requireNonNull(sketch, "kll sketch is null"));
}

@Override
public void setConversion(Function<Object, Object> conversion)
{
this.conversion = conversion;
}

@Override
public <T> void update(T item)
{
getSketch().update(conversion.apply(item));
}

@Override
public long getEstimatedSize()
{
Expand Down Expand Up @@ -208,20 +267,39 @@ static long getEstimatedKllInMemorySize(@Nullable KllItemsSketch<?> sketch, Clas

static SketchParameters<?> getSketchParameters(Type type)
{
if (type.getJavaType().equals(double.class)) {
return new SketchParameters<>(Double::compareTo, new ArrayOfDoublesSerDe());
if (!type.isOrderable() || !type.isComparable()) {
throw new PrestoException(INVALID_ARGUMENTS, type + " does not support comparisons or ordering");
}
else if (type.getJavaType().equals(long.class)) {
return new SketchParameters<>(Long::compareTo, new ArrayOfLongsSerDe());

if (type.equals(REAL)) {
return new SketchParameters<>(Double::compareTo, new ArrayOfDoublesSerDe(),
(Object intValue) -> (double) Float.intBitsToFloat(((Long) intValue).intValue()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the physical type for floats is an int?

Suggested change
(Object intValue) -> (double) Float.intBitsToFloat(((Long) intValue).intValue()));
(Object intValue) -> (double) Float.intBitsToFloat((int) intValue));

Copy link
Contributor Author

@ZacBlanco ZacBlanco Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I recall when writing this: for the REAL type, the backing primitive type is a long (RealType extends AbstractIntType), but you can't perform a direct cast from Object to int, so you need to cast it to the boxed type (Long) first.

}
else if (type.getJavaType().equals(Slice.class)) {
return new SketchParameters<>(String::compareTo, new ArrayOfStringsSerDe());
else if (type.equals(DOUBLE)) {
return new SketchParameters<>(Double::compareTo, new ArrayOfDoublesSerDe());
}
else if (type.getJavaType().equals(boolean.class)) {
else if (type.equals(BOOLEAN)) {
return new SketchParameters<>(Boolean::compareTo, new ArrayOfBooleansSerDe());
}
else if (type.equals(TIMESTAMP_WITH_TIME_ZONE) || type.equals(TIME_WITH_TIME_ZONE)) {
return new SketchParameters<>(Long::compareTo, new ArrayOfLongsSerDe(), (Object packed) -> unpackMillisUtc((Long) packed));
}
else if (type.equals(TINYINT) ||
type.equals(SMALLINT) ||
type.equals(INTEGER) ||
type.equals(BIGINT) ||
type instanceof BigintEnumType ||
type.equals(TIME) ||
type.equals(TIMESTAMP) ||
type.equals(DATE) ||
type.equals(INTERVAL_YEAR_MONTH)) {
return new SketchParameters<>(Long::compareTo, new ArrayOfLongsSerDe());
}
else if (type instanceof AbstractVarcharType) {
return new SketchParameters<>(String::compareTo, new ArrayOfStringsSerDe(), (Object slice) -> ((Slice) slice).toStringUtf8());
}
else {
throw new PrestoException(INVALID_ARGUMENTS, "failed to deserialize KLL Sketch. No suitable type found for " + type);
throw new PrestoException(INVALID_ARGUMENTS, "Unsupported type for KLL sketch: " + type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.TypeParameter;
import io.airlift.slice.Slices;
import org.apache.datasketches.kll.KllItemsSketch;
import org.apache.datasketches.memory.Memory;

Expand Down Expand Up @@ -46,12 +45,7 @@ public Type getSerializedType()
@Override
public void serialize(KllSketchAggregationState state, BlockBuilder out)
{
if (state.getSketch() == null) {
out.appendNull();
return;
}

VARBINARY.writeSlice(out, Slices.wrappedBuffer(state.getSketch().toByteArray()));
KllSketchWithKAggregationFunction.output(state, out);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.operator.aggregation.sketch.kll;

import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
Expand All @@ -24,16 +25,8 @@
import com.facebook.presto.spi.function.TypeParameter;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import org.apache.datasketches.common.ArrayOfBooleansSerDe;
import org.apache.datasketches.common.ArrayOfDoublesSerDe;
import org.apache.datasketches.common.ArrayOfItemsSerDe;
import org.apache.datasketches.common.ArrayOfLongsSerDe;
import org.apache.datasketches.common.ArrayOfStringsSerDe;
import org.apache.datasketches.kll.KllItemsSketch;

import java.util.Comparator;
import java.util.function.Supplier;

import static com.facebook.presto.common.type.StandardTypes.BIGINT;
import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.operator.aggregation.sketch.kll.KllSketchAggregationState.getEstimatedKllInMemorySize;
Expand All @@ -50,45 +43,45 @@ private KllSketchWithKAggregationFunction()

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") long value, @SqlType(BIGINT) long k)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") long value, @SqlType(BIGINT) long k)
{
initializeSketch(state, () -> Long::compareTo, ArrayOfLongsSerDe::new, k);
initializeSketch(state, type, k);
KllItemsSketch<Long> sketch = state.getSketch();
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, long.class));
state.getSketch().update(value);
state.update(value);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, long.class));
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") double value, @SqlType(BIGINT) long k)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") double value, @SqlType(BIGINT) long k)
{
initializeSketch(state, () -> Double::compareTo, ArrayOfDoublesSerDe::new, k);
initializeSketch(state, type, k);
KllItemsSketch<Double> sketch = state.getSketch();
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, double.class));
state.getSketch().update(value);
state.update(value);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, double.class));
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") Slice value, @SqlType(BIGINT) long k)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") Slice value, @SqlType(BIGINT) long k)
{
initializeSketch(state, () -> String::compareTo, ArrayOfStringsSerDe::new, k);
initializeSketch(state, type, k);
KllItemsSketch sketch = state.getSketch();
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, Slice.class));
state.getSketch().update(value.toStringUtf8());
state.update(value);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, Slice.class));
}

@InputFunction
@TypeParameter("T")
public static void input(@AggregationState KllSketchAggregationState state, @SqlType("T") boolean value, @SqlType(BIGINT) long k)
public static void input(@TypeParameter("T") Type type, @AggregationState KllSketchAggregationState state, @SqlType("T") boolean value, @SqlType(BIGINT) long k)
{
initializeSketch(state, () -> Boolean::compareTo, ArrayOfBooleansSerDe::new, k);
initializeSketch(state, type, k);
KllItemsSketch<Boolean> sketch = state.getSketch();
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, boolean.class));
state.getSketch().update(value);
state.update(value);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, boolean.class));
}

Expand Down Expand Up @@ -117,15 +110,22 @@ public static void output(@AggregationState KllSketchAggregationState state, Blo
VARBINARY.writeSlice(out, Slices.wrappedBuffer(state.getSketch().toByteArray()));
}

private static <T> void initializeSketch(KllSketchAggregationState state, Supplier<Comparator<T>> comparator, Supplier<ArrayOfItemsSerDe<T>> serdeSupplier, long k)
@SuppressWarnings({"rawtypes", "unchecked"})
private static void initializeSketch(KllSketchAggregationState state, Type type, long k)
{
if (state.getSketch() != null) {
return;
}

if (k < 8 || k > MAX_K) {
throw new PrestoException(INVALID_ARGUMENTS, format("k value must satisfy 8 <= k <= %d: %d", MAX_K, k));
}
if (state.getSketch() == null) {
KllItemsSketch<T> sketch = KllItemsSketch.newHeapInstance((int) k, comparator.get(), serdeSupplier.get());
state.setSketch(sketch);
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, state.getType().getJavaType()));
}

KllSketchAggregationState.SketchParameters parameters = KllSketchAggregationState.getSketchParameters(type);
KllItemsSketch sketch = KllItemsSketch.newHeapInstance((int) k, parameters.getComparator(), parameters.getSerde());

state.setSketch(sketch);
state.setConversion(parameters.getConversion());
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, state.getType().getJavaType()));
}
}
Loading
Loading