Skip to content

Commit

Permalink
range aggs draft changes
Browse files Browse the repository at this point in the history
Signed-off-by: Sandesh Kumar <sandeshkr419@gmail.com>
  • Loading branch information
sandeshkr419 committed Feb 6, 2025
1 parent b82ed5a commit 4e15156
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
package org.opensearch.search.aggregations.bucket.range;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -43,7 +45,13 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.compositeindex.datacube.MetricStat;
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeUtils;
import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
Expand All @@ -53,12 +61,17 @@
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.NonCollectingAggregator;
import org.opensearch.search.aggregations.StarTreeBucketCollector;
import org.opensearch.search.aggregations.StarTreePreComputeCollector;
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
import org.opensearch.search.aggregations.bucket.filterrewrite.FilterRewriteOptimizationContext;
import org.opensearch.search.aggregations.bucket.filterrewrite.RangeAggregatorBridge;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.StarTreeQueryHelper;
import org.opensearch.search.startree.StarTreeTraversalUtil;
import org.opensearch.search.startree.filter.DimensionFilter;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -69,17 +82,18 @@
import java.util.function.Function;

import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.opensearch.search.aggregations.bucket.filterrewrite.AggregatorBridge.segmentMatchAll;
import static org.opensearch.search.startree.StarTreeQueryHelper.getSupportedStarTree;

/**
* Aggregate all docs that match given ranges.
*
* @opensearch.internal
*/
public class RangeAggregator extends BucketsAggregator {
public class RangeAggregator extends BucketsAggregator implements StarTreePreComputeCollector {

public static final ParseField RANGES_FIELD = new ParseField("ranges");
public static final ParseField KEYED_FIELD = new ParseField("keyed");
public final String fieldName;

/**
* Range for the range aggregator
Expand Down Expand Up @@ -298,6 +312,9 @@ protected Function<Object, Long> bucketOrdProducer() {
}
};
filterRewriteOptimizationContext = new FilterRewriteOptimizationContext(bridge, parent, subAggregators.length, context);
this.fieldName = (valuesSource instanceof ValuesSource.Numeric.FieldData)
? ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName()
: null;
}

@Override
Expand All @@ -310,8 +327,15 @@ public ScoreMode scoreMode() {

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (segmentMatchAll(context, ctx)) {
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false);
// TODO: commenting out match all otpimization only for testing - will restore later
// if (segmentMatchAll(context, ctx)) {
// return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false);
// }
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (preComputeWithStarTree(ctx, supportedStarTree) == true) {
return true;
}
}
return false;
}
Expand Down Expand Up @@ -383,6 +407,141 @@ private int collect(int doc, double value, long owningBucketOrdinal, int lowBoun
};
}

private boolean preComputeWithStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
StarTreeBucketCollector starTreeBucketCollector = getStarTreeBucketCollector(ctx, starTree, null);
FixedBitSet matchingDocsBitSet = starTreeBucketCollector.getMatchingDocsBitSet();

int numBits = matchingDocsBitSet.length();

if (numBits > 0) {
for (int bit = matchingDocsBitSet.nextSetBit(0); bit != DocIdSetIterator.NO_MORE_DOCS; bit = (bit + 1 < numBits)
? matchingDocsBitSet.nextSetBit(bit + 1)
: DocIdSetIterator.NO_MORE_DOCS) {
starTreeBucketCollector.collectStarTreeEntry(bit, 0);
}
}
return true;
}

@Override
public StarTreeBucketCollector getStarTreeBucketCollector(
LeafReaderContext ctx,
CompositeIndexFieldInfo starTree,
StarTreeBucketCollector parentCollector
) throws IOException {
assert parentCollector == null;
StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree);
return new StarTreeBucketCollector(
starTreeValues,
StarTreeTraversalUtil.getStarTreeResult(
starTreeValues,
StarTreeQueryHelper.mergeDimensionFilterIfNotExists(
context.getQueryShardContext().getStarTreeQueryContext().getBaseQueryStarTreeFilter(),
fieldName,
List.of(DimensionFilter.MATCH_ALL_DEFAULT)
),
context
)
) {
@Override
public void setSubCollectors() throws IOException {
for (Aggregator aggregator : subAggregators) {
this.subCollectors.add(((StarTreePreComputeCollector) aggregator).getStarTreeBucketCollector(ctx, starTree, this));
}
}

SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues
.getDimensionValuesIterator(fieldName);

String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(
starTree.getField(),
"_doc_count",
MetricStat.DOC_COUNT.getTypeName()
);

SortedNumericStarTreeValuesIterator docCountsIterator = (SortedNumericStarTreeValuesIterator) starTreeValues
.getMetricValuesIterator(metricName);

@Override
public void collectStarTreeEntry(int starTreeEntry, long owningBucketOrd) throws IOException {
if (valuesIterator.advanceExact(starTreeEntry) == false) {
return;
}

for (int i = 0, count = valuesIterator.entryValueCount(); i < count; i++) {
long dimensionLongValue = valuesIterator.nextValue();
double dimensionValue;

// Only numeric & floating points are supported as of now in star-tree
// TODO: Add support for isBigInteger() when it gets supported in star-tree
if (valuesSource.isFloatingPoint()) {
dimensionValue = ((NumberFieldMapper.NumberFieldType) context.mapperService().fieldType(fieldName)).toDoubleValue(
dimensionLongValue
);
} else {
dimensionValue = dimensionLongValue;
}

// The core logic remains largely the same as the original collect method,
// but adapted for star-tree entry processing.
int lo = 0, hi = ranges.length - 1;
int mid = (lo + hi) >>> 1;

while (lo <= hi) {
if (dimensionValue < ranges[mid].from) {
hi = mid - 1;
} else if (dimensionValue >= maxTo[mid]) {
lo = mid + 1;
} else {
break;
}
mid = (lo + hi) >>> 1;
}

if (lo > hi) continue; // No matching range

// binary search the lower bound
int startLo = lo, startHi = mid;
while (startLo <= startHi) {
final int startMid = (startLo + startHi) >>> 1;
if (dimensionValue >= maxTo[startMid]) {
startLo = startMid + 1;
} else {
startHi = startMid - 1;
}
}

// binary search the upper bound
int endLo = mid, endHi = hi;
while (endLo <= endHi) {
final int endMid = (endLo + endHi) >>> 1;
if (dimensionValue < ranges[endMid].from) {
endHi = endMid - 1;
} else {
endLo = endMid + 1;
}
}

if (docCountsIterator.advanceExact(starTreeEntry)) {
long metricValue = docCountsIterator.nextValue();
for (int j = startLo; j <= endHi; ++j) {
if (ranges[j].matches(dimensionValue)) {
long bucketOrd = subBucketOrdinal(owningBucketOrd, j);
if (bucketOrd < 0) {
bucketOrd = -1 - bucketOrd;
collectStarTreeBucket(this, metricValue, bucketOrd, starTreeEntry);
} else {
grow(bucketOrd + 1);
collectStarTreeBucket(this, metricValue, bucketOrd, starTreeEntry);
}
}
}
}
}
}
};
}

private long subBucketOrdinal(long owningBucketOrdinal, int rangeOrd) {
return owningBucketOrdinal * ranges.length + rangeOrd;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregatorFactory;
import org.opensearch.search.aggregations.bucket.range.RangeAggregatorFactory;
import org.opensearch.search.aggregations.metrics.MetricAggregatorFactory;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.filter.StarTreeFilter;
Expand Down Expand Up @@ -113,6 +114,11 @@ public boolean consolidateAllFilters(SearchContext context) {
if (validateDateHistogramSupport(compositeMappedFieldType, aggregatorFactory)) {
continue;
}

// test for range aggregation
if (aggregatorFactory instanceof RangeAggregatorFactory) {
continue;
}
return false;
}

Expand Down

0 comments on commit 4e15156

Please sign in to comment.