Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validations from appsec #562

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524))
- Fix Flaky test reported in #433 ([#533](https://github.com/opensearch-project/neural-search/pull/533))
- Enable support for default model id on HybridQueryBuilder ([#541](https://github.com/opensearch-project/neural-search/pull/541))
- Add vaalidations for reranker requests per #555 ([#562](https://github.com/opensearch-project/neural-search/pull/562))
### Infrastructure
- BWC tests for Neural Search ([#515](https://github.com/opensearch-project/neural-search/pull/515))
- Github action to run integ tests in secure opensearch cluster ([#535](https://github.com/opensearch-project/neural-search/pull/535))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.plugin;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -145,7 +146,7 @@

@Override
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED);
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED, RERANKER_MAX_DOC_FIELDS);

Check warning on line 149 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L149

Added line #L149 was not covered by tests
}

@Override
Expand All @@ -159,7 +160,7 @@
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchResponseProcessor>> getResponseProcessors(
Parameters parameters
) {
return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor));
return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor, parameters.env));

Check warning on line 163 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L163

Added line #L163 was not covered by tests
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Set;
import java.util.StringJoiner;

import org.opensearch.env.Environment;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
Expand All @@ -37,6 +38,7 @@ public class RerankProcessorFactory implements Processor.Factory<SearchResponseP
public static final String CONTEXT_CONFIG_FIELD = "context";

private final MLCommonsClientAccessor clientAccessor;
private final Environment environment;
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved

@Override
public SearchResponseProcessor create(
Expand All @@ -49,7 +51,12 @@ public SearchResponseProcessor create(
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(
config,
includeQueryContextFetcher,
tag,
environment
);
switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
Expand Down Expand Up @@ -109,22 +116,23 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) {
public static List<ContextSourceFetcher> createFetchers(
Map<String, Object> config,
boolean includeQueryContextFetcher,
String tag
String tag,
final Environment environment
) {
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD);
List<ContextSourceFetcher> fetchers = new ArrayList<>();
for (String key : contextConfig.keySet()) {
Object cfg = contextConfig.get(key);
switch (key) {
case DocumentContextSourceFetcher.NAME:
fetchers.add(DocumentContextSourceFetcher.create(cfg));
fetchers.add(DocumentContextSourceFetcher.create(cfg, environment));
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key));
}
}
if (includeQueryContextFetcher) {
fetchers.add(new QueryContextSourceFetcher());
fetchers.add(new QueryContextSourceFetcher(environment));
}
return fetchers;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ObjectPath;
import org.opensearch.env.Environment;
import org.opensearch.search.SearchHit;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

Expand All @@ -29,8 +32,10 @@ public class DocumentContextSourceFetcher implements ContextSourceFetcher {

public static final String NAME = "document_fields";
public static final String DOCUMENT_CONTEXT_LIST_FIELD = "document_context_list";
public static final int MAX_DOCUMENT_FIELDS = 50;
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved

private final List<String> contextFields;
private final Environment environment;

/**
* Fetch the information needed in order to rerank.
Expand Down Expand Up @@ -87,15 +92,26 @@ public String getName() {
* @param config configuration object grabbed from parsed API request. Should be a list of strings
* @return a new DocumentContextSourceFetcher or throws IllegalArgumentException if config is malformed
*/
public static DocumentContextSourceFetcher create(Object config) {
public static DocumentContextSourceFetcher create(Object config, Environment environment) {
if (!(config instanceof List)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of field names", NAME));
}
List<?> fields = (List<?>) config;
if (fields.size() == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", NAME));
}
if (fields.size() > RERANKER_MAX_DOC_FIELDS.get(environment.settings())) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"%s must not contain more than %d fields. Configure by setting %s",
NAME,
RERANKER_MAX_DOC_FIELDS.get(environment.settings()),
RERANKER_MAX_DOC_FIELDS.getKey()
)
);
}
List<String> fieldsAsStrings = fields.stream().map(field -> (String) field).collect(Collectors.toList());
return new DocumentContextSourceFetcher(fieldsAsStrings);
return new DocumentContextSourceFetcher(fieldsAsStrings, environment);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,27 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.search.SearchExtBuilder;

import lombok.AllArgsConstructor;

/**
* Context Source Fetcher that gets context from the rerank query ext.
*/
@AllArgsConstructor
public class QueryContextSourceFetcher implements ContextSourceFetcher {

public static final String NAME = "query_context";
public static final String QUERY_TEXT_FIELD = "query_text";
public static final String QUERY_TEXT_PATH_FIELD = "query_text_path";

public static final Integer MAX_QUERY_PATH_STRLEN = 1024;

private final Environment environment;

@Override
public void fetchContext(
final SearchRequest searchRequest,
Expand Down Expand Up @@ -65,6 +74,17 @@
} else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) {
// Case "query_text_path": ser/de the query into a map and then find the text at the path specified
String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD);
if (!validatePath(path)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d nested fields or %d characters",
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
QUERY_TEXT_PATH_FIELD,
MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings()),
MAX_QUERY_PATH_STRLEN
)
);
}
Map<String, Object> map = requestToMap(searchRequest);
// Get the text at the path
Object queryText = ObjectPath.eval(path, map);
Expand Down Expand Up @@ -107,4 +127,14 @@
Map<String, Object> map = parser.map();
return map;
}

private boolean validatePath(final String path) {
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
if (path == null || path.isEmpty()) {
return true;

Check warning on line 133 in src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java#L133

Added line #L133 was not covered by tests
}
if (path.length() > MAX_QUERY_PATH_STRLEN) {
return false;
}
return path.split("\\.").length <= MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,13 @@ public final class NeuralSearchSettings {
false,
Setting.Property.NodeScope
);

/**
* Limits the number of document fields that can be passed to the reranker.
*/
public static final Setting<Integer> RERANKER_MAX_DOC_FIELDS = Setting.intSetting(
"plugins.neural_search.reranker_max_document_fields",
50,
Setting.Property.NodeScope
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand All @@ -15,6 +18,8 @@
import org.junit.Before;
import org.mockito.Mock;
import org.opensearch.OpenSearchParseException;
import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
Expand All @@ -37,11 +42,16 @@ public class RerankProcessorFactoryTests extends OpenSearchTestCase {
@Mock
private PipelineContext pipelineContext;

@Mock
private Environment environment;

@Before
public void setup() {
environment = mock(Environment.class);
pipelineContext = mock(PipelineContext.class);
clientAccessor = mock(MLCommonsClientAccessor.class);
factory = new RerankProcessorFactory(clientAccessor);
factory = new RerankProcessorFactory(clientAccessor, environment);
doReturn(Settings.EMPTY).when(environment).settings();
}

public void testRerankProcessorFactory_whenEmptyConfig_thenFail() {
Expand Down Expand Up @@ -187,4 +197,26 @@ public void testCrossEncoder_whenEmptyContextDocField_thenFail() {
);
}

public void testCrossEncoder_whenTooManyDocFields_ThenFail() {
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, Collections.nCopies(75, "field")))
)
);
assertThrows(
String.format(
Locale.ROOT,
"%s must not contain more than %d fields. Configure by setting %s",
DocumentContextSourceFetcher.NAME,
RERANKER_MAX_DOC_FIELDS.get(environment.settings()),
RERANKER_MAX_DOC_FIELDS.getKey()
),
IllegalArgumentException.class,
() -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.common.document.DocumentField;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
Expand Down Expand Up @@ -65,14 +68,18 @@ public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase {
@Mock
private PipelineProcessingContext ppctx;

@Mock
private Environment environment;

private RerankProcessorFactory factory;

private MLOpenSearchRerankProcessor processor;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
factory = new RerankProcessorFactory(mlCommonsClientAccessor);
doReturn(Settings.EMPTY).when(environment).settings();
factory = new RerankProcessorFactory(mlCommonsClientAccessor, environment);
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
Expand Down Expand Up @@ -223,6 +230,51 @@ public void testRerankContext_whenQueryTextPathIsBadPointer_thenFail() throws IO
.equals(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + " must point to a string field"));
}

public void testRerankContext_whenQueryTextPathIsExceeedinglyManyCharacters_thenFail() throws IOException {
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
// "eighteencharacters" * 60 = 1080 character string > max len of 1024
setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "eighteencharacters".repeat(60)));
setupSearchResults();
@SuppressWarnings("unchecked")
ActionListener<Map<String, Object>> listener = mock(ActionListener.class);
processor.generateRerankingContext(request, response, listener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue() instanceof IllegalArgumentException);
assert (argCaptor.getValue()
.getMessage()
.equals(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d nested fields or %d characters",
QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD,
MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings()),
QueryContextSourceFetcher.MAX_QUERY_PATH_STRLEN
)
));
}

public void textRerankContext_whenQueryTextPathIsExceeedinglyDeeplyNested_ThenFail() throws IOException {
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.w.x.y.z"));
setupSearchResults();
@SuppressWarnings("unchecked")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this annotation? I don't see ActionListener<Map<String, Object>> listener = mock(ActionListener.class); is causing any warning in my local.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

¯_(ツ)_/¯. I get a warning when I remove it. Makes sense since it's a typecast from generic action listener to action listener for that map thingy.

ActionListener<Map<String, Object>> listener = mock(ActionListener.class);
processor.generateRerankingContext(request, response, listener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue() instanceof IllegalArgumentException);
assert (argCaptor.getValue()
.getMessage()
.equals(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d nested fields or %d characters",
QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD,
MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings()),
QueryContextSourceFetcher.MAX_QUERY_PATH_STRLEN
)
));
}

public void testRescoreSearchResponse_HappyPath() throws IOException {
setupSimilarityRescoring();
setupSearchResults();
Expand Down
Loading