diff --git a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java index ea36318741..828aa970a0 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java +++ b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java @@ -11,6 +11,7 @@ import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.annotation.MLAlgoOutput; import org.opensearch.ml.common.annotation.MLAlgoParameter; +import org.opensearch.ml.common.annotation.MLInput; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.output.MLOutputType; @@ -30,6 +31,7 @@ public class MLCommonsClassLoader { private static Map, Class> parameterClassMap = new HashMap<>(); private static Map, Class> executeInputClassMap = new HashMap<>(); private static Map, Class> executeOutputClassMap = new HashMap<>(); + private static Map, Class> mlInputClassMap = new HashMap<>(); static { try { @@ -51,6 +53,7 @@ public static void loadClassMapping() { loadMLInputDataSetClassMapping(); loadExecuteInputClassMapping(); loadExecuteOutputClassMapping(); + loadMLInputClassMapping(); } finally { Thread.currentThread().setContextClassLoader(originalClassLoader); } @@ -160,6 +163,22 @@ private static void loadExecuteOutputClassMapping() { } } + private static void loadMLInputClassMapping() { + Reflections reflections = new Reflections("org.opensearch.ml.common.input"); + Set> classes = reflections.getTypesAnnotatedWith(MLInput.class); + for (Class clazz : classes) { + MLInput mlInput = clazz.getAnnotation(MLInput.class); + if (mlInput != null) { + FunctionName[] algorithms = mlInput.functionNames(); + if (algorithms != null && algorithms.length > 0) { + for(FunctionName name : algorithms){ + mlInputClassMap.put(name, clazz); + } + } + } + } + } + @SuppressWarnings("unchecked") public static , S, I extends Object> S initMLInstance(T type, I in, Class constructorParamClass) { return init(parameterClassMap, type, in, constructorParamClass); @@ -195,4 +214,33 @@ private static , S, I extends Object> S init(Map, Clas } } + public static boolean canInitMLInput(FunctionName functionName) { + return mlInputClassMap.containsKey(functionName); + } + + @SuppressWarnings("unchecked") + public static , S> S initMLInput(T type, Object[] initArgs, + Class... constructorParameterTypes) { + return init(mlInputClassMap, type, initArgs, constructorParameterTypes); + } + + private static , S> S init(Map, Class> map, T type, + Object[] initArgs, Class... constructorParameterTypes) { + Class clazz = map.get(type); + if (clazz == null) { + throw new IllegalArgumentException("Can't find class for type " + type); + } + try { + Constructor constructor = clazz.getConstructor(constructorParameterTypes); + return (S) constructor.newInstance(initArgs); + } catch (Exception e) { + Throwable cause = e.getCause(); + if (cause instanceof MLException) { + throw (MLException)cause; + } else { + log.error("Failed to init instance for type " + type, e); + return null; + } + } + } } diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java b/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java new file mode 100644 index 0000000000..b8100473b0 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.annotation; + +import org.opensearch.ml.common.FunctionName; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface MLInput { + // supported algorithms + FunctionName[] functionNames(); +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index 6373d09422..b1278b6177 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -7,6 +7,7 @@ import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.XContentBuilder; @@ -32,9 +33,10 @@ import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; /** - * ML input data: algirithm name, parameters and input data set. + * ML input data: algorithm name, parameters and input data set. */ @Data +@NoArgsConstructor public class MLInput implements Input { public static final String ALGORITHM_FIELD = "algorithm"; @@ -56,11 +58,11 @@ public class MLInput implements Input { public static final String TEXT_DOCS_FIELD = "text_docs"; // Algorithm name - private FunctionName algorithm; + protected FunctionName algorithm; // ML algorithm parameters - private MLAlgoParams parameters; + protected MLAlgoParams parameters; // Input data to train model, run trained model to predict or run ML algorithms(no-model-based) directly. - private MLInputDataset inputDataset; + protected MLInputDataset inputDataset; private int version = 1; @@ -169,6 +171,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException { String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT); FunctionName algorithm = FunctionName.from(algorithmName); + + if (MLCommonsClassLoader.canInitMLInput(algorithm)) { + MLInput mlInput = MLCommonsClassLoader.initMLInput(algorithm, new Object[]{parser, algorithm}, XContentParser.class, FunctionName.class); + mlInput.setAlgorithm(algorithm); + return mlInput; + } + MLAlgoParams mlParameters = null; SearchSourceBuilder searchSourceBuilder = null; List sourceIndices = new ArrayList<>(); diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java new file mode 100644 index 0000000000..9eeebc7534 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java @@ -0,0 +1,133 @@ +package org.opensearch.ml.common.input.nlp; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelResultFilter; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * ML input class which supports a list fo text docs. + * This class can be used for TEXT_EMBEDDING model. + */ +@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING}) +public class TextDocsMLInput extends MLInput { + public static final String TEXT_DOCS_FIELD = "text_docs"; + public static final String RESULT_FILTER_FIELD = "result_filter"; + + public TextDocsMLInput(FunctionName algorithm, MLInputDataset inputDataset) { + super(algorithm, null, inputDataset); + } + + public TextDocsMLInput(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ALGORITHM_FIELD, algorithm.name()); + if (parameters != null) { + builder.field(ML_PARAMETERS_FIELD, parameters); + } + if (inputDataset != null) { + TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset; + List docs = textInputDataSet.getDocs(); + ModelResultFilter resultFilter = textInputDataSet.getResultFilter(); + if (docs != null && docs.size() > 0) { + builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0])); + } + if (resultFilter != null) { + builder.startObject(RESULT_FILTER_FIELD); + builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes()); + builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber()); + List targetResponse = resultFilter.getTargetResponse(); + if (targetResponse != null && targetResponse.size() > 0) { + builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0])); + } + List targetPositions = resultFilter.getTargetResponsePositions(); + if (targetPositions != null && targetPositions.size() > 0) { + builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0])); + } + builder.endObject(); + } + } + builder.endObject(); + return builder; + } + + public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws IOException { + super(); + this.algorithm = functionName; + List docs = new ArrayList<>(); + ModelResultFilter resultFilter = null; + + boolean returnBytes = false; + boolean returnNumber = true; + List targetResponse = new ArrayList<>(); + List targetResponsePositions = new ArrayList<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case RETURN_BYTES_FIELD: + returnBytes = parser.booleanValue(); + break; + case RETURN_NUMBER_FIELD: + returnNumber = parser.booleanValue(); + break; + case TARGET_RESPONSE_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + targetResponse.add(parser.text()); + } + break; + case TARGET_RESPONSE_POSITIONS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + targetResponsePositions.add(parser.intValue()); + } + break; + case TEXT_DOCS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + docs.add(parser.text()); + } + break; + case RESULT_FILTER_FIELD: + resultFilter = ModelResultFilter.parse(parser); + break; + default: + parser.skipChildren(); + break; + } + } + ModelResultFilter filter = resultFilter != null ? resultFilter : ModelResultFilter.builder().returnBytes(returnBytes) + .returnNumber(returnNumber).targetResponse(targetResponse).targetResponsePositions(targetResponsePositions) + .build(); + + if (docs.size() == 0) { + throw new IllegalArgumentException("Empty text docs"); + } + inputDataset = new TextDocsInputDataSet(docs, filter); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelResultFilter.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelResultFilter.java index dd42cf04ed..1bdc0a834a 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelResultFilter.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelResultFilter.java @@ -7,11 +7,14 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.XContentParser; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + /** * This class is to filter model results. */ @@ -19,6 +22,14 @@ @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class ModelResultFilter implements Writeable { + public static final String RETURN_BYTES_FIELD = "return_bytes"; + // Return bytes in model output. This can be used together with return_bytes. + public static final String RETURN_NUMBER_FIELD = "return_number"; + // Filter target response with name in model output + public static final String TARGET_RESPONSE_FIELD = "target_response"; + // Filter target response with position in model output + public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions"; + // Return model output as bytes. This could be useful if client side prefer // to parse the model output in its own way. protected boolean returnBytes; @@ -77,4 +88,41 @@ public void writeTo(StreamOutput streamOutput) throws IOException { streamOutput.writeBoolean(false); } } + + public static ModelResultFilter parse(XContentParser parser) throws IOException { + boolean returnBytes = false; + boolean returnNumber = true; + List targetResponse = new ArrayList<>(); + List targetResponsePositions = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case RETURN_BYTES_FIELD: + returnBytes = parser.booleanValue(); + break; + case RETURN_NUMBER_FIELD: + returnNumber = parser.booleanValue(); + break; + case TARGET_RESPONSE_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + targetResponse.add(parser.text()); + } + break; + case TARGET_RESPONSE_POSITIONS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + targetResponsePositions.add(parser.intValue()); + } + break; + default: + parser.skipChildren(); + break; + } + } + return new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java index 250634d9ed..341bc32e75 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java @@ -11,17 +11,26 @@ import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; +import org.opensearch.ml.common.input.nlp.TextDocsMLInput; import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; +import org.opensearch.search.SearchModule; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; public class MLCommonsClassLoaderTests { @@ -90,6 +99,21 @@ public void testClassLoader_ExecuteOutput() { assertEquals(this.output, calculatorOutput); } + @Test + public void testClassLoader_MLInput() throws IOException { + assertTrue(MLCommonsClassLoader.canInitMLInput(FunctionName.TEXT_EMBEDDING)); + + String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + + TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(FunctionName.TEXT_EMBEDDING, new Object[]{parser, FunctionName.TEXT_EMBEDDING}, XContentParser.class, FunctionName.class); + assertNotNull(mlInput); + assertEquals(FunctionName.TEXT_EMBEDDING, mlInput.getFunctionName()); + assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size()); + } + public enum TestEnum { TEST } diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java new file mode 100644 index 0000000000..8fe02b4427 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java @@ -0,0 +1,85 @@ +package org.opensearch.ml.common.input.nlp; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelResultFilter; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class TextDocsMLInputTest { + + MLInput input; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private final FunctionName algorithm = FunctionName.TEXT_EMBEDDING; + + @Before + public void setUp() throws Exception { + ModelResultFilter resultFilter = ModelResultFilter.builder().returnBytes(true).returnNumber(true).targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build(); + MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2")).resultFilter(resultFilter).build(); + input = new TextDocsMLInput(algorithm, inputDataset); + } + + @Test + public void parseTextDocsMLInput() throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = Strings.toString(builder); + System.out.println(jsonStr); + parseMLInput(jsonStr); + } + + @Test + public void parseTextDocsMLInput_OldWay() throws IOException { + String jsonStr = "{\"text_docs\": [ \"doc1\", \"doc2\" ],\"return_number\": true, \"return_bytes\": true,\"target_response\": [ \"field1\" ], \"target_response_positions\": [2]}"; + parseMLInput(jsonStr); + } + + @Test + public void parseTextDocsMLInput_NewWay() throws IOException { + String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; + parseMLInput(jsonStr); + } + + private void parseMLInput(String jsonStr) throws IOException { + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + + MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); + assertTrue(parsedInput instanceof TextDocsMLInput); + assertEquals(input.getFunctionName(), parsedInput.getFunctionName()); + assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType()); + TextDocsInputDataSet inputDataset = (TextDocsInputDataSet) parsedInput.getInputDataset(); + assertEquals(2, inputDataset.getDocs().size()); + assertEquals("doc1", inputDataset.getDocs().get(0)); + assertEquals("doc2", inputDataset.getDocs().get(1)); + assertNotNull(inputDataset.getResultFilter()); + assertTrue(inputDataset.getResultFilter().isReturnBytes()); + assertTrue(inputDataset.getResultFilter().isReturnNumber()); + } + +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java index 53f0a7c0e9..ef54b42cce 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.engine.algorithms; import ai.djl.Application;