Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BKD Optimization to Range aggregation #47712

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,19 @@ public Number parsePoint(byte[] value) {
return HalfFloatPoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
float parsedValue = parse(value, coerce);
byte[] bytes = new byte[Integer.BYTES];
HalfFloatPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return HalfFloatPoint.BYTES;
}

@Override
public Float parse(XContentParser parser, boolean coerce) throws IOException {
float parsed = parser.floatValue(coerce);
Expand Down Expand Up @@ -290,6 +303,19 @@ public Number parsePoint(byte[] value) {
return FloatPoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
float parsedValue = parse(value, coerce);
byte[] bytes = new byte[Integer.BYTES];
FloatPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Integer.BYTES;
}

@Override
public Float parse(XContentParser parser, boolean coerce) throws IOException {
float parsed = parser.floatValue(coerce);
Expand Down Expand Up @@ -376,6 +402,19 @@ public Number parsePoint(byte[] value) {
return DoublePoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
double parsedValue = parse(value, coerce);
byte[] bytes = new byte[Long.BYTES];
DoublePoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Long.BYTES;
}

@Override
public Double parse(XContentParser parser, boolean coerce) throws IOException {
double parsed = parser.doubleValue(coerce);
Expand Down Expand Up @@ -473,6 +512,21 @@ public Number parsePoint(byte[] value) {
return INTEGER.parsePoint(value).byteValue();
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
int parsedValue = parse(value, coerce);

// Same as integer
byte[] bytes = new byte[Integer.BYTES];
IntPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Integer.BYTES;
}

@Override
public Short parse(XContentParser parser, boolean coerce) throws IOException {
int value = parser.intValue(coerce);
Expand Down Expand Up @@ -534,6 +588,21 @@ public Number parsePoint(byte[] value) {
return INTEGER.parsePoint(value).shortValue();
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
int parsedValue = parse(value, coerce);

// Same as integer
byte[] bytes = new byte[Integer.BYTES];
IntPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Integer.BYTES;
}

@Override
public Short parse(XContentParser parser, boolean coerce) throws IOException {
return parser.shortValue(coerce);
Expand Down Expand Up @@ -591,6 +660,19 @@ public Number parsePoint(byte[] value) {
return IntPoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
int parsedValue = parse(value, coerce);
byte[] bytes = new byte[Integer.BYTES];
IntPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Integer.BYTES;
}

@Override
public Integer parse(XContentParser parser, boolean coerce) throws IOException {
return parser.intValue(coerce);
Expand Down Expand Up @@ -710,6 +792,19 @@ public Number parsePoint(byte[] value) {
return LongPoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
long parsedValue = parse(value, coerce);
byte[] bytes = new byte[Long.BYTES];
LongPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Long.BYTES;
}

@Override
public Long parse(XContentParser parser, boolean coerce) throws IOException {
return parser.longValue(coerce);
Expand Down Expand Up @@ -827,6 +922,8 @@ public abstract Query rangeQuery(String field, Object lowerTerm, Object upperTer
public abstract Number parse(XContentParser parser, boolean coerce) throws IOException;
public abstract Number parse(Object value, boolean coerce);
public abstract Number parsePoint(byte[] value);
public abstract byte[] encodePoint(Number value, boolean coerce);
public abstract int bytesPerEncodedPoint();
public abstract List<Field> createFields(String name, Number value, boolean indexed,
boolean docValued, boolean stored);
Number valueForSearch(Number value) {
Expand Down Expand Up @@ -979,6 +1076,14 @@ public Number parsePoint(byte[] value) {
return type.parsePoint(value);
}

public byte[] encodePoint(Number value, boolean coerce) {
return type.encodePoint(value, coerce);
}

public int bytesPerEncodedPoint() {
return type.bytesPerEncodedPoint();
}

@Override
public boolean equals(Object o) {
if (super.equals(o) == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,26 @@ public AggParseContext(String name) {

public static final AggregatorFactories EMPTY = new AggregatorFactories(new AggregatorFactory[0], new ArrayList<>());

private AggregatorFactory[] factories;
protected AggregatorFactory[] factories;
private List<PipelineAggregationBuilder> pipelineAggregatorFactories;

public static Builder builder() {
return new Builder();
}

private AggregatorFactories(AggregatorFactory[] factories, List<PipelineAggregationBuilder> pipelineAggregators) {
protected AggregatorFactories(AggregatorFactory[] factories, List<PipelineAggregationBuilder> pipelineAggregators) {
this.factories = factories;
this.pipelineAggregatorFactories = pipelineAggregators;
}

public AggregatorFactory[] getFactories() {
return factories;
}

public List<PipelineAggregationBuilder> getPipelineAggregatorFactories() {
return pipelineAggregatorFactories;
}

public List<PipelineAggregator> createPipelineAggregators() {
List<PipelineAggregator> pipelineAggregators = new ArrayList<>(this.pipelineAggregatorFactories.size());
for (PipelineAggregationBuilder factory : this.pipelineAggregatorFactories) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@
public abstract class AggregatorFactory {

public static final class MultiBucketAggregatorWrapper extends Aggregator {
private final BigArrays bigArrays;
private final Aggregator parent;
private final AggregatorFactory factory;
protected final BigArrays bigArrays;
protected final AggregatorFactory factory;
protected ObjectArray<Aggregator> aggregators;
protected ObjectArray<LeafBucketCollector> collectors;
protected final Aggregator parent;
private final Aggregator first;
ObjectArray<Aggregator> aggregators;
ObjectArray<LeafBucketCollector> collectors;

MultiBucketAggregatorWrapper(BigArrays bigArrays, SearchContext context,
Aggregator parent, AggregatorFactory factory, Aggregator first) {
Aggregator parent, AggregatorFactory factory, Aggregator first) {
this.bigArrays = bigArrays;
this.parent = parent;
this.factory = factory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

package org.elasticsearch.search.aggregations.bucket.range;

import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.elasticsearch.index.mapper.DateFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
Expand All @@ -31,26 +36,29 @@
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.profile.aggregation.ProfilingAggregator;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;

public class AbstractRangeAggregatorFactory<R extends Range> extends ValuesSourceAggregatorFactory<ValuesSource.Numeric> {

private final InternalRange.Factory<?, ?> rangeFactory;
private final R[] ranges;
private final boolean keyed;

public AbstractRangeAggregatorFactory(String name,
ValuesSourceConfig<Numeric> config,
R[] ranges,
boolean keyed,
InternalRange.Factory<?, ?> rangeFactory,
QueryShardContext queryShardContext,
AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
AbstractRangeAggregatorFactory(String name,
ValuesSourceConfig<Numeric> config,
R[] ranges,
boolean keyed,
InternalRange.Factory<?, ?> rangeFactory,
QueryShardContext queryShardContext,
AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
super(name, config, queryShardContext, parent, subFactoriesBuilder, metaData);
this.ranges = ranges;
this.keyed = keyed;
Expand All @@ -59,22 +67,92 @@ public AbstractRangeAggregatorFactory(String name,

@Override
protected Aggregator createUnmapped(SearchContext searchContext,
Aggregator parent,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
Aggregator parent,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
return new Unmapped<>(name, ranges, keyed, config.format(), searchContext, parent, rangeFactory, pipelineAggregators, metaData);
}

@Override
protected Aggregator doCreateInternal(Numeric valuesSource,
SearchContext searchContext,
Aggregator parent,
boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
return new RangeAggregator(name, factories, valuesSource, config.format(), rangeFactory, ranges, keyed, searchContext, parent,
pipelineAggregators, metaData);
SearchContext searchContext,
Aggregator parent,
boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {

// If we don't have a parent, the range agg can potentially optimize by using the BKD tree. But BKD
// traversal is per-range, which means that docs are potentially called out-of-order across multiple
// ranges. To prevent this from causing problems, we create a special AggregatorFactories that
// wraps all the sub-aggs with a MultiBucketAggregatorWrapper. This effectively creates a new agg
// sub-tree for each range and prevents out-of-order problems
BiFunction<Number, Boolean, byte[]> pointEncoder = configurePointEncoder(searchContext, parent, config);
AggregatorFactories wrappedFactories = factories;
if (pointEncoder != null) {
wrappedFactories = wrapSubAggsAsMultiBucket(factories);
}
polyfractal marked this conversation as resolved.
Show resolved Hide resolved

return new RangeAggregator(name, wrappedFactories, valuesSource, config, rangeFactory, ranges, keyed, searchContext, parent,
pipelineAggregators, metaData, pointEncoder);
}

/**
* Returns a converter for point values if BKD optimization is applicable to
* the context or <code>null</code> otherwise. Optimization criteria is:
* - Match_all query
* - no parent agg
* - no script
* - no missing value
* - has indexed points
*
* @param context The {@link SearchContext} of the aggregation.
* @param parent The parent aggregator.
* @param config The config for the values source metric.
*/
private BiFunction<Number, Boolean, byte[]> configurePointEncoder(SearchContext context, Aggregator parent,
ValuesSourceConfig<?> config) {
if (context.query() != null &&
context.query().getClass() != MatchAllDocsQuery.class) {
return null;
}
if (parent != null) {
return null;
}
if (config.fieldContext() != null && config.script() == null && config.missing() == null) {
MappedFieldType fieldType = config.fieldContext().fieldType();
if (fieldType == null || fieldType.indexOptions() == IndexOptions.NONE) {
return null;
}
if (fieldType instanceof NumberFieldMapper.NumberFieldType) {
return ((NumberFieldMapper.NumberFieldType) fieldType)::encodePoint;
} else if (fieldType.getClass() == DateFieldMapper.DateFieldType.class) {
return NumberFieldMapper.NumberType.LONG::encodePoint;
}
}
return null;
}

/**
* Creates a new{@link AggregatorFactories} object so that sub-aggs are automatically
* wrapped with a {@link org.elasticsearch.search.aggregations.AggregatorFactory.MultiBucketAggregatorWrapper}.
* This allows sub-aggs to execute in their own isolated sub tree
*/
private static AggregatorFactories wrapSubAggsAsMultiBucket(AggregatorFactories factories) {
return new AggregatorFactories(factories.getFactories(), factories.getPipelineAggregatorFactories()) {
@Override
public Aggregator[] createSubAggregators(SearchContext searchContext, Aggregator parent) throws IOException {
Aggregator[] aggregators = new Aggregator[countAggregators()];
for (int i = 0; i < this.factories.length; ++i) {
Aggregator factory = asMultiBucketAggregator(factories[i], searchContext, parent);
Profilers profilers = factory.context().getProfilers();
if (profilers != null) {
factory = new ProfilingAggregator(factory, profilers.getAggregationProfiler());
}
aggregators[i] = factory;
}
return aggregators;
}
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ protected Aggregator doCreateInternal(final ValuesSource.GeoPoint valuesSource,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
DistanceSource distanceSource = new DistanceSource(valuesSource, distanceType, origin, unit);
return new RangeAggregator(name, factories, distanceSource, config.format(), rangeFactory, ranges, keyed, searchContext,
parent,
pipelineAggregators, metaData);
return new RangeAggregator(name, factories, distanceSource, config, rangeFactory, ranges, keyed, searchContext,
parent, pipelineAggregators, metaData, null);
}

private static class DistanceSource extends ValuesSource.Numeric {
Expand Down
Loading