Skip to content

Commit

Permalink
add rrf retriever parsing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jdconrad committed Feb 29, 2024
1 parent cdf1b5d commit baf0c0c
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ public KnnRetrieverBuilder(

// ---- FOR TESTING XCONTENT PARSING ----

@Override
public String getName() {
return NAME;
}

@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(field, queryVector, queryVectorBuilder, k, numCands, similarity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ public List<QueryBuilder> getPreFilterQueryBuilders() {

// ---- FOR TESTING XCONTENT PARSING ----

public abstract String getName();

@Override
public final XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
Expand All @@ -199,7 +201,7 @@ public final XContentBuilder toXContent(XContentBuilder builder, ToXContent.Para
return builder;
}

public abstract void doToXContent(XContentBuilder builder, ToXContent.Params params) throws IOException;
protected abstract void doToXContent(XContentBuilder builder, ToXContent.Params params) throws IOException;

@Override
public boolean isFragment() {
Expand All @@ -214,14 +216,14 @@ public final boolean equals(Object o) {
return Objects.equals(preFilterQueryBuilders, that.preFilterQueryBuilders) && doEquals(o);
}

public abstract boolean doEquals(Object o);
protected abstract boolean doEquals(Object o);

@Override
public final int hashCode() {
return Objects.hash(getClass(), preFilterQueryBuilders, doHashCode());
}

public abstract int doHashCode();
protected abstract int doHashCode();

@Override
public String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder

// ---- FOR TESTING XCONTENT PARSING ----

@Override
public String getName() {
return NAME;
}

@Override
public void doToXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
if (queryBuilder != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,29 @@
package org.elasticsearch.search.retriever;

import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;

import static org.elasticsearch.search.vectors.KnnSearchBuilderTests.randomVector;

public class KnnRetrieverParsingTests extends AbstractXContentTestCase<KnnRetrieverBuilder> {
public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase<KnnRetrieverBuilder> {

/**
* Creates a random {@link KnnRetrieverBuilder}. The created instance
* is not guaranteed to pass {@link SearchRequest} validation. This is purely
* for x-content testing.
*/
public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder(BiFunction<XContent, BytesReference, XContentParser> createParser) {
public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() {
String field = randomAlphaOfLength(6);
int dim = randomIntBetween(2, 30);
float[] vector = randomBoolean() ? null : randomVector(dim);
Expand All @@ -60,13 +56,7 @@ public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder(BiFunction<XCo

@Override
protected KnnRetrieverBuilder createTestInstance() {
return createRandomKnnRetrieverBuilder((xContent, data) -> {
try {
return createParser(xContent, data);
} catch (IOException ioe) {
throw new UncheckedIOException(ioe);
}
});
return createRandomKnnRetrieverBuilder();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import java.util.List;
import java.util.function.BiFunction;

public class StandardRetrieverParsingTests extends AbstractXContentTestCase<StandardRetrieverBuilder> {
public class StandardRetrieverBuilderParsingTests extends AbstractXContentTestCase<StandardRetrieverBuilder> {

/**
* Creates a random {@link StandardRetrieverBuilder}. The created instance
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.retriever;

import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

/**
* Test retriever is used to test parsing of retrievers in plugins where
* generation of other random retrievers are not easily accessible through test code.
*/
public class TestRetrieverBuilder extends RetrieverBuilder {

/**
* Creates a random {@link TestRetrieverBuilder}. The created instance
* is not guaranteed to pass {@link SearchRequest} validation. This is purely
* for x-content testing.
*/
public static TestRetrieverBuilder createRandomTestRetrieverBuilder() {
return new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10));
}

public static final String NAME = "test";
public static final ParseField TEST_FIELD = new ParseField(NAME);
public static final SearchPlugin.RetrieverSpec<TestRetrieverBuilder> TEST_SPEC = new SearchPlugin.RetrieverSpec<>(
TEST_FIELD,
TestRetrieverBuilder::fromXContent
);

public static final ParseField VALUE_FIELD = new ParseField("value");

public static final ConstructingObjectParser<TestRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
NAME,
args -> new TestRetrieverBuilder((String) args[0])
);

static {
PARSER.declareString(constructorArg(), VALUE_FIELD);
}

public static TestRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) {
return PARSER.apply(parser, context);
}

private final String value;

public TestRetrieverBuilder(String value) {
this.value = value;
}

@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
throw new UnsupportedOperationException("only used for parsing tests");
}

@Override
public String getName() {
return NAME;
}

@Override
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(VALUE_FIELD.getPreferredName(), value);
}

@Override
public boolean doEquals(Object o) {
TestRetrieverBuilder that = (TestRetrieverBuilder) o;
return Objects.equals(value, that.value);
}

@Override
public int doHashCode() {
return Objects.hash(value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.rank.rrf.RRFRankPlugin.NAME;

/**
* An rrf retriever is used to represent an rrf rank element, but
* as a tree-like structure. This retriever is a compound retriever
Expand All @@ -40,7 +42,7 @@ public final class RRFRetrieverBuilder extends RetrieverBuilder {
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");

public static final ObjectParser<RRFRetrieverBuilder, RetrieverParserContext> PARSER = new ObjectParser<>(
RRFRankPlugin.NAME,
NAME,
RRFRetrieverBuilder::new
);

Expand All @@ -55,22 +57,22 @@ public final class RRFRetrieverBuilder extends RetrieverBuilder {
PARSER.declareInt((r, v) -> r.windowSize = v, WINDOW_SIZE_FIELD);
PARSER.declareInt((r, v) -> r.rankConstant = v, RANK_CONSTANT_FIELD);

RetrieverBuilder.declareBaseParserFields(RRFRankPlugin.NAME, PARSER);
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
}

public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
if (context.clusterSupportsFeature(RRF_RETRIEVER_SUPPORTED) == false) {
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + RRFRankPlugin.NAME + "]");
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]");
}
if (RRFRankPlugin.RANK_RRF_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) {
throw LicenseUtils.newComplianceException("Reciprocal Rank Fusion (RRF)");
}
return PARSER.apply(parser, context);
}

private List<RetrieverBuilder> retrieverBuilders = Collections.emptyList();
private int windowSize = RRFRankBuilder.DEFAULT_WINDOW_SIZE;
private int rankConstant = RRFRankBuilder.DEFAULT_RANK_CONSTANT;
List<RetrieverBuilder> retrieverBuilders = Collections.emptyList();
int windowSize = RRFRankBuilder.DEFAULT_WINDOW_SIZE;
int rankConstant = RRFRankBuilder.DEFAULT_RANK_CONSTANT;

@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
Expand All @@ -91,10 +93,24 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder

// ---- FOR TESTING XCONTENT PARSING ----

@Override
public String getName() {
return NAME;
}

@Override
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
if (retrieverBuilders.isEmpty() == false) {
builder.field(RETRIEVERS_FIELD.getPreferredName(), retrieverBuilders);
builder.startArray(RETRIEVERS_FIELD.getPreferredName());

for (RetrieverBuilder retrieverBuilder : retrieverBuilders) {
builder.startObject();
builder.field(retrieverBuilder.getName());
retrieverBuilder.toXContent(builder, params);
builder.endObject();
}

builder.endArray();
}

builder.field(WINDOW_SIZE_FIELD.getPreferredName(), windowSize);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.xpack.rank.rrf;

import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RRFRetrieverBuilder> {

/**
* Creates a random {@link RRFRetrieverBuilder}. The created instance
* is not guaranteed to pass {@link SearchRequest} validation. This is purely
* for x-content testing.
*/
public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() {
RRFRetrieverBuilder rrfRetrieverBuilder = new RRFRetrieverBuilder();

if (randomBoolean()) {
rrfRetrieverBuilder.windowSize = randomIntBetween(1, 10000);
}

if (randomBoolean()) {
rrfRetrieverBuilder.rankConstant = randomIntBetween(1, 1000000);
}

int retrieverCount = randomIntBetween(2, 50);
rrfRetrieverBuilder.retrieverBuilders = new ArrayList<>(retrieverCount);

while (retrieverCount > 0) {
rrfRetrieverBuilder.retrieverBuilders.add(TestRetrieverBuilder.createRandomTestRetrieverBuilder());
--retrieverCount;
}

return rrfRetrieverBuilder;
}

@Override
protected RRFRetrieverBuilder createTestInstance() {
return createRandomRRFRetrieverBuilder();
}

@Override
protected RRFRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
return RRFRetrieverBuilder.PARSER.apply(parser, new RetrieverParserContext(new SearchUsage(), nf -> true));
}

@Override
protected boolean supportsUnknownFields() {
return false;
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> entries = new ArrayList<>();
entries.add(
new NamedXContentRegistry.Entry(
RetrieverBuilder.class,
TestRetrieverBuilder.TEST_SPEC.getName(),
(p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c),
TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion()
)
);
entries.add(
new NamedXContentRegistry.Entry(
RetrieverBuilder.class,
new ParseField(RRFRankPlugin.NAME),
(p, c) -> RRFRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
)
);
return new NamedXContentRegistry(entries);
}
}

0 comments on commit baf0c0c

Please sign in to comment.