diff --git a/common/build.gradle b/common/build.gradle index 6d57630147..5c50796f40 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -14,6 +14,8 @@ dependencies { compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" compile group: 'org.reflections', name: 'reflections', version: '0.9.12' testCompile group: 'junit', name: 'junit', version: '4.12' + compile "org.opensearch.client:opensearch-rest-client:${opensearch_version}" + compile "org.opensearch:common-utils:${common_utils_version}" } jacocoTestReport { diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java b/common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java new file mode 100644 index 0000000000..3ca7457556 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java @@ -0,0 +1,142 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.parameter; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.ToXContentObject; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.commons.authuser.User; + +import java.io.IOException; +import java.util.Base64; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +@Getter +public class MLModel implements ToXContentObject { + public static final String ALGORITHM = "algorithm"; + public static final String MODEL_NAME = "name"; + public static final String MODEL_VERSION = "version"; + public static final String MODEL_CONTENT = "content"; + public static final String USER = "user"; + + private String name; + private FunctionName algorithm; + private Integer version; + private String content; + private User user; + + @Builder + public MLModel(String name, FunctionName algorithm, Integer version, String content, User user) { + this.name = name; + this.algorithm = algorithm; + this.version = version; + this.content = content; + this.user = user; + } + + public MLModel(FunctionName algorithm, Model model) { + this(model.getName(), algorithm, model.getVersion(), Base64.getEncoder().encodeToString(model.getContent()), null); + } + + public MLModel(StreamInput input) throws IOException{ + name = input.readOptionalString(); + algorithm = input.readEnum(FunctionName.class); + version = input.readInt(); + content = input.readOptionalString(); + if (input.readBoolean()) { + this.user = new User(input); + } else { + user = null; + } + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(name); + out.writeEnum(algorithm); + out.writeInt(version); + out.writeOptionalString(content); + if (user != null) { + out.writeBoolean(true); // user exists + user.writeTo(out); + } else { + out.writeBoolean(false); // user does not exist + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (name != null) { + builder.field(MODEL_NAME, name); + } + if (algorithm != null) { + builder.field(ALGORITHM, algorithm); + } + if (version != null) { + builder.field(MODEL_VERSION, version); + } + if (content != null) { + builder.field(MODEL_CONTENT, content); + } + if (user != null) { + builder.field(USER, user); + } + builder.endObject(); + return builder; + } + + public static MLModel parse(XContentParser parser, String taskId) throws IOException { + String name = null; + FunctionName algorithm = null; + Integer version = null; + String content = null; + User user = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_NAME: + name = parser.text(); + break; + case MODEL_CONTENT: + content = parser.text(); + break; + case MODEL_VERSION: + version = parser.intValue(); + break; + case USER: + user = User.parse(parser); + break; + case ALGORITHM: + algorithm = FunctionName.valueOf(parser.text()); + break; + default: + parser.skipChildren(); + break; + } + } + return MLModel.builder() + .name(name) + .algorithm(algorithm) + .version(version) + .content(content) + .user(user) + .build(); + } + + public static MLModel fromStream(StreamInput in) throws IOException { + MLModel mlModel = new MLModel(in); + return mlModel; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Model.java b/common/src/main/java/org/opensearch/ml/common/parameter/Model.java similarity index 83% rename from ml-algorithms/src/main/java/org/opensearch/ml/engine/Model.java rename to common/src/main/java/org/opensearch/ml/common/parameter/Model.java index 6328a2c597..e828fb099e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Model.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/Model.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.engine; +package org.opensearch.ml.common.parameter; import lombok.Data; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java new file mode 100644 index 0000000000..37e3831404 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.opensearch.action.ActionType; + +public class MLModelGetAction extends ActionType { + public static final MLModelGetAction INSTANCE = new MLModelGetAction(); + public static final String NAME = "cluster:admin/opensearch/ml/models/get"; + + private MLModelGetAction() { super(NAME, MLModelGetResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java new file mode 100644 index 0000000000..dd05f06b45 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.InputStreamStreamInput; +import org.opensearch.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLModelGetRequest extends ActionRequest { + + String modelId; + + @Builder + public MLModelGetRequest(String modelId) { + this.modelId = modelId; + } + + public MLModelGetRequest(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.modelId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.modelId == null) { + exception = addValidationError("ML model id can't be null", exception); + } + + return exception; + } + + public static MLModelGetRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLModelGetRequest) { + return (MLModelGetRequest)actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLModelGetRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLModelGetRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java new file mode 100644 index 0000000000..eda09e77df --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.Builder; +import org.opensearch.action.ActionResponse; +import org.opensearch.common.io.stream.InputStreamStreamInput; +import org.opensearch.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.ToXContentObject; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.MLModel; +import org.opensearch.ml.common.parameter.MLOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +public class MLModelGetResponse extends ActionResponse implements ToXContentObject { + + MLModel mlModel; + + @Builder + public MLModelGetResponse(MLModel mlModel) { + this.mlModel = mlModel; + } + + + public MLModelGetResponse(StreamInput in) throws IOException { + super(in); + mlModel = mlModel.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException{ + mlModel.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return mlModel.toXContent(xContentBuilder, params); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/TestHelper.java b/common/src/test/java/org/opensearch/ml/common/TestHelper.java index 19237e3793..3a843e7cdb 100644 --- a/common/src/test/java/org/opensearch/ml/common/TestHelper.java +++ b/common/src/test/java/org/opensearch/ml/common/TestHelper.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common; import org.opensearch.common.Strings; +import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.ToXContent; @@ -43,4 +44,8 @@ public static void testParseFromString(ToXContentObject obj, String jsonStr, T parsedObj = function.apply(parser); obj.equals(parsedObj); } + + public static String xContentBuilderToString(XContentBuilder builder) { + return BytesReference.bytes(builder).utf8ToString(); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/parameter/MLModelTests.java b/common/src/test/java/org/opensearch/ml/common/parameter/MLModelTests.java new file mode 100644 index 0000000000..d8141a743f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/parameter/MLModelTests.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.parameter; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; + +public class MLModelTests { + + MLModel mlModel; + @Before + public void setUp() { + FunctionName algorithm = FunctionName.KMEANS; + User user = new User(); + mlModel = MLModel.builder() + .name("some model") + .algorithm(algorithm) + .version(1) + .content("some content") + .user(user) + .build(); + } + + @Test + public void toXContent() throws IOException { + MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version(1).content("test_content").build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlModel.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"version\":1,\"content\":\"test_content\"}", mlModelContent); + } + + @Test + public void toXContent_NullValue() throws IOException { + MLModel mlModel = MLModel.builder().build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlModel.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{}", mlModelContent); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(mlModel); + } + + public void readInputStream(MLModel mlModel) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlModel.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLModel parsedMLModel = new MLModel(streamInput); + assertEquals(mlModel.getName(), parsedMLModel.getName()); + assertEquals(mlModel.getAlgorithm(), parsedMLModel.getAlgorithm()); + assertEquals(mlModel.getVersion(), parsedMLModel.getVersion()); + assertEquals(mlModel.getContent(), parsedMLModel.getContent()); + assertEquals(mlModel.getUser(), parsedMLModel.getUser()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java new file mode 100644 index 0000000000..94def0df1d --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +public class MLModelGetRequestTest { + private String modelId; + + @Before + public void setUp() { + modelId = "test_id"; + } + + @Test + public void writeTo_Success() throws IOException { + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder() + .modelId(modelId).build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlModelGetRequest.writeTo(bytesStreamOutput); + MLModelGetRequest parsedModel = new MLModelGetRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedModel.getModelId(), modelId); + } + + @Test + public void validate_Exception_NullModelId() { + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().build(); + + ActionRequestValidationException exception = mlModelGetRequest.validate(); + assertEquals("Validation Failed: 1: ML model id can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder() + .modelId(modelId).build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlModelGetRequest.writeTo(out); + } + }; + MLModelGetRequest result = MLModelGetRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlModelGetRequest); + assertEquals(result.getModelId(), mlModelGetRequest.getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLModelGetRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java new file mode 100644 index 0000000000..171ed76337 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java @@ -0,0 +1,62 @@ +package org.opensearch.ml.common.transport.model; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.MLModel; + +import java.io.IOException; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; + +public class MLModelGetResponseTest { + + MLModel mlModel; + + @Before + public void setUp() { + mlModel = MLModel.builder() + .name("model") + .algorithm(FunctionName.KMEANS) + .version(1) + .content("content") + .user(new User()) + .build(); + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + MLModelGetResponse response = MLModelGetResponse.builder().mlModel(mlModel).build(); + response.writeTo(bytesStreamOutput); + MLModelGetResponse parsedResponse = new MLModelGetResponse(bytesStreamOutput.bytes().streamInput()); + assertNotEquals(response.mlModel, parsedResponse.mlModel); + assertEquals(response.mlModel.getName(), parsedResponse.mlModel.getName()); + assertEquals(response.mlModel.getAlgorithm(), parsedResponse.mlModel.getAlgorithm()); + assertEquals(response.mlModel.getVersion(), parsedResponse.mlModel.getVersion()); + assertEquals(response.mlModel.getContent(), parsedResponse.mlModel.getContent()); + assertEquals(response.mlModel.getUser(), parsedResponse.mlModel.getUser()); + } + + @Test + public void toXContentTest() throws IOException { + MLModelGetResponse mlModelGetResponse = MLModelGetResponse.builder().mlModel(mlModel).build(); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + mlModelGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = Strings.toString(builder); + assertEquals("{\"name\":\"model\"," + + "\"algorithm\":\"KMEANS\"," + + "\"version\":1," + + "\"content\":\"content\"," + + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}}", jsonStr); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index 6c56ff340f..257fb3e269 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -10,6 +10,7 @@ import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLOutput; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.common.parameter.Output; /** diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java index 764ab930a8..db1ac1eef0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java @@ -7,7 +7,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.parameter.MLOutput; - +import org.opensearch.ml.common.parameter.Model; /** * This is machine learning algorithms predict interface. diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Trainable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Trainable.java index 7cb994ae31..c8b8d49f84 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Trainable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Trainable.java @@ -6,7 +6,7 @@ package org.opensearch.ml.engine; import org.opensearch.ml.common.dataframe.DataFrame; - +import org.opensearch.ml.common.parameter.Model; /** * This is machine learning algorithms train interface. diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java index 58c7a37fab..e44747a686 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java @@ -12,7 +12,7 @@ import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.MLPredictionOutput; -import org.opensearch.ml.engine.Model; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.engine.TrainAndPredictable; import org.opensearch.ml.engine.annotation.Function; import org.opensearch.ml.engine.utils.ModelSerDeSer; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java index 14a8554efd..a5f0e3469f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java @@ -12,7 +12,7 @@ import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.MLPredictionOutput; -import org.opensearch.ml.engine.Model; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.engine.Predictable; import org.opensearch.ml.engine.Trainable; import org.opensearch.ml.engine.annotation.Function; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java index 8c6b2f4efa..c6799e8f0b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java @@ -11,7 +11,7 @@ import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.SampleAlgoOutput; import org.opensearch.ml.common.parameter.SampleAlgoParams; -import org.opensearch.ml.engine.Model; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.engine.Predictable; import org.opensearch.ml.engine.Trainable; import org.opensearch.ml.engine.annotation.Function; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 27e20802ba..0757379115 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -20,6 +20,7 @@ import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.common.parameter.MLPredictionOutput; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java index 6026172dc3..47034d3dfc 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java @@ -9,6 +9,7 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.ml.common.parameter.KMeansParams; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.engine.algorithms.clustering.KMeans; import org.opensearch.ml.engine.exceptions.ModelSerDeSerException; import org.opensearch.ml.engine.utils.ModelSerDeSer; @@ -40,8 +41,8 @@ public void testModelSerDeSerKMeans() { KMeans kMeans = new KMeans(params); Model model = kMeans.train(constructKMeansDataFrame(100)); - KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.content); + KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent()); byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel); - assertFalse(Arrays.equals(serializedModel, model.content)); + assertFalse(Arrays.equals(serializedModel, model.getContent())); } } \ No newline at end of file diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/clustering/KMeansTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/clustering/KMeansTest.java index cd277d00c9..91f98a5bf3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/clustering/KMeansTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/clustering/KMeansTest.java @@ -12,7 +12,7 @@ import org.opensearch.ml.common.parameter.KMeansParams; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLPredictionOutput; -import org.opensearch.ml.engine.Model; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.engine.algorithms.clustering.KMeans; import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/regression/LinearRegressionTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/regression/LinearRegressionTest.java index 2186406c50..87727204a8 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/regression/LinearRegressionTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/regression/LinearRegressionTest.java @@ -14,7 +14,7 @@ import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.LinearRegressionParams; import org.opensearch.ml.common.parameter.MLPredictionOutput; -import org.opensearch.ml.engine.Model; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.engine.algorithms.regression.LinearRegression; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame; diff --git a/plugin/build.gradle b/plugin/build.gradle index e9d28133ed..3d8e88b99b 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -205,7 +205,8 @@ List jacocoExclusions = [ 'org.opensearch.ml.utils.RestActionUtils', 'org.opensearch.ml.task.MLTaskCache', 'org.opensearch.ml.task.MLTaskManager', - 'org.opensearch.ml.task.MLTrainAndPredictTaskRunner' + 'org.opensearch.ml.task.MLTrainAndPredictTaskRunner', + 'org.opensearch.ml.rest.RestMLGetModelAction' ] jacocoTestCoverageVerification { diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java new file mode 100644 index 0000000000..55c858f50b --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.parameter.MLModel; +import org.opensearch.ml.common.transport.model.MLModelGetAction; +import org.opensearch.ml.common.transport.model.MLModelGetRequest; +import org.opensearch.ml.common.transport.model.MLModelGetResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class GetModelTransportAction extends HandledTransportAction { + + TransportService transportService; + Client client; + NamedXContentRegistry xContentRegistry; + + @Inject + public GetModelTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(MLModelGetAction.NAME, transportService, actionFilters, MLModelGetRequest::new); + this.transportService = transportService; + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.fromActionRequest(request); + String modelId = mlModelGetRequest.getModelId(); + GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId); + + client.get(getRequest, ActionListener.wrap(r -> { + log.info("Completed Get Model Request, id:{}", modelId); + + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModel mlModel = MLModel.parse(parser, r.getId()); + actionListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); + } catch (Exception e) { + log.error("Failed to parse ml model" + r.getId(), e); + actionListener.onFailure(e); + } + } else { + actionListener.onFailure(new MLResourceNotFoundException("Fail to find model " + modelId)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + actionListener.onFailure(new MLResourceNotFoundException("Fail to find model " + modelId)); + } else { + log.error("Failed to get ML model " + modelId, e); + actionListener.onFailure(e); + } + })); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModel.java b/plugin/src/main/java/org/opensearch/ml/model/MLModel.java deleted file mode 100644 index dd8f649f83..0000000000 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModel.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.model; - -import java.io.IOException; -import java.util.Base64; - -import lombok.Builder; -import lombok.Getter; - -import org.opensearch.common.xcontent.ToXContentObject; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.commons.authuser.User; -import org.opensearch.ml.common.parameter.FunctionName; -import org.opensearch.ml.engine.Model; - -@Getter -public class MLModel implements ToXContentObject { - public static final String ALGORITHM = "algorithm"; - public static final String MODEL_NAME = "name"; - public static final String MODEL_VERSION = "version"; - public static final String MODEL_CONTENT = "content"; - public static final String USER = "user"; - - private String name; - private FunctionName algorithm; - private Integer version; - private String content; - private User user; - - @Builder - public MLModel(String name, FunctionName algorithm, Integer version, String content, User user) { - this.name = name; - this.algorithm = algorithm; - this.version = version; - this.content = content; - this.user = user; - } - - public MLModel(FunctionName algorithm, Model model) { - this(model.getName(), algorithm, model.getVersion(), Base64.getEncoder().encodeToString(model.getContent()), null); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (name != null) { - builder.field(MODEL_NAME, name); - } - if (algorithm != null) { - builder.field(ALGORITHM, algorithm); - } - if (version != null) { - builder.field(MODEL_VERSION, version); - } - if (content != null) { - builder.field(MODEL_CONTENT, content); - } - if (user != null) { - builder.field(USER, user); - } - builder.endObject(); - return builder; - } - -} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index cc6a712213..7d3fbacbde 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -28,6 +28,7 @@ import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; import org.opensearch.ml.action.execute.TransportExecuteTaskAction; +import org.opensearch.ml.action.models.GetModelTransportAction; import org.opensearch.ml.action.prediction.TransportPredictionTaskAction; import org.opensearch.ml.action.stats.MLStatsNodesAction; import org.opensearch.ml.action.stats.MLStatsNodesTransportAction; @@ -39,6 +40,7 @@ import org.opensearch.ml.common.parameter.LocalSampleCalculatorInput; import org.opensearch.ml.common.parameter.SampleAlgoParams; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; +import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; @@ -49,6 +51,7 @@ import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.indices.MLInputDatasetHandler; import org.opensearch.ml.rest.RestMLExecuteAction; +import org.opensearch.ml.rest.RestMLGetModelAction; import org.opensearch.ml.rest.RestMLPredictionAction; import org.opensearch.ml.rest.RestMLTrainAndPredictAction; import org.opensearch.ml.rest.RestMLTrainingAction; @@ -111,7 +114,8 @@ public Setting legacySetting() { new ActionHandler<>(MLExecuteTaskAction.INSTANCE, TransportExecuteTaskAction.class), new ActionHandler<>(MLPredictionTaskAction.INSTANCE, TransportPredictionTaskAction.class), new ActionHandler<>(MLTrainingTaskAction.INSTANCE, TransportTrainingTaskAction.class), - new ActionHandler<>(MLTrainAndPredictionTaskAction.INSTANCE, TransportTrainAndPredictionTaskAction.class) + new ActionHandler<>(MLTrainAndPredictionTaskAction.INSTANCE, TransportTrainAndPredictionTaskAction.class), + new ActionHandler<>(MLModelGetAction.INSTANCE, GetModelTransportAction.class) ); } @@ -218,8 +222,16 @@ public List getRestHandlers( RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction(); RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(); RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(); + RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(); return ImmutableList - .of(restStatsMLAction, restMLTrainingAction, restMLPredictionAction, restMLExecuteAction, restMLTrainAndPredictAction); + .of( + restStatsMLAction, + restMLTrainingAction, + restMLPredictionAction, + restMLExecuteAction, + restMLTrainAndPredictAction, + restMLGetModelAction + ); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java new file mode 100644 index 0000000000..dd6232833b --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.*; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.model.MLModelGetAction; +import org.opensearch.ml.common.transport.model.MLModelGetRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetModelAction extends BaseRestHandler { + private static final String ML_GET_MODEL_ACTION = "ml_get_model_action"; + + /** + * Constructor + */ + public RestMLGetModelAction() {} + + @Override + public String getName() { + return ML_GET_MODEL_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/models/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLModelGetRequest mlModelGetRequest = getRequest(request); + return channel -> client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLModelGetRequest from a RestRequest + * + * @param request RestRequest + * @return MLModelGetRequest + */ + @VisibleForTesting + MLModelGetRequest getRequest(RestRequest request) throws IOException { + String modelId = getModelId(request); + + return new MLModelGetRequest(modelId); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 23fabefe81..4afb6b68fd 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -31,15 +31,15 @@ import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLModel; import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.engine.Model; import org.opensearch.ml.indices.MLInputDatasetHandler; -import org.opensearch.ml.model.MLModel; import org.opensearch.ml.model.MLTask; import org.opensearch.ml.model.MLTaskState; import org.opensearch.ml.model.MLTaskType; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index e7be47da26..e07fbb0f8c 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -28,15 +28,15 @@ import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLModel; import org.opensearch.ml.common.parameter.MLTrainingOutput; +import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.engine.Model; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.indices.MLInputDatasetHandler; -import org.opensearch.ml.model.MLModel; import org.opensearch.ml.model.MLTask; import org.opensearch.ml.model.MLTaskState; import org.opensearch.ml.model.MLTaskType; @@ -184,7 +184,7 @@ private void train(MLTask mlTask, MLInput mlInput, ActionListener { - log.info("mode data indexing done, result:{}, model id: {}", r.getResult(), r.getId()); + log.info("Model data indexing done, result:{}, model id: {}", r.getResult(), r.getId()); handleMLTaskComplete(mlTask); MLTrainingOutput output = new MLTrainingOutput(r.getId(), mlTask.getTaskId(), MLTaskState.COMPLETED.name()); listener.onResponse(MLTaskResponse.builder().output(output).build()); diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java index 0bef7d1fe7..54c7fc5b7b 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java @@ -5,9 +5,13 @@ package org.opensearch.ml.utils; +import java.io.IOException; + import lombok.experimental.UtilityClass; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.xcontent.*; import org.opensearch.ml.plugin.MachineLearningPlugin; @UtilityClass @@ -15,4 +19,9 @@ public class MLNodeUtils { public boolean isMLNode(DiscoveryNode node) { return node.getRoles().stream().anyMatch(role -> role.roleName().equalsIgnoreCase(MachineLearningPlugin.ML_ROLE.roleName())); } + + public static XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference) + throws IOException { + return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelTests.java deleted file mode 100644 index 966dc51ccf..0000000000 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelTests.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.model; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; - -import java.io.IOException; - -import org.junit.Test; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.ml.common.parameter.FunctionName; -import org.opensearch.ml.utils.TestHelper; - -public class MLModelTests { - - @Test - public void toXContent() throws IOException { - MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version(1).content("test_content").build(); - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mlModel.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"version\":1,\"content\":\"test_content\"}", mlModelContent); - } - - @Test - public void toXContent_NullValue() throws IOException { - MLModel mlModel = MLModel.builder().build(); - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mlModel.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{}", mlModelContent); - } -}