Skip to content

Commit

Permalink
Added support for msearch API to pass search pipeline name (#15923)
Browse files Browse the repository at this point in the history
* Added support for search pipeline name in multi search API

Signed-off-by: Owais <owaiskazi19@gmail.com>

* Updated CHANGELOG

Signed-off-by: Owais <owaiskazi19@gmail.com>

* Pulled search pipeline in MultiSearchRequest and updated test

Signed-off-by: Owais <owaiskazi19@gmail.com>

* Updated test

Signed-off-by: Owais <owaiskazi19@gmail.com>

* Updated SearchRequest with search pipeline from source

Signed-off-by: Owais <owaiskazi19@gmail.com>

* Added tests for parseSearchRequest

Signed-off-by: Owais <owaiskazi19@gmail.com>

* Guard serialization with version check

Signed-off-by: Owais <owaiskazi19@gmail.com>

* Updated version and added another test for serialization

Signed-off-by: Owais <owaiskazi19@gmail.com>

---------

Signed-off-by: Owais <owaiskazi19@gmail.com>
  • Loading branch information
owaiskazi19 authored Sep 26, 2024
1 parent a42e51d commit daf1669
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Implement WithFieldName interface in ValuesSourceAggregationBuilder & FieldSortBuilder ([#15916](https://github.com/opensearch-project/OpenSearch/pull/15916))
- Add successfulSearchShardIndices in searchRequestContext ([#15967](https://github.com/opensearch-project/OpenSearch/pull/15967))
- Remove identity-related feature flagged code from the RestController ([#15430](https://github.com/opensearch-project/OpenSearch/pull/15430))
- Add support for msearch API to pass search pipeline name - ([#15923](https://github.com/opensearch-project/OpenSearch/pull/15923))

### Dependencies
- Bump `com.azure:azure-identity` from 1.13.0 to 1.13.2 ([#15578](https://github.com/opensearch-project/OpenSearch/pull/15578))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ public static void readMultiLineFormat(
) {
consumer.accept(searchRequest, parser);
}

if (searchRequest.source() != null && searchRequest.source().pipeline() != null) {
searchRequest.pipeline(searchRequest.source().pipeline());
}
// move pointers
from = nextMarker + 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ public static void parseSearchRequest(
searchRequest.routing(request.param("routing"));
searchRequest.preference(request.param("preference"));
searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions()));
searchRequest.pipeline(request.param("search_pipeline"));
searchRequest.pipeline(request.param("search_pipeline", searchRequest.source().pipeline()));

checkRestTotalHits(request, searchRequest);
request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ public static HighlightBuilder highlight() {

private Map<String, Object> searchPipelineSource = null;

private String searchPipeline;

/**
* Constructs a new search source builder.
*/
Expand Down Expand Up @@ -297,6 +299,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException {
derivedFields = in.readList(DerivedField::new);
}
}
if (in.getVersion().onOrAfter(Version.V_3_0_0)) {
searchPipeline = in.readOptionalString();
}
}

@Override
Expand Down Expand Up @@ -377,6 +382,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeList(derivedFields);
}
}
if (out.getVersion().onOrAfter(Version.V_3_0_0)) {
out.writeOptionalString(searchPipeline);
}
}

/**
Expand Down Expand Up @@ -1111,6 +1119,13 @@ public Map<String, Object> searchPipelineSource() {
return searchPipelineSource;
}

/**
* @return a search pipeline name defined within the search source (see {@link org.opensearch.search.pipeline.SearchPipelineService})
*/
public String pipeline() {
return searchPipeline;
}

/**
* Define a search pipeline to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}.
*/
Expand All @@ -1119,6 +1134,14 @@ public SearchSourceBuilder searchPipelineSource(Map<String, Object> searchPipeli
return this;
}

/**
* Define a search pipeline name to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}.
*/
public SearchSourceBuilder pipeline(String searchPipeline) {
this.searchPipeline = searchPipeline;
return this;
}

/**
* Rewrites this search source builder into its primitive form. e.g. by
* rewriting the QueryBuilder. If the builder did not change the identity
Expand Down Expand Up @@ -1216,6 +1239,7 @@ private SearchSourceBuilder shallowCopy(
rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder;
rewrittenBuilder.derivedFieldsObject = derivedFieldsObject;
rewrittenBuilder.derivedFields = derivedFields;
rewrittenBuilder.searchPipeline = searchPipeline;
return rewrittenBuilder;
}

Expand Down Expand Up @@ -1283,6 +1307,8 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th
sort(parser.text());
} else if (PROFILE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
profile = parser.booleanValue();
} else if (SEARCH_PIPELINE.match(currentFieldName, parser.getDeprecationHandler())) {
searchPipeline = parser.text();
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -1612,6 +1638,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t

}

if (searchPipeline != null) {
builder.field(SEARCH_PIPELINE.getPreferredName(), searchPipeline);
}

return builder;
}

Expand Down Expand Up @@ -1889,7 +1919,8 @@ public int hashCode() {
trackTotalHitsUpTo,
pointInTimeBuilder,
derivedFieldsObject,
derivedFields
derivedFields,
searchPipeline
);
}

Expand Down Expand Up @@ -1934,7 +1965,8 @@ public boolean equals(Object obj) {
&& Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo)
&& Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder)
&& Objects.equals(derivedFieldsObject, other.derivedFieldsObject)
&& Objects.equals(derivedFields, other.derivedFields);
&& Objects.equals(derivedFields, other.derivedFields)
&& Objects.equals(searchPipeline, other.searchPipeline);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
import org.opensearch.geometry.LinearRing;
import org.opensearch.index.query.GeoShapeQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.search.RestSearchAction;
import org.opensearch.search.AbstractSearchTestCase;
import org.opensearch.search.Scroll;
import org.opensearch.search.builder.PointInTimeBuilder;
Expand All @@ -50,14 +52,18 @@
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.VersionUtils;
import org.opensearch.test.rest.FakeRestRequest;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.IntConsumer;

import static java.util.Collections.emptyMap;
import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
import static org.opensearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;

public class SearchRequestTests extends AbstractSearchTestCase {

Expand Down Expand Up @@ -242,6 +248,19 @@ public void testCopyConstructor() throws IOException {
assertNotSame(deserializedRequest, searchRequest);
}

public void testParseSearchRequestWithUnsupportedSearchType() throws IOException {
RestRequest restRequest = new FakeRestRequest();
SearchRequest searchRequest = createSearchRequest();
IntConsumer setSize = mock(IntConsumer.class);
restRequest.params().put("search_type", "query_and_fetch");

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize)
);
assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage());
}

public void testEqualsAndHashcode() throws IOException {
checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate);
}
Expand All @@ -268,10 +287,7 @@ private SearchRequest mutate(SearchRequest searchRequest) {
);
mutators.add(
() -> mutation.searchType(
randomValueOtherThan(
searchRequest.searchType(),
() -> randomFrom(SearchType.DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH)
)
randomValueOtherThan(searchRequest.searchType(), () -> randomFrom(DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH))
)
);
mutators.add(() -> mutation.source(randomValueOtherThan(searchRequest.source(), this::createSearchSourceBuilder)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,27 @@ public void testDerivedFieldsParsingAndSerializationObjectType() throws IOExcept
}
}

public void testSearchPipelineParsingAndSerialization() throws IOException {
String restContent = "{ \"query\": { \"match_all\": {} }, \"from\": 0, \"size\": 10, \"search_pipeline\": \"my_pipeline\" }";
String expectedContent = "{\"from\":0,\"size\":10,\"query\":{\"match_all\":{\"boost\":1.0}},\"search_pipeline\":\"my_pipeline\"}";

try (XContentParser parser = createParser(JsonXContent.jsonXContent, restContent)) {
SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.fromXContent(parser);
searchSourceBuilder = rewrite(searchSourceBuilder);

try (BytesStreamOutput output = new BytesStreamOutput()) {
searchSourceBuilder.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry)) {
SearchSourceBuilder deserializedBuilder = new SearchSourceBuilder(in);
String actualContent = deserializedBuilder.toString();
assertEquals(expectedContent, actualContent);
assertEquals(searchSourceBuilder.hashCode(), deserializedBuilder.hashCode());
assertNotSame(searchSourceBuilder, deserializedBuilder);
}
}
}
}

public void testAggsParsing() throws IOException {
{
String restContent = "{\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,64 @@ public void testInlinePipeline() throws Exception {
}
}

public void testInlineDefinedPipeline() throws Exception {
SearchPipelineService searchPipelineService = createWithProcessors();

SearchPipelineMetadata metadata = new SearchPipelineMetadata(
Map.of(
"p1",
new PipelineConfiguration(
"p1",
new BytesArray(
"{"
+ "\"request_processors\": [{ \"scale_request_size\": { \"scale\" : 2 } }],"
+ "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]"
+ "}"
),
MediaTypeRegistry.JSON
)
)

);
ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build();
ClusterState previousState = clusterState;
clusterState = ClusterState.builder(clusterState)
.metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata))
.build();
searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState));

SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().size(100).pipeline("p1");
SearchRequest searchRequest = new SearchRequest().source(sourceBuilder);
searchRequest.pipeline(searchRequest.source().pipeline());

// Verify pipeline
PipelinedRequest pipelinedRequest = syncTransformRequest(
searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver)
);
Pipeline pipeline = pipelinedRequest.getPipeline();
assertEquals("p1", pipeline.getId());
assertEquals(1, pipeline.getSearchRequestProcessors().size());
assertEquals(1, pipeline.getSearchResponseProcessors().size());

// Verify that pipeline transforms request
assertEquals(200, pipelinedRequest.source().size());

int size = 10;
SearchHit[] hits = new SearchHit[size];
for (int i = 0; i < size; i++) {
hits[i] = new SearchHit(i, "doc" + i, Collections.emptyMap(), Collections.emptyMap());
hits[i].score(i);
}
SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size);
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);

SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse);
for (int i = 0; i < size; i++) {
assertEquals(2.0, transformedResponse.getHits().getHits()[i].getScore(), 0.0001);
}
}

public void testInfo() {
SearchPipelineService searchPipelineService = createWithProcessors();
SearchPipelineInfo info = searchPipelineService.info();
Expand Down

0 comments on commit daf1669

Please sign in to comment.