Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] improve frequent items runtime #93255

Merged
merged 3 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/93255.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 93255
summary: Improve frequent items runtime
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public abstract class ItemSetMapReduceAggregator<
ReduceContext extends Closeable,
Result extends ToXContent & Writeable> extends AggregatorBase {

private final List<ItemSetMapReduceValueSource> extractors;
private final List<ItemSetMapReduceValueSource> valueSources;
private final Weight weightDocumentFilter;
private final List<Field> fields;
private final AbstractItemSetMapReducer<MapContext, MapFinalContext, ReduceContext, Result> mapReducer;
Expand All @@ -69,7 +69,7 @@ protected ItemSetMapReduceAggregator(
) throws IOException {
super(name, AggregatorFactories.EMPTY, context, parent, CardinalityUpperBound.NONE, metadata);

List<ItemSetMapReduceValueSource> extractors = new ArrayList<>();
List<ItemSetMapReduceValueSource> valueSources = new ArrayList<>();
List<Field> fields = new ArrayList<>();
IndexSearcher contextSearcher = context.searcher();

Expand All @@ -84,11 +84,11 @@ protected ItemSetMapReduceAggregator(
.build(c.v1(), id++, c.v2());
if (e.getField().getName() != null) {
fields.add(e.getField());
extractors.add(e);
valueSources.add(e);
}
}

this.extractors = Collections.unmodifiableList(extractors);
this.valueSources = Collections.unmodifiableList(valueSources);
this.fields = Collections.unmodifiableList(fields);
this.mapReducer = mapReducer;
this.profiling = context.profiling();
Expand Down Expand Up @@ -126,14 +126,19 @@ protected LeafBucketCollector getLeafCollector(AggregationExecutionContext ctx,
)
: null;

List<ItemSetMapReduceValueSource.ValueCollector> valueCollectors = new ArrayList<>(valueSources.size());
for (ItemSetMapReduceValueSource valueSource : valueSources) {
valueCollectors.add(valueSource.getValueCollector(ctx.getLeafReaderContext()));
}

return new LeafBucketCollectorBase(sub, null) {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
SetOnce<IOException> firstException = new SetOnce<>();
if (bits == null || bits.get(doc)) {
mapReducer.map(extractors.stream().map(extractor -> {
mapReducer.map(valueCollectors.stream().map(c -> {
try {
return extractor.collect(ctx.getLeafReaderContext(), doc);
return c.collect(doc);
} catch (IOException e) {
firstException.trySet(e);
// ignored in AbstractMapReducer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,21 @@
*/
public abstract class ItemSetMapReduceValueSource {

/**
* Interface to hook value collection into the {@link org.elasticsearch.search.aggregations.support.ValuesSourceRegistry}
*/
@FunctionalInterface
public interface ValueSourceSupplier {
ItemSetMapReduceValueSource build(ValuesSourceConfig config, int id, IncludeExclude includeExclude);
}

/**
* Internal interface for collecting values
*/
interface ValueCollector {
Tuple<Field, List<Object>> collect(int doc) throws IOException;
}

enum ValueFormatter {
BYTES_REF {
@Override
Expand Down Expand Up @@ -125,7 +135,7 @@ public int hashCode() {

private final Field field;

abstract Tuple<Field, List<Object>> collect(LeafReaderContext ctx, int doc) throws IOException;
abstract ValueCollector getValueCollector(LeafReaderContext ctx) throws IOException;

ItemSetMapReduceValueSource(ValuesSourceConfig config, int id, ValueFormatter valueFormatter) {
String fieldName = config.fieldContext() != null ? config.fieldContext().field() : null;
Expand All @@ -151,22 +161,31 @@ public KeywordValueSource(ValuesSourceConfig config, int id, IncludeExclude incl
}

@Override
public Tuple<Field, List<Object>> collect(LeafReaderContext ctx, int doc) throws IOException {
SortedBinaryDocValues values = source.bytesValues(ctx);
ValueCollector getValueCollector(LeafReaderContext ctx) throws IOException {
final SortedBinaryDocValues values = source.bytesValues(ctx);
final Field field = getField();
final Tuple<Field, List<Object>> empty = new Tuple<>(field, Collections.emptyList());

return doc -> {
if (values.advanceExact(doc)) {
int valuesCount = values.docValueCount();

if (valuesCount == 0) {
return empty;
}

if (values.advanceExact(doc)) {
int valuesCount = values.docValueCount();
List<Object> objects = new ArrayList<>(valuesCount);
List<Object> objects = new ArrayList<>(valuesCount);

for (int i = 0; i < valuesCount; ++i) {
BytesRef v = values.nextValue();
if (stringFilter == null || stringFilter.accept(v)) {
objects.add(BytesRef.deepCopyOf(v));
for (int i = 0; i < valuesCount; ++i) {
BytesRef v = values.nextValue();
if (stringFilter == null || stringFilter.accept(v)) {
objects.add(BytesRef.deepCopyOf(v));
}
}
return new Tuple<>(field, objects);
}
return new Tuple<>(getField(), objects);
}
return new Tuple<>(getField(), Collections.emptyList());
return empty;
};
}

}
Expand All @@ -182,22 +201,31 @@ public NumericValueSource(ValuesSourceConfig config, int id, IncludeExclude incl
}

@Override
public Tuple<Field, List<Object>> collect(LeafReaderContext ctx, int doc) throws IOException {
SortedNumericDocValues values = source.longValues(ctx);
ValueCollector getValueCollector(LeafReaderContext ctx) throws IOException {
final SortedNumericDocValues values = source.longValues(ctx);
final Field field = getField();
final Tuple<Field, List<Object>> empty = new Tuple<>(field, Collections.emptyList());

if (values.advanceExact(doc)) {
int valuesCount = values.docValueCount();
List<Object> objects = new ArrayList<>(valuesCount);
return doc -> {
if (values.advanceExact(doc)) {
int valuesCount = values.docValueCount();

for (int i = 0; i < valuesCount; ++i) {
long v = values.nextValue();
if (longFilter == null || longFilter.accept(v)) {
objects.add(v);
if (valuesCount == 0) {
return empty;
}

List<Object> objects = new ArrayList<>(valuesCount);

for (int i = 0; i < valuesCount; ++i) {
long v = values.nextValue();
if (longFilter == null || longFilter.accept(v)) {
objects.add(v);
}
}
return new Tuple<>(field, objects);
}
return new Tuple<>(getField(), objects);
}
return new Tuple<>(getField(), Collections.emptyList());
return empty;
};
}

}
Expand Down