From 54bf4520a21d1adae03064a66ddaf6a9c2ed3d4d Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 23 Nov 2021 15:17:02 -0800 Subject: [PATCH] tune class name; tune how to add new doc Signed-off-by: Yaliang Wu --- .../ml/client/MachineLearningClient.java | 6 +- .../ml/client/MachineLearningClientTest.java | 15 ++- .../client/MachineLearningNodeClientTest.java | 10 +- .../ml/common/MLCommonsClassLoader.java | 8 +- .../ml/common/dataframe/DataFrameType.java | 2 +- .../{MLAlgoName.java => FunctionName.java} | 8 +- .../opensearch/ml/common/parameter/Input.java | 2 +- .../ml/common/parameter/KMeansParams.java | 2 +- .../parameter/LinearRegressionParams.java | 2 +- .../parameter/LocalSampleCalculatorInput.java | 6 +- .../ml/common/parameter/MLInput.java | 10 +- .../ml/common/parameter/SampleAlgoParams.java | 3 +- .../ml/common/MLCommonsClassLoaderTests.java | 7 +- .../MLPredictionTaskRequestTest.java | 6 +- .../training/MLTrainingTaskRequestTest.java | 6 +- ...lgorithm.md => how-to-add-new-function.md} | 97 +++++++++---------- .../org/opensearch/ml/engine/MLEngine.java | 6 +- .../engine/algorithms/clustering/KMeans.java | 6 +- .../regression/LinearRegression.java | 4 +- .../engine/algorithms/sample/SampleAlgo.java | 4 +- .../opensearch/ml/engine/MLEngineTest.java | 21 ++-- .../ml/engine/clustering/KMeansTest.java | 3 +- .../action/prediction/PredictionITTests.java | 8 +- .../ml/action/training/TrainingITTests.java | 6 +- .../opensearch/ml/params/MLInputTests.java | 5 +- .../ml/rest/BaseMLSearchActionTests.java | 9 +- .../opensearch/ml/utils/IntegTestUtils.java | 6 +- 27 files changed, 131 insertions(+), 137 deletions(-) rename common/src/main/java/org/opensearch/ml/common/parameter/{MLAlgoName.java => FunctionName.java} (72%) rename docs/{how-to-add-new-algorithm.md => how-to-add-new-function.md} (65%) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 8fa4bdf092..a619c1ccb8 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -65,9 +65,9 @@ default ActionFuture train(MLInput mlInput) { void train(MLInput mlInput, ActionListener listener); /** - * Execute ML algorithm. + * Execute function and return ActionFuture. * @param input input data - * @return output + * @return ActionFuture of output */ default ActionFuture execute(Input input) { PlainActionFuture actionFuture = PlainActionFuture.newFuture(); @@ -76,7 +76,7 @@ default ActionFuture execute(Input input) { } /** - * Execute ML algorithm + * Execute function and return output in listener * @param input input data * @param listener action listener */ diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index ce5fb9c2fa..3b8d1b9b49 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -22,11 +22,10 @@ import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.parameter.Input; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLOutput; -import org.opensearch.ml.common.parameter.MLOutputType; import org.opensearch.ml.common.parameter.MLTrainingOutput; import org.opensearch.ml.common.parameter.Output; @@ -93,7 +92,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Test public void predict_WithAlgoAndInputData() { MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .dataFrame(input) .build(); assertEquals(output, machineLearningClient.predict(null, mlInput).actionGet()); @@ -102,7 +101,7 @@ public void predict_WithAlgoAndInputData() { @Test public void predict_WithAlgoAndParametersAndInputData() { MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .parameters(mlParameters) .dataFrame(input) .build(); @@ -112,7 +111,7 @@ public void predict_WithAlgoAndParametersAndInputData() { @Test public void predict_WithAlgoAndParametersAndInputDataAndModelId() { MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .parameters(mlParameters) .dataFrame(input) .build(); @@ -122,7 +121,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() { @Test public void predict_WithAlgoAndInputDataAndListener() { MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .dataFrame(input) .build(); ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); @@ -134,7 +133,7 @@ public void predict_WithAlgoAndInputDataAndListener() { @Test public void predict_WithAlgoAndInputDataAndParametersAndListener() { MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .parameters(mlParameters) .dataFrame(input) .build(); @@ -147,7 +146,7 @@ public void predict_WithAlgoAndInputDataAndParametersAndListener() { @Test public void train() { MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .parameters(mlParameters) .dataFrame(input) .build(); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 25007d420f..c8c04151d4 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -23,7 +23,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.MLInputDataset; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.MLPredictionOutput; @@ -88,7 +88,7 @@ public void predict() { ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .inputDataset(input) .build(); machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener); @@ -114,7 +114,7 @@ public void predict_Exception_WithNullDataSet() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("input data set can't be null"); MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .build(); machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener); } @@ -137,7 +137,7 @@ public void train() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLOutput.class); MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .inputDataset(input) .build(); machineLearningNodeClient.train(mlInput, trainingActionListener); @@ -164,7 +164,7 @@ public void train_Exception_WithNullDataSet() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("input data set can't be null"); MLInput mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .build(); machineLearningNodeClient.train(mlInput, trainingActionListener); } diff --git a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java index c07e1d8f2b..b10024b094 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java +++ b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java @@ -20,7 +20,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.parameter.Input; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLOutputType; @@ -57,15 +57,15 @@ public static void loadClassMapping(Class resource, String configFile) { if (currentToken == XContentParser.Token.FIELD_NAME) { String key = parser.currentName(); if ("ml_algo_param_class".equals(key)) { - parseMLAlgoParams(parser, parameterClassMap, k -> MLAlgoName.fromString(k)); + parseMLAlgoParams(parser, parameterClassMap, k -> FunctionName.fromString(k)); } else if ("ml_input_data_set_class".equals(key)) { parseMLAlgoParams(parser, parameterClassMap, k -> MLInputDataType.fromString(k)); } else if ("ml_output_class".equals(key)) { parseMLAlgoParams(parser, parameterClassMap, k -> MLOutputType.fromString(k)); } else if ("ml_algo_class".equals(key)) { - parseMLAlgoParams(parser, mlAlgoClassMap, k -> MLAlgoName.fromString(k)); + parseMLAlgoParams(parser, mlAlgoClassMap, k -> FunctionName.fromString(k)); } else if ("executable_function_class".equals(key)) { - parseMLAlgoParams(parser, mlAlgoClassMap, k -> MLAlgoName.fromString(k)); + parseMLAlgoParams(parser, mlAlgoClassMap, k -> FunctionName.fromString(k)); } } else { parser.nextToken(); diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameType.java b/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameType.java index 73a1646639..6828aa8b47 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameType.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameType.java @@ -15,7 +15,7 @@ import lombok.Getter; public enum DataFrameType { - DEFAULT("DEFAULT"); + DEFAULT("default"); @Getter private final String name; diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/MLAlgoName.java b/common/src/main/java/org/opensearch/ml/common/parameter/FunctionName.java similarity index 72% rename from common/src/main/java/org/opensearch/ml/common/parameter/MLAlgoName.java rename to common/src/main/java/org/opensearch/ml/common/parameter/FunctionName.java index 72e1d2c528..6828bf443c 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/MLAlgoName.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/FunctionName.java @@ -2,7 +2,7 @@ import lombok.Getter; -public enum MLAlgoName { +public enum FunctionName { LINEAR_REGRESSION("linear_regression"), KMEANS("kmeans"), SAMPLE_ALGO("sample_algo"), @@ -11,7 +11,7 @@ public enum MLAlgoName { @Getter private final String name; - MLAlgoName(String name) { + FunctionName(String name) { this.name = name; } @@ -19,8 +19,8 @@ public String toString() { return name; } - public static MLAlgoName fromString(String name){ - for(MLAlgoName e : MLAlgoName.values()){ + public static FunctionName fromString(String name){ + for(FunctionName e : FunctionName.values()){ if(e.name.equals(name)) return e; } return null; diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/Input.java b/common/src/main/java/org/opensearch/ml/common/parameter/Input.java index 5d850ac8cb..c8f236106e 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/Input.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/Input.java @@ -16,5 +16,5 @@ public interface Input extends ToXContentObject, Writeable { - MLAlgoName getFunctionName(); + FunctionName getFunctionName(); } diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java b/common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java index caf41d23ca..bbd4b78ce0 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java @@ -17,7 +17,7 @@ @Data public class KMeansParams implements MLAlgoParams { - public static final String PARSE_FIELD_NAME = "kmeans"; + public static final String PARSE_FIELD_NAME = FunctionName.KMEANS.getName(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( MLAlgoParams.class, new ParseField(PARSE_FIELD_NAME), diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java b/common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java index 76ca529aea..76d31cb58e 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java @@ -17,7 +17,7 @@ @Data public class LinearRegressionParams implements MLAlgoParams { - public static final String PARSE_FIELD_NAME = "linear_regression"; + public static final String PARSE_FIELD_NAME = FunctionName.LINEAR_REGRESSION.getName(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( MLAlgoParams.class, new ParseField(PARSE_FIELD_NAME), diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java b/common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java index a925e18381..0420d5d0da 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java @@ -28,7 +28,7 @@ @Getter public class LocalSampleCalculatorInput implements Input { - public static final String PARSE_FIELD_NAME = "local_sample_calculator"; + public static final String PARSE_FIELD_NAME = FunctionName.LOCAL_SAMPLE_CALCULATOR.getName(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( Input.class, new ParseField(PARSE_FIELD_NAME), @@ -81,8 +81,8 @@ public LocalSampleCalculatorInput(String operation, List inputData) { } @Override - public MLAlgoName getFunctionName() { - return MLAlgoName.LOCAL_SAMPLE_CALCULATOR; + public FunctionName getFunctionName() { + return FunctionName.LOCAL_SAMPLE_CALCULATOR; } public LocalSampleCalculatorInput(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java b/common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java index 96ff9ea022..c4e92f3566 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java @@ -46,7 +46,7 @@ public class MLInput implements Input { public static final String INPUT_DATA_FIELD = "input_data"; // Algorithm name - private MLAlgoName algorithm; + private FunctionName algorithm; // ML algorithm parameters private MLAlgoParams parameters; // Input data to train model, run trained model to predict or run ML algorithms(no-model-based) directly. @@ -55,7 +55,7 @@ public class MLInput implements Input { private int version = 1; @Builder - public MLInput(MLAlgoName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, List sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) { + public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, List sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) { this.algorithm = algorithm; this.parameters = parameters; if (inputDataset != null) { @@ -66,7 +66,7 @@ public MLInput(MLAlgoName algorithm, MLAlgoParams parameters, SearchSourceBuilde } public MLInput(StreamInput in) throws IOException { - this.algorithm = in.readEnum(MLAlgoName.class); + this.algorithm = in.readEnum(FunctionName.class); if (in.readBoolean()) { this.parameters = MLCommonsClassLoader.initInstance(algorithm, in, StreamInput.class); } @@ -120,7 +120,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } public static MLInput parse(XContentParser parser, String algorithmName) throws IOException { - MLAlgoName algorithm = MLAlgoName.fromString(algorithmName); + FunctionName algorithm = FunctionName.fromString(algorithmName); MLAlgoParams mlParameters = null; SearchSourceBuilder searchSourceBuilder = null; List sourceIndices = new ArrayList<>(); @@ -166,7 +166,7 @@ private MLInputDataset createInputDataSet(SearchSourceBuilder searchSourceBuilde } @Override - public MLAlgoName getFunctionName() { + public FunctionName getFunctionName() { return this.algorithm; } } diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/SampleAlgoParams.java b/common/src/main/java/org/opensearch/ml/common/parameter/SampleAlgoParams.java index 9b88fd8f6d..66f0c253fc 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/SampleAlgoParams.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/SampleAlgoParams.java @@ -25,8 +25,7 @@ @Data public class SampleAlgoParams implements MLAlgoParams { - - public static final String PARSE_FIELD_NAME = "sample_algo"; + public static final String PARSE_FIELD_NAME = FunctionName.SAMPLE_ALGO.getName(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( MLAlgoParams.class, new ParseField(PARSE_FIELD_NAME), diff --git a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java index abb07e3d20..e8589e2316 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java @@ -17,8 +17,7 @@ import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.ml.common.MLCommonsClassLoader; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.SampleAlgoParams; @@ -48,13 +47,13 @@ public void setUp() throws IOException { @Test public void testClassLoader_SampleAlgoParams() { - SampleAlgoParams sampleAlgoParams = MLCommonsClassLoader.initInstance(MLAlgoName.SAMPLE_ALGO, streamInput, StreamInput.class); + SampleAlgoParams sampleAlgoParams = MLCommonsClassLoader.initInstance(FunctionName.SAMPLE_ALGO, streamInput, StreamInput.class); assertEquals(params.getSampleParam(), sampleAlgoParams.getSampleParam()); } @Test public void testClassLoader_Return_MLAlgoParams() { - MLAlgoParams mlAlgoParams = MLCommonsClassLoader.initInstance(MLAlgoName.SAMPLE_ALGO, streamInput, StreamInput.class); + MLAlgoParams mlAlgoParams = MLCommonsClassLoader.initInstance(FunctionName.SAMPLE_ALGO, streamInput, StreamInput.class); assertTrue(mlAlgoParams instanceof SampleAlgoParams); assertEquals(params.getSampleParam(), ((SampleAlgoParams)mlAlgoParams).getSampleParam()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java index 5745736516..523032b7e5 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java @@ -34,7 +34,7 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.parameter.KMeansParams; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.search.builder.SearchSourceBuilder; @@ -50,7 +50,7 @@ public class MLPredictionTaskRequestTest { @Before public void setUp() { mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .parameters(KMeansParams.builder().centroids(1).build()) .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); @@ -68,7 +68,7 @@ public void writeTo_Success() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(MLAlgoName.KMEANS, request.getMlInput().getAlgorithm()); + assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm()); KMeansParams params = (KMeansParams)request.getMlInput().getParameters(); assertEquals(1, params.getCentroids().intValue()); MLInputDataset inputDataset = request.getMlInput().getInputDataset(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java index e8bfd2862e..858f8b3b3c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java @@ -21,7 +21,7 @@ import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.parameter.KMeansParams; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLInput; import java.io.IOException; @@ -41,7 +41,7 @@ public class MLTrainingTaskRequestTest { @Before public void setUp() { mlInput = MLInput.builder() - .algorithm(MLAlgoName.KMEANS) + .algorithm(FunctionName.KMEANS) .parameters(KMeansParams.builder().centroids(1).build()) .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); @@ -75,7 +75,7 @@ public void writeTo() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLTrainingTaskRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(MLAlgoName.KMEANS, request.getMlInput().getAlgorithm()); + assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm()); assertEquals(1, ((KMeansParams) request.getMlInput().getParameters()).getCentroids().intValue()); assertEquals(MLInputDataType.DATA_FRAME, request.getMlInput().getInputDataset().getInputDataType()); } diff --git a/docs/how-to-add-new-algorithm.md b/docs/how-to-add-new-function.md similarity index 65% rename from docs/how-to-add-new-algorithm.md rename to docs/how-to-add-new-function.md index 400bce9f69..67982a339b 100644 --- a/docs/how-to-add-new-algorithm.md +++ b/docs/how-to-add-new-function.md @@ -1,30 +1,27 @@ -# How to add new algorithm +# How to add new function -This doc explains how to add new algorithm to `ml-commons` with two examples. +This doc explains how to add new function to `ml-commons` with two examples. -## Example 1 - Sample algorithm with train/predict APIs +## Example 1 - Sample ML algorithm with train/predict APIs This sample algorithm will train a dummy model first. Then user can call predict API to calculate total sum of a data frame or selected data from indices. - -### Step 1: add parameter class for new algorithm -Add new enum in `MLAlgoName` +### Step 1: name the function +Add new function name in `org.opensearch.ml.common.parameter.FunctionName` ``` -public class MLAlgoName { - LINEAR_REGRESSION("linear_regression"), - KMEANS("kmeans"), - +public class FunctionName { + ... //Add new sample algorithm name SAMPLE_ALGO("sample_algo"), ... } ``` -Create new class `SampleAlgoParams` in `common` package by implements `MLAlgoParams` interface. -Must define `NamedXContentRegistry.Entry` in `SampleAlgoParams`, sample code(check more details in ml-commons code) +### Step 2: add input class +Create new class `org.opensearch.ml.common.parameter.SampleAlgoParams` in `common` package by implementing `MLAlgoParams` interface. +Must define `NamedXContentRegistry.Entry` in `SampleAlgoParams`. ``` public class SampleAlgoParams implements MLAlgoParams { - - public static final String PARSE_FIELD_NAME = "sample_algo"; + public static final String PARSE_FIELD_NAME = FunctionName.SAMPLE_ALGO.getName(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( MLAlgoParams.class, new ParseField(PARSE_FIELD_NAME), @@ -46,23 +43,23 @@ Register `SampleAlgoParams.XCONTENT_REGISTRY` in `MachineLearningPlugin.getNamed ); } ``` -### Step 2: add output class for new algorithm -Add new enum in `MLOutput.MLOutputType`: +### Step 3: add output class +Add new enum in `org.opensearch.ml.common.parameter.MLOutputType`: ``` public enum MLOutputType { ... - SAMPLE_ALGO("SAMPLE_ALGO"); + SAMPLE_ALGO("sample_algo"), ... } ``` -Create new class `SampleAlgoOutput` in `common` package by extending abstract class `MLOutput`. +Create new class `org.opensearch.ml.common.parameter.SampleAlgoOutput` in `common` package by extending abstract class `MLOutput`. -### Step 3: add new algorithm in `ml-algorithms` package -Create new class `ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java` by implementing interface `MLAlgo`. +### Step 4: add implementation +Create new class `org.opensearch.ml.engine.algorithms.sample.SampleAlgo` in `ml-algorithms` package by implementing interface `MLAlgo`. Override `train`, `predict` methods. -### Step 4: configure new algorithm -Add new input/output in config file `common/src/main/resources/ml-commons-config.yml` +### Step 5: configure new algorithm +Add new input/output to config file `common/src/main/resources/ml-commons-config.yml` Key is input/output enum name, value is class name. ``` @@ -75,17 +72,16 @@ ml_output_class: ... ``` -Add new ML algorithm in config file `ml-algorithms/src/main/resources/ml-algorithm-config.yml` +Add new ML algorithm to config file `ml-algorithms/src/main/resources/ml-algorithm-config.yml` -Key is ML algorithm enum name, value is class name. +Key is ML algorithm enum name added to `FunctionName` in Step 1, value is class name. ``` ml_algo_class: sample_algo: org.opensearch.ml.engine.algorithms.sample.SampleAlgo ... - ``` -### Step 5: Run and test +### Step 6: Run and test Run `./gradlew run` and test sample algorithm. Train with sample data @@ -94,7 +90,7 @@ Train with sample data POST /_plugins/_ml/_train/sample_algo { "parameters": { - "sample_param": 22 + "sample_param": 10 }, "input_data": { "column_metas": [ @@ -146,17 +142,27 @@ POST _plugins/_ml/_predict/sample_algo/247c5947-35a1-41a7-a95b-703a1e9b2203 } ``` -## Example 2 - Sample algorithm(no model) with execute API -Some algorithm like anomaly localization has no model. We can add such algorithm by exposing execute API only. +## Example 2 - Sample calculator with execute API +Some function like anomaly localization has no model. We can add such function by exposing execute API only. In this example, we will add a new sample calculator which runs on local node (don't dispatch task to other node). -The sample calculator supports calculating `sum` or `max` value of a data frame or selected data from indices. +The sample calculator supports calculating `sum`/`max`/`min` value from a double list. -### Step 1: add input class for new algorithm +### Step 1: name the function +Add new function name in `org.opensearch.ml.common.parameter.FunctionName` +``` +public enum FunctionName { + ... + // Add new enum + LOCAL_SAMPLE_CALCULATOR("local_sample_calculator"); +} +``` + +### Step 2: add input class Add new class `org.opensearch.ml.common.parameter.LocalSampleCalculatorInput` by implementing `Input`. Must define `NamedXContentRegistry.Entry` in `LocalSampleCalculatorInput`. ``` public class LocalSampleCalculatorInput implements Input { - public static final String PARSE_FIELD_NAME = "local_sample_calculator"; + public static final String PARSE_FIELD_NAME = FunctionName.LOCAL_SAMPLE_CALCULATOR.getName(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( Input.class, new ParseField(PARSE_FIELD_NAME), @@ -179,7 +185,7 @@ Register `SampleAlgoParams.XCONTENT_REGISTRY` in `MachineLearningPlugin.getNamed } ``` -### Step 2: add output class for new algorithm +### Step 3: add output class Add output class `org.opensearch.ml.common.parameter.LocalSampleCalculatorOutput` by implementing `Output`. ``` public class LocalSampleCalculatorOutput implements Output{ @@ -188,12 +194,12 @@ public class LocalSampleCalculatorOutput implements Output{ } ``` -### Step 3: add new algorithm in `ml-algorithms` package -Create new class `ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java` +### Step 4: add implementation +Create new class `ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator` in `ml-algorithms` package by implementing interface `Executable`. Override `execute` method. ``` -public class LocalSam**pleCalculator implements Executable { +public class LocalSampleCalculator implements Executable { @Override public Output execute(Input input) { .... @@ -202,26 +208,15 @@ public class LocalSam**pleCalculator implements Executable { } ``` -### Step 4: configure new algorithm -Add new algorithm/function name in enum `org.opensearch.ml.common.parameter.MLAlgoName`(may change the class name later) -``` -public enum MLAlgoName { - LINEAR_REGRESSION("linear_regression"), - KMEANS("kmeans"), - ... - // Add new enum - LOCAL_SAMPLE_CALCULATOR("local_sample_calculator"); -} -``` - -Add new executable function in config file `ml-algorithms/src/main/resources/ml-algorithm-config.yml` +### Step 5: configure new function +Add new function to config file `ml-algorithms/src/main/resources/ml-algorithm-config.yml` -Key is the new enum name `local_sample_calculator`, value is class name. +Key is the new function name `local_sample_calculator` added to `FunctionName` in Step 1, value is class name. ``` executable_function_class: local_sample_calculator: org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator ``` -### Step 5: Run and test +### Step 6: Run and test Run `./gradlew run` and test this sample calculator. ``` diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index 841d8b79d9..ca3fe712f6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -14,7 +14,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.parameter.Input; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.Output; @@ -35,7 +35,7 @@ public class MLEngine { MLCommonsClassLoader.loadClassMapping(MLEngine.class, "/ml-algorithm-config.yml"); } - public static MLOutput predict(MLAlgoName algoName, MLAlgoParams parameters, DataFrame dataFrame, Model model) { + public static MLOutput predict(FunctionName algoName, MLAlgoParams parameters, DataFrame dataFrame, Model model) { if (algoName == null) { throw new IllegalArgumentException("Algo name should not be null"); } @@ -46,7 +46,7 @@ public static MLOutput predict(MLAlgoName algoName, MLAlgoParams parameters, Dat return mlAlgo.predict(dataFrame, model); } - public static Model train(MLAlgoName algoName, MLAlgoParams parameters, DataFrame dataFrame) { + public static Model train(FunctionName algoName, MLAlgoParams parameters, DataFrame dataFrame) { if (algoName == null) { throw new IllegalArgumentException("Algo name should not be null"); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java index 30820a72ca..6a88900b9c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java @@ -15,7 +15,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.parameter.KMeansParams; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.MLPredictionOutput; @@ -122,7 +122,7 @@ public Model train(DataFrame dataFrame) { KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed); KMeansModel kMeansModel = trainer.train(trainDataset); Model model = new Model(); - model.setName("KMeans"); + model.setName(FunctionName.KMEANS.getName()); model.setVersion(1); model.setContent(ModelSerDeSer.serialize(kMeansModel)); @@ -131,7 +131,7 @@ public Model train(DataFrame dataFrame) { @Override public MLAlgoMetaData getMetaData() { - return MLAlgoMetaData.builder().name(MLAlgoName.KMEANS.getName()) + return MLAlgoMetaData.builder().name(FunctionName.KMEANS.getName()) .description("A clustering algorithm.") .version("1.0") .predictable(true) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java index 1f9f4aeedc..03eb18d4ab 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java @@ -15,7 +15,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.parameter.LinearRegressionParams; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.MLPredictionOutput; @@ -219,7 +219,7 @@ public Model train(DataFrame dataFrame) { @Override public MLAlgoMetaData getMetaData() { - return MLAlgoMetaData.builder().name(MLAlgoName.LINEAR_REGRESSION.getName()) + return MLAlgoMetaData.builder().name(FunctionName.LINEAR_REGRESSION.getName()) .description("Linear regression algorithm.") .version("1.0") .predictable(true) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java index fa3af17e51..a757888839 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java @@ -1,7 +1,7 @@ package org.opensearch.ml.engine.algorithms.sample; import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.SampleAlgoOutput; @@ -49,7 +49,7 @@ public Model train(DataFrame dataFrame) { @Override public MLAlgoMetaData getMetaData() { - return MLAlgoMetaData.builder().name(MLAlgoName.SAMPLE_ALGO.getName()) + return MLAlgoMetaData.builder().name(FunctionName.SAMPLE_ALGO.getName()) .description("A sample algorithm.") .version("1.0") .predictable(true) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 5084b18a10..a5087e7569 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -21,11 +21,10 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.parameter.KMeansParams; import org.opensearch.ml.common.parameter.LinearRegressionParams; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLPredictionOutput; import java.util.HashSet; -import java.util.Locale; import java.util.Set; import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame; @@ -40,9 +39,9 @@ public class MLEngineTest { @Before public void setUp() { - algoNames.add(MLAlgoName.KMEANS.getName()); - algoNames.add(MLAlgoName.LINEAR_REGRESSION.getName()); - algoNames.add(MLAlgoName.SAMPLE_ALGO.getName()); + algoNames.add(FunctionName.KMEANS.getName()); + algoNames.add(FunctionName.LINEAR_REGRESSION.getName()); + algoNames.add(FunctionName.SAMPLE_ALGO.getName()); MLCommonsClassLoader.loadClassMapping(MLCommonsClassLoader.class, "/ml-commons-config.yml"); MLCommonsClassLoader.loadClassMapping(MLEngine.class, "/ml-algorithm-config.yml"); } @@ -51,7 +50,7 @@ public void setUp() { public void predictKMeans() { Model model = trainKMeansModel(); DataFrame predictionDataFrame = constructKMeansDataFrame(10); - MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(MLAlgoName.KMEANS, null, predictionDataFrame, model); + MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(FunctionName.KMEANS, null, predictionDataFrame, model); DataFrame predictions = output.getPredictionResult(); Assert.assertEquals(10, predictions.size()); predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1)); @@ -61,7 +60,7 @@ public void predictKMeans() { public void predictLinearRegression() { Model model = trainLinearRegressionModel(); DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame(); - MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(MLAlgoName.LINEAR_REGRESSION, null, predictionDataFrame, model); + MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(FunctionName.LINEAR_REGRESSION, null, predictionDataFrame, model); DataFrame predictions = output.getPredictionResult(); Assert.assertEquals(2, predictions.size()); } @@ -69,7 +68,7 @@ public void predictLinearRegression() { @Test public void trainKMeans() { Model model = trainKMeansModel(); - Assert.assertEquals("KMeans", model.getName()); + Assert.assertEquals(FunctionName.KMEANS.getName(), model.getName()); Assert.assertEquals(1, model.getVersion()); Assert.assertNotNull(model.getContent()); } @@ -100,7 +99,7 @@ public void predictUnsupportedAlgorithm() { public void predictWithoutModel() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("No model found for linear regression prediction."); - MLEngine.predict(MLAlgoName.LINEAR_REGRESSION, null, null, null); + MLEngine.predict(FunctionName.LINEAR_REGRESSION, null, null, null); } @Test @@ -116,7 +115,7 @@ private Model trainKMeansModel() { .distanceType(KMeansParams.DistanceType.EUCLIDEAN) .build(); DataFrame trainDataFrame = constructKMeansDataFrame(100); - return MLEngine.train(MLAlgoName.KMEANS, parameters, trainDataFrame); + return MLEngine.train(FunctionName.KMEANS, parameters, trainDataFrame); } private Model trainLinearRegressionModel() { @@ -132,6 +131,6 @@ private Model trainLinearRegressionModel() { DataFrame trainDataFrame = constructLinearRegressionTrainDataFrame(); - return MLEngine.train(MLAlgoName.LINEAR_REGRESSION, parameters, trainDataFrame); + return MLEngine.train(FunctionName.LINEAR_REGRESSION, parameters, trainDataFrame); } } \ No newline at end of file diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/clustering/KMeansTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/clustering/KMeansTest.java index 5ed9c6bb60..4bf08601a5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/clustering/KMeansTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/clustering/KMeansTest.java @@ -17,6 +17,7 @@ import org.junit.Test; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.parameter.KMeansParams; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.ml.engine.Model; import org.opensearch.ml.engine.algorithms.clustering.KMeans; @@ -57,7 +58,7 @@ public void predict() { @Test public void train() { Model model = kMeans.train(trainDataFrame); - Assert.assertEquals("KMeans", model.getName()); + Assert.assertEquals(FunctionName.KMEANS.getName(), model.getName()); Assert.assertEquals(1, model.getVersion()); Assert.assertNotNull(model.getContent()); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java index 7de3075d3c..b1850d3f1e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java @@ -23,7 +23,7 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; @@ -83,14 +83,14 @@ public void testPredictionWithoutAlgorithm() throws IOException { } public void testPredictionWithoutModelId() throws IOException { - MLInput mlInput = MLInput.builder().algorithm(MLAlgoName.KMEANS).inputDataset(DATA_FRAME_INPUT_DATASET).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(DATA_FRAME_INPUT_DATASET).build(); MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest("", mlInput); ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); expectThrows(ResourceNotFoundException.class, () -> predictionFuture.actionGet()); } public void testPredictionWithoutDataset() throws IOException { - MLInput mlInput = MLInput.builder().algorithm(MLAlgoName.KMEANS).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build(); MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput); ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); expectThrows(ActionRequestValidationException.class, () -> predictionFuture.actionGet()); @@ -98,7 +98,7 @@ public void testPredictionWithoutDataset() throws IOException { public void testPredictionWithEmptyDataset() throws IOException { MLInputDataset emptySearchInputDataset = generateEmptyDataset(); - MLInput mlInput = MLInput.builder().algorithm(MLAlgoName.KMEANS).inputDataset(emptySearchInputDataset).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(emptySearchInputDataset).build(); MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput); ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); expectThrows(IllegalArgumentException.class, () -> predictionFuture.actionGet()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/training/TrainingITTests.java b/plugin/src/test/java/org/opensearch/ml/action/training/TrainingITTests.java index 5007b0db2a..655399e6ac 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/training/TrainingITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/training/TrainingITTests.java @@ -38,7 +38,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLTrainingOutput; import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; @@ -103,7 +103,7 @@ public void testTrainingWithoutAlgorithm() { // Train a model without dataset. public void testTrainingWithoutDataset() { - MLInput mlInput = MLInput.builder().algorithm(MLAlgoName.KMEANS).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build(); MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput); expectThrows(ActionRequestValidationException.class, () -> { ActionFuture trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest); @@ -116,7 +116,7 @@ public void testTrainingWithEmptyDataset() throws InterruptedException { SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder(); searchSourceBuilder.query(QueryBuilders.matchQuery("noSuchName", "")); MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder); - MLInput mlInput = MLInput.builder().algorithm(MLAlgoName.KMEANS).inputDataset(inputDataset).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build(); MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput); ActionFuture trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest); diff --git a/plugin/src/test/java/org/opensearch/ml/params/MLInputTests.java b/plugin/src/test/java/org/opensearch/ml/params/MLInputTests.java index 762a9b7705..ac9f92438f 100644 --- a/plugin/src/test/java/org/opensearch/ml/params/MLInputTests.java +++ b/plugin/src/test/java/org/opensearch/ml/params/MLInputTests.java @@ -9,6 +9,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.test.OpenSearchTestCase; @@ -18,7 +19,7 @@ public void testParseKmeansInputQuery() throws IOException { String query = "{\"input_query\":{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"k1\":1}}]}},\"size\":10},\"input_index\":[\"test_data\"]}"; XContentParser parser = parser(query); - MLInput mlInput = MLInput.parse(parser, "kmeans"); + MLInput mlInput = MLInput.parse(parser, FunctionName.KMEANS.getName()); String expectedQuery = "{\"size\":10,\"query\":{\"bool\":{\"filter\":[{\"term\":{\"k1\":{\"value\":1,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}}}"; SearchQueryInputDataset inputDataset = (SearchQueryInputDataset) mlInput.getInputDataset(); @@ -31,7 +32,7 @@ public void testParseKmeansInputDataFrame() throws IOException { + "{\"column_type\":\"BOOLEAN\",\"value\":false}]},{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":100}," + "{\"column_type\":\"BOOLEAN\",\"value\":true}]}]}}"; XContentParser parser = parser(query); - MLInput mlInput = MLInput.parse(parser, "kmeans"); + MLInput mlInput = MLInput.parse(parser, FunctionName.KMEANS.getName()); DataFrameInputDataset inputDataset = (DataFrameInputDataset) mlInput.getInputDataset(); DataFrame dataFrame = inputDataset.getDataFrame(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/BaseMLSearchActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/BaseMLSearchActionTests.java index 3a02e9096f..48ea605534 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/BaseMLSearchActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/BaseMLSearchActionTests.java @@ -27,6 +27,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.QueryStringQueryBuilder; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.plugin.MachineLearningPlugin; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -105,11 +106,11 @@ public void testGetAlgorithmWithEmptyInput() { @Test public void testGetAlgorithmWithValidInput() { - Map param = ImmutableMap.builder().put(PARAMETER_ALGORITHM, "kmeans").build(); + Map param = ImmutableMap.builder().put(PARAMETER_ALGORITHM, FunctionName.KMEANS.getName()).build(); FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build(); String algorithm = baseMLSearchAction.getAlgorithm(fakeRestRequest); assertFalse(Strings.isNullOrEmpty(algorithm)); - assertEquals(algorithm, "kmeans"); + assertEquals(algorithm, FunctionName.KMEANS.getName()); } @Test @@ -143,7 +144,7 @@ public void testGetModelIdWithValidInput() { public void testGetSearchQueryWithoutSearchInput() throws IOException { Map param = ImmutableMap .builder() - .put(PARAMETER_ALGORITHM, "kmeans") + .put(PARAMETER_ALGORITHM, FunctionName.KMEANS.getName()) .put("index", "index1,index2") .build(); FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build(); @@ -199,7 +200,7 @@ public void testGetSearchQueryWithoutIndices() throws IOException { public void testGetSearchQueryWithSearchParams() throws IOException { Map param = ImmutableMap .builder() - .put(PARAMETER_ALGORITHM, "kmeans") + .put(PARAMETER_ALGORITHM, FunctionName.KMEANS.getName()) .put("index", "index1,index2") .put("q", "user:dilbert") .build(); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/IntegTestUtils.java b/plugin/src/test/java/org/opensearch/ml/utils/IntegTestUtils.java index bed37c5833..89bb580087 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/IntegTestUtils.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/IntegTestUtils.java @@ -39,7 +39,7 @@ import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; -import org.opensearch.ml.common.parameter.MLAlgoName; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.ml.common.parameter.MLTrainingOutput; @@ -121,7 +121,7 @@ public static SearchSourceBuilder generateSearchSourceBuilder() { // Train a model. public static String trainModel(MLInputDataset inputDataset) throws ExecutionException, InterruptedException { - MLInput mlInput = MLInput.builder().algorithm(MLAlgoName.KMEANS).inputDataset(inputDataset).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build(); MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput); ActionFuture trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest); MLTrainingTaskResponse trainingResponse = trainingFuture.actionGet(); @@ -161,7 +161,7 @@ public static SearchResponse waitModelAvailable(String taskId) throws Interrupte // Predict with the model generated, and verify the prediction result. public static void predictAndVerifyResult(String taskId, MLInputDataset inputDataset) throws IOException { - MLInput mlInput = MLInput.builder().algorithm(MLAlgoName.KMEANS).inputDataset(inputDataset).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build(); MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput); ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); MLPredictionTaskResponse predictionResponse = predictionFuture.actionGet();