Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna committed Jul 5, 2024
1 parent e749757 commit 7e03b17
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
private final NamedXContentRegistry xContentRegistry;

private Configuration suppressExceptionConfiguration = Configuration
Expand Down Expand Up @@ -494,12 +495,12 @@ public MLInferenceIngestProcessor create(

// if model input is not provided for remote models, use default value
if (functionName.equalsIgnoreCase("remote")) {
modelInput = (modelInput != null) ? modelInput : "{ \"parameters\": ${ml_inference.parameters} }";
modelInput = (modelInput != null) ? modelInput : DEFAULT_MODEl_INPUT;
} else if (modelInput == null) {
// if model input is not provided for local models, throw exception since it is mandatory here
throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
}
boolean defaultFullResponsePath = !functionName.equalsIgnoreCase("remote");
boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name());
boolean fullResponsePath = ConfigurationUtils
.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultFullResponsePath);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ public void testModelInputIsNullForLocalModels() throws Exception {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(MODEL_ID, "model2");
config.put(FUNCTION_NAME, "text_embedding");
config.put(FUNCTION_NAME, "remote");
Map<String, Object> model_config = new HashMap<>();
model_config.put("return_number", true);
config.put(MODEL_CONFIG, model_config);
Expand Down

0 comments on commit 7e03b17

Please sign in to comment.