Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model input validation for local models in ml processor #2610

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,19 @@ public MLInferenceIngestProcessor create(
boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false);
String functionName = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
String modelInput = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }");
boolean defaultValue = !functionName.equalsIgnoreCase("remote");
boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue);

String modelInput = ConfigurationUtils.readOptionalStringProperty(TYPE, processorTag, config, MODEL_INPUT);

// if model input is not provided for remote models, use default value
if (functionName.equalsIgnoreCase("remote")) {
modelInput = (modelInput != null) ? modelInput : "{ \"parameters\": ${ml_inference.parameters} }";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it better to be a public static final string as DEFAULT_MODEl_INPUT_FIELD="{ "parameters": ${ml_inference.parameters} }"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

} else if (modelInput == null) {
// if model input is not provided for local models, throw exception since it is mandatory here
throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
}
boolean defaultFullResponsePath = !functionName.equalsIgnoreCase("remote");
boolean fullResponsePath = ConfigurationUtils
mingshl marked this conversation as resolved.
Show resolved Hide resolved
.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultFullResponsePath);

boolean ignoreFailure = ConfigurationUtils
.readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,60 @@ public void testCreateOptionalFields() throws Exception {
assertEquals(mLInferenceIngestProcessor.getTag(), processorTag);
assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE);
}

public void testLocalModel() throws Exception {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(MODEL_ID, "model2");
config.put(FUNCTION_NAME, "text_embedding");
Map<String, Object> model_config = new HashMap<>();
model_config.put("return_number", true);
config.put(MODEL_CONFIG, model_config);
config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }");
List<Map<String, String>> inputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put("text_docs", "chunks.*.chunk.text.*.context");
inputMap.add(input);
List<Map<String, String>> outputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put("chunks.*.chunk.text.*.embedding", "$.inference_results.*.output[2].data");
outputMap.add(output);
config.put(INPUT_MAP, inputMap);
config.put(OUTPUT_MAP, outputMap);
config.put(MAX_PREDICTION_TASKS, 5);
String processorTag = randomAlphaOfLength(10);

MLInferenceIngestProcessor mLInferenceIngestProcessor = factory.create(registry, processorTag, null, config);
assertNotNull(mLInferenceIngestProcessor);
assertEquals(mLInferenceIngestProcessor.getTag(), processorTag);
assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE);
}

public void testModelInputIsNullForLocalModels() throws Exception {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(MODEL_ID, "model2");
config.put(FUNCTION_NAME, "text_embedding");
Map<String, Object> model_config = new HashMap<>();
model_config.put("return_number", true);
config.put(MODEL_CONFIG, model_config);
List<Map<String, String>> inputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put("text_docs", "chunks.*.chunk.text.*.context");
inputMap.add(input);
List<Map<String, String>> outputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put("chunks.*.chunk.text.*.embedding", "$.inference_results.*.output[2].data");
outputMap.add(output);
config.put(INPUT_MAP, inputMap);
config.put(OUTPUT_MAP, outputMap);
config.put(MAX_PREDICTION_TASKS, 5);
String processorTag = randomAlphaOfLength(10);

try {
factory.create(registry, processorTag, null, config);
} catch (IllegalArgumentException e) {
assertEquals(e.getMessage(), ("Please provide model input when using a local model in ML Inference Processor"));
}
}
}
Loading