Skip to content

Commit

Permalink
tune class name; tune how to add new doc
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn committed Nov 23, 2021
1 parent db67aa2 commit 54bf452
Show file tree
Hide file tree
Showing 27 changed files with 131 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ default ActionFuture<MLOutput> train(MLInput mlInput) {
void train(MLInput mlInput, ActionListener<MLOutput> listener);

/**
* Execute ML algorithm.
* Execute function and return ActionFuture.
* @param input input data
* @return output
* @return ActionFuture of output
*/
default ActionFuture<Output> execute(Input input) {
PlainActionFuture<Output> actionFuture = PlainActionFuture.newFuture();
Expand All @@ -76,7 +76,7 @@ default ActionFuture<Output> execute(Input input) {
}

/**
* Execute ML algorithm
* Execute function and return output in listener
* @param input input data
* @param listener action listener
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.parameter.MLAlgoName;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLAlgoParams;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.parameter.MLOutput;
import org.opensearch.ml.common.parameter.MLOutputType;
import org.opensearch.ml.common.parameter.MLTrainingOutput;
import org.opensearch.ml.common.parameter.Output;

Expand Down Expand Up @@ -93,7 +92,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
@Test
public void predict_WithAlgoAndInputData() {
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.dataFrame(input)
.build();
assertEquals(output, machineLearningClient.predict(null, mlInput).actionGet());
Expand All @@ -102,7 +101,7 @@ public void predict_WithAlgoAndInputData() {
@Test
public void predict_WithAlgoAndParametersAndInputData() {
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.parameters(mlParameters)
.dataFrame(input)
.build();
Expand All @@ -112,7 +111,7 @@ public void predict_WithAlgoAndParametersAndInputData() {
@Test
public void predict_WithAlgoAndParametersAndInputDataAndModelId() {
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.parameters(mlParameters)
.dataFrame(input)
.build();
Expand All @@ -122,7 +121,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() {
@Test
public void predict_WithAlgoAndInputDataAndListener() {
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.dataFrame(input)
.build();
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
Expand All @@ -134,7 +133,7 @@ public void predict_WithAlgoAndInputDataAndListener() {
@Test
public void predict_WithAlgoAndInputDataAndParametersAndListener() {
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.parameters(mlParameters)
.dataFrame(input)
.build();
Expand All @@ -147,7 +146,7 @@ public void predict_WithAlgoAndInputDataAndParametersAndListener() {
@Test
public void train() {
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.parameters(mlParameters)
.dataFrame(input)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.parameter.MLAlgoName;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.parameter.MLOutput;
import org.opensearch.ml.common.parameter.MLPredictionOutput;
Expand Down Expand Up @@ -88,7 +88,7 @@ public void predict() {

ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.inputDataset(input)
.build();
machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
Expand All @@ -114,7 +114,7 @@ public void predict_Exception_WithNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data set can't be null");
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.build();
machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
}
Expand All @@ -137,7 +137,7 @@ public void train() {

ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.inputDataset(input)
.build();
machineLearningNodeClient.train(mlInput, trainingActionListener);
Expand All @@ -164,7 +164,7 @@ public void train_Exception_WithNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data set can't be null");
MLInput mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.build();
machineLearningNodeClient.train(mlInput, trainingActionListener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.parameter.MLAlgoName;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLAlgoParams;
import org.opensearch.ml.common.parameter.MLOutputType;

Expand Down Expand Up @@ -57,15 +57,15 @@ public static void loadClassMapping(Class<?> resource, String configFile) {
if (currentToken == XContentParser.Token.FIELD_NAME) {
String key = parser.currentName();
if ("ml_algo_param_class".equals(key)) {
parseMLAlgoParams(parser, parameterClassMap, k -> MLAlgoName.fromString(k));
parseMLAlgoParams(parser, parameterClassMap, k -> FunctionName.fromString(k));
} else if ("ml_input_data_set_class".equals(key)) {
parseMLAlgoParams(parser, parameterClassMap, k -> MLInputDataType.fromString(k));
} else if ("ml_output_class".equals(key)) {
parseMLAlgoParams(parser, parameterClassMap, k -> MLOutputType.fromString(k));
} else if ("ml_algo_class".equals(key)) {
parseMLAlgoParams(parser, mlAlgoClassMap, k -> MLAlgoName.fromString(k));
parseMLAlgoParams(parser, mlAlgoClassMap, k -> FunctionName.fromString(k));
} else if ("executable_function_class".equals(key)) {
parseMLAlgoParams(parser, mlAlgoClassMap, k -> MLAlgoName.fromString(k));
parseMLAlgoParams(parser, mlAlgoClassMap, k -> FunctionName.fromString(k));
}
} else {
parser.nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import lombok.Getter;

public enum DataFrameType {
DEFAULT("DEFAULT");
DEFAULT("default");

@Getter
private final String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import lombok.Getter;

public enum MLAlgoName {
public enum FunctionName {
LINEAR_REGRESSION("linear_regression"),
KMEANS("kmeans"),
SAMPLE_ALGO("sample_algo"),
Expand All @@ -11,16 +11,16 @@ public enum MLAlgoName {
@Getter
private final String name;

MLAlgoName(String name) {
FunctionName(String name) {
this.name = name;
}

public String toString() {
return name;
}

public static MLAlgoName fromString(String name){
for(MLAlgoName e : MLAlgoName.values()){
public static FunctionName fromString(String name){
for(FunctionName e : FunctionName.values()){
if(e.name.equals(name)) return e;
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@

public interface Input extends ToXContentObject, Writeable {

MLAlgoName getFunctionName();
FunctionName getFunctionName();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@Data
public class KMeansParams implements MLAlgoParams {

public static final String PARSE_FIELD_NAME = "kmeans";
public static final String PARSE_FIELD_NAME = FunctionName.KMEANS.getName();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(PARSE_FIELD_NAME),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@Data
public class LinearRegressionParams implements MLAlgoParams {

public static final String PARSE_FIELD_NAME = "linear_regression";
public static final String PARSE_FIELD_NAME = FunctionName.LINEAR_REGRESSION.getName();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(PARSE_FIELD_NAME),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@Getter
public class LocalSampleCalculatorInput implements Input {
public static final String PARSE_FIELD_NAME = "local_sample_calculator";
public static final String PARSE_FIELD_NAME = FunctionName.LOCAL_SAMPLE_CALCULATOR.getName();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
Input.class,
new ParseField(PARSE_FIELD_NAME),
Expand Down Expand Up @@ -81,8 +81,8 @@ public LocalSampleCalculatorInput(String operation, List<Double> inputData) {
}

@Override
public MLAlgoName getFunctionName() {
return MLAlgoName.LOCAL_SAMPLE_CALCULATOR;
public FunctionName getFunctionName() {
return FunctionName.LOCAL_SAMPLE_CALCULATOR;
}

public LocalSampleCalculatorInput(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class MLInput implements Input {
public static final String INPUT_DATA_FIELD = "input_data";

// Algorithm name
private MLAlgoName algorithm;
private FunctionName algorithm;
// ML algorithm parameters
private MLAlgoParams parameters;
// Input data to train model, run trained model to predict or run ML algorithms(no-model-based) directly.
Expand All @@ -55,7 +55,7 @@ public class MLInput implements Input {
private int version = 1;

@Builder
public MLInput(MLAlgoName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, List<String> sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) {
public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, List<String> sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) {
this.algorithm = algorithm;
this.parameters = parameters;
if (inputDataset != null) {
Expand All @@ -66,7 +66,7 @@ public MLInput(MLAlgoName algorithm, MLAlgoParams parameters, SearchSourceBuilde
}

public MLInput(StreamInput in) throws IOException {
this.algorithm = in.readEnum(MLAlgoName.class);
this.algorithm = in.readEnum(FunctionName.class);
if (in.readBoolean()) {
this.parameters = MLCommonsClassLoader.initInstance(algorithm, in, StreamInput.class);
}
Expand Down Expand Up @@ -120,7 +120,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

public static MLInput parse(XContentParser parser, String algorithmName) throws IOException {
MLAlgoName algorithm = MLAlgoName.fromString(algorithmName);
FunctionName algorithm = FunctionName.fromString(algorithmName);
MLAlgoParams mlParameters = null;
SearchSourceBuilder searchSourceBuilder = null;
List<String> sourceIndices = new ArrayList<>();
Expand Down Expand Up @@ -166,7 +166,7 @@ private MLInputDataset createInputDataSet(SearchSourceBuilder searchSourceBuilde
}

@Override
public MLAlgoName getFunctionName() {
public FunctionName getFunctionName() {
return this.algorithm;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@

@Data
public class SampleAlgoParams implements MLAlgoParams {

public static final String PARSE_FIELD_NAME = "sample_algo";
public static final String PARSE_FIELD_NAME = FunctionName.SAMPLE_ALGO.getName();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(PARSE_FIELD_NAME),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.parameter.MLAlgoName;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLAlgoParams;
import org.opensearch.ml.common.parameter.SampleAlgoParams;

Expand Down Expand Up @@ -48,13 +47,13 @@ public void setUp() throws IOException {

@Test
public void testClassLoader_SampleAlgoParams() {
SampleAlgoParams sampleAlgoParams = MLCommonsClassLoader.initInstance(MLAlgoName.SAMPLE_ALGO, streamInput, StreamInput.class);
SampleAlgoParams sampleAlgoParams = MLCommonsClassLoader.initInstance(FunctionName.SAMPLE_ALGO, streamInput, StreamInput.class);
assertEquals(params.getSampleParam(), sampleAlgoParams.getSampleParam());
}

@Test
public void testClassLoader_Return_MLAlgoParams() {
MLAlgoParams mlAlgoParams = MLCommonsClassLoader.initInstance(MLAlgoName.SAMPLE_ALGO, streamInput, StreamInput.class);
MLAlgoParams mlAlgoParams = MLCommonsClassLoader.initInstance(FunctionName.SAMPLE_ALGO, streamInput, StreamInput.class);
assertTrue(mlAlgoParams instanceof SampleAlgoParams);
assertEquals(params.getSampleParam(), ((SampleAlgoParams)mlAlgoParams).getSampleParam());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.parameter.KMeansParams;
import org.opensearch.ml.common.parameter.MLAlgoName;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.search.builder.SearchSourceBuilder;

Expand All @@ -50,7 +50,7 @@ public class MLPredictionTaskRequestTest {
@Before
public void setUp() {
mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.parameters(KMeansParams.builder().centroids(1).build())
.dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {{
put("key1", 2.0D);
Expand All @@ -68,7 +68,7 @@ public void writeTo_Success() throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
request.writeTo(bytesStreamOutput);
request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(MLAlgoName.KMEANS, request.getMlInput().getAlgorithm());
assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm());
KMeansParams params = (KMeansParams)request.getMlInput().getParameters();
assertEquals(1, params.getCentroids().intValue());
MLInputDataset inputDataset = request.getMlInput().getInputDataset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.parameter.KMeansParams;
import org.opensearch.ml.common.parameter.MLAlgoName;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLInput;

import java.io.IOException;
Expand All @@ -41,7 +41,7 @@ public class MLTrainingTaskRequestTest {
@Before
public void setUp() {
mlInput = MLInput.builder()
.algorithm(MLAlgoName.KMEANS)
.algorithm(FunctionName.KMEANS)
.parameters(KMeansParams.builder().centroids(1).build())
.dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {{
put("key1", 2.0D);
Expand Down Expand Up @@ -75,7 +75,7 @@ public void writeTo() throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
request.writeTo(bytesStreamOutput);
request = new MLTrainingTaskRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(MLAlgoName.KMEANS, request.getMlInput().getAlgorithm());
assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm());
assertEquals(1, ((KMeansParams) request.getMlInput().getParameters()).getCentroids().intValue());
assertEquals(MLInputDataType.DATA_FRAME, request.getMlInput().getInputDataset().getInputDataType());
}
Expand Down
Loading

0 comments on commit 54bf452

Please sign in to comment.