Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model input validation for local models in ml processor #2610

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
private final NamedXContentRegistry xContentRegistry;

private Configuration suppressExceptionConfiguration = Configuration
Expand Down Expand Up @@ -489,10 +490,19 @@ public MLInferenceIngestProcessor create(
boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false);
String functionName = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
String modelInput = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }");
boolean defaultValue = !functionName.equalsIgnoreCase("remote");
boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue);

String modelInput = ConfigurationUtils.readOptionalStringProperty(TYPE, processorTag, config, MODEL_INPUT);

// if model input is not provided for remote models, use default value
if (functionName.equalsIgnoreCase("remote")) {
modelInput = (modelInput != null) ? modelInput : DEFAULT_MODEl_INPUT;
} else if (modelInput == null) {
// if model input is not provided for local models, throw exception since it is mandatory here
throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
}
boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name());
boolean fullResponsePath = ConfigurationUtils
mingshl marked this conversation as resolved.
Show resolved Hide resolved
.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultFullResponsePath);

boolean ignoreFailure = ConfigurationUtils
.readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,60 @@ public void testCreateOptionalFields() throws Exception {
assertEquals(mLInferenceIngestProcessor.getTag(), processorTag);
assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE);
}

public void testLocalModel() throws Exception {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(MODEL_ID, "model2");
config.put(FUNCTION_NAME, "text_embedding");
Map<String, Object> model_config = new HashMap<>();
model_config.put("return_number", true);
config.put(MODEL_CONFIG, model_config);
config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }");
List<Map<String, String>> inputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put("text_docs", "chunks.*.chunk.text.*.context");
inputMap.add(input);
List<Map<String, String>> outputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put("chunks.*.chunk.text.*.embedding", "$.inference_results.*.output[2].data");
outputMap.add(output);
config.put(INPUT_MAP, inputMap);
config.put(OUTPUT_MAP, outputMap);
config.put(MAX_PREDICTION_TASKS, 5);
String processorTag = randomAlphaOfLength(10);

MLInferenceIngestProcessor mLInferenceIngestProcessor = factory.create(registry, processorTag, null, config);
assertNotNull(mLInferenceIngestProcessor);
assertEquals(mLInferenceIngestProcessor.getTag(), processorTag);
assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE);
}

public void testModelInputIsNullForLocalModels() throws Exception {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(MODEL_ID, "model2");
config.put(FUNCTION_NAME, "text_embedding");
Map<String, Object> model_config = new HashMap<>();
model_config.put("return_number", true);
config.put(MODEL_CONFIG, model_config);
List<Map<String, String>> inputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put("text_docs", "chunks.*.chunk.text.*.context");
inputMap.add(input);
List<Map<String, String>> outputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put("chunks.*.chunk.text.*.embedding", "$.inference_results.*.output[2].data");
outputMap.add(output);
config.put(INPUT_MAP, inputMap);
config.put(OUTPUT_MAP, outputMap);
config.put(MAX_PREDICTION_TASKS, 5);
String processorTag = randomAlphaOfLength(10);

try {
factory.create(registry, processorTag, null, config);
} catch (IllegalArgumentException e) {
assertEquals(e.getMessage(), ("Please provide model input when using a local model in ML Inference Processor"));
}
}
}
Loading