Skip to content

Commit

Permalink
add model input validation for local models in ml processor (#2610) (#…
Browse files Browse the repository at this point in the history
…2615)

* add model input validation for local models in ml processor

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>

---------

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
(cherry picked from commit 2b953cd)

Co-authored-by: Bhavana Ramaram <rbhavna@amazon.com>
(cherry picked from commit 6a250dd)
  • Loading branch information
opensearch-trigger-bot[bot] authored and github-actions[bot] committed Oct 2, 2024
1 parent a7ac777 commit d8debca
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
private final NamedXContentRegistry xContentRegistry;

private Configuration suppressExceptionConfiguration = Configuration
Expand Down Expand Up @@ -489,10 +490,19 @@ public MLInferenceIngestProcessor create(
boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false);
String functionName = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
String modelInput = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }");
boolean defaultValue = !functionName.equalsIgnoreCase("remote");
boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue);

String modelInput = ConfigurationUtils.readOptionalStringProperty(TYPE, processorTag, config, MODEL_INPUT);

// if model input is not provided for remote models, use default value
if (functionName.equalsIgnoreCase("remote")) {
modelInput = (modelInput != null) ? modelInput : DEFAULT_MODEl_INPUT;
} else if (modelInput == null) {
// if model input is not provided for local models, throw exception since it is mandatory here
throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
}
boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name());
boolean fullResponsePath = ConfigurationUtils
.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultFullResponsePath);

boolean ignoreFailure = ConfigurationUtils
.readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,60 @@ public void testCreateOptionalFields() throws Exception {
assertEquals(mLInferenceIngestProcessor.getTag(), processorTag);
assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE);
}

public void testLocalModel() throws Exception {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(MODEL_ID, "model2");
config.put(FUNCTION_NAME, "text_embedding");
Map<String, Object> model_config = new HashMap<>();
model_config.put("return_number", true);
config.put(MODEL_CONFIG, model_config);
config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }");
List<Map<String, String>> inputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put("text_docs", "chunks.*.chunk.text.*.context");
inputMap.add(input);
List<Map<String, String>> outputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put("chunks.*.chunk.text.*.embedding", "$.inference_results.*.output[2].data");
outputMap.add(output);
config.put(INPUT_MAP, inputMap);
config.put(OUTPUT_MAP, outputMap);
config.put(MAX_PREDICTION_TASKS, 5);
String processorTag = randomAlphaOfLength(10);

MLInferenceIngestProcessor mLInferenceIngestProcessor = factory.create(registry, processorTag, null, config);
assertNotNull(mLInferenceIngestProcessor);
assertEquals(mLInferenceIngestProcessor.getTag(), processorTag);
assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE);
}

public void testModelInputIsNullForLocalModels() throws Exception {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(MODEL_ID, "model2");
config.put(FUNCTION_NAME, "text_embedding");
Map<String, Object> model_config = new HashMap<>();
model_config.put("return_number", true);
config.put(MODEL_CONFIG, model_config);
List<Map<String, String>> inputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put("text_docs", "chunks.*.chunk.text.*.context");
inputMap.add(input);
List<Map<String, String>> outputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put("chunks.*.chunk.text.*.embedding", "$.inference_results.*.output[2].data");
outputMap.add(output);
config.put(INPUT_MAP, inputMap);
config.put(OUTPUT_MAP, outputMap);
config.put(MAX_PREDICTION_TASKS, 5);
String processorTag = randomAlphaOfLength(10);

try {
factory.create(registry, processorTag, null, config);
} catch (IllegalArgumentException e) {
assertEquals(e.getMessage(), ("Please provide model input when using a local model in ML Inference Processor"));
}
}
}

0 comments on commit d8debca

Please sign in to comment.