Skip to content

Commit

Permalink
add ML model get API
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt committed Feb 2, 2022
1 parent b9efc53 commit 7c45ab5
Show file tree
Hide file tree
Showing 29 changed files with 700 additions and 124 deletions.
2 changes: 2 additions & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dependencies {
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
compile group: 'org.reflections', name: 'reflections', version: '0.9.12'
testCompile group: 'junit', name: 'junit', version: '4.12'
compile "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
compile "org.opensearch:common-utils:${common_utils_version}"
}

jacocoTestReport {
Expand Down
141 changes: 141 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.parameter;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.commons.authuser.User;

import java.io.IOException;
import java.util.Base64;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

@Getter
public class MLModel implements ToXContentObject {
public static final String ALGORITHM = "algorithm";
public static final String MODEL_NAME = "name";
public static final String MODEL_VERSION = "version";
public static final String MODEL_CONTENT = "content";
public static final String USER = "user";

private String name;
private FunctionName algorithm;
private Integer version;
private String content;
private User user;

@Builder
public MLModel(String name, FunctionName algorithm, Integer version, String content, User user) {
this.name = name;
this.algorithm = algorithm;
this.version = version;
this.content = content;
this.user = user;
}

public MLModel(FunctionName algorithm, Model model) {
this(model.getName(), algorithm, model.getVersion(), Base64.getEncoder().encodeToString(model.getContent()), null);
}

public MLModel(StreamInput input) throws IOException{
name = input.readOptionalString();
algorithm = input.readEnum(FunctionName.class);
version = input.readInt();
content = input.readOptionalString();
if (input.readBoolean()) {
this.user = new User(input);
} else {
user = null;
}
}

public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(name);
out.writeEnum(algorithm);
out.writeInt(version);
out.writeOptionalString(content);
if (user != null) {
out.writeBoolean(true); // user exists
user.writeTo(out);
} else {
out.writeBoolean(false); // user does not exist
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (name != null) {
builder.field(MODEL_NAME, name);
}
if (algorithm != null) {
builder.field(ALGORITHM, algorithm);
}
if (version != null) {
builder.field(MODEL_VERSION, version);
}
if (content != null) {
builder.field(MODEL_CONTENT, content);
}
if (user != null) {
builder.field(USER, user);
}
builder.endObject();
return builder;
}

public static MLModel parse(XContentParser parser, String taskId) throws IOException {
String name = null;
FunctionName algorithm = null;
Integer version = null;
String content = null;
User user = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MODEL_NAME:
name = parser.text();
break;
case MODEL_CONTENT:
content = parser.text();
break;
case MODEL_VERSION:
version = parser.intValue();
break;
case USER:
user = User.parse(parser);
break;
case ALGORITHM:
algorithm = FunctionName.valueOf(parser.text());
break;
default:
break;
}
}
return MLModel.builder()
.name(name)
.algorithm(algorithm)
.version(version)
.content(content)
.user(user)
.build();
}

public static MLModel fromStream(StreamInput in) throws IOException {
MLModel mlModel = new MLModel(in);
return mlModel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine;
package org.opensearch.ml.common.parameter;

import lombok.Data;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import org.opensearch.action.ActionType;

public class MLModelGetAction extends ActionType<MLModelGetResponse> {
public static final MLModelGetAction INSTANCE = new MLModelGetAction();
public static final String NAME = "cluster:admin/opensearch/ml/models/get";

private MLModelGetAction() { super(NAME, MLModelGetResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.InputStreamStreamInput;
import org.opensearch.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLModelGetRequest extends ActionRequest {

String modelId;

@Builder
public MLModelGetRequest(String modelId) {
this.modelId = modelId;
}

public MLModelGetRequest(StreamInput in) throws IOException {
super(in);
this.modelId = in.readOptionalString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(this.modelId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.modelId == null) {
exception = addValidationError("ML model id can't be null", exception);
}

return exception;
}

public static MLModelGetRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLModelGetRequest) {
return (MLModelGetRequest)actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLModelGetRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLModelGetRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.Builder;
import org.opensearch.action.ActionResponse;
import org.opensearch.common.io.stream.InputStreamStreamInput;
import org.opensearch.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLModel;
import org.opensearch.ml.common.parameter.MLOutput;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

public class MLModelGetResponse extends ActionResponse implements ToXContentObject {

MLModel mlModel;

@Builder
public MLModelGetResponse(MLModel mlModel) {
this.mlModel = mlModel;
}


public MLModelGetResponse(StreamInput in) throws IOException {
super(in);
mlModel = mlModel.fromStream(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException{
mlModel.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return mlModel.toXContent(xContentBuilder, params);
}
}
5 changes: 5 additions & 0 deletions common/src/test/java/org/opensearch/ml/common/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common;

import org.opensearch.common.Strings;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.ToXContent;
Expand Down Expand Up @@ -43,4 +44,8 @@ public static <T> void testParseFromString(ToXContentObject obj, String jsonStr,
T parsedObj = function.apply(parser);
obj.equals(parsedObj);
}

public static String xContentBuilderToString(XContentBuilder builder) {
return BytesReference.bytes(builder).utf8ToString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.parameter;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;

import static org.junit.Assert.assertEquals;
import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;

public class MLModelTests {

MLModel mlModel;
@Before
public void setUp() {
FunctionName algorithm = FunctionName.KMEANS;
User user = new User();
mlModel = MLModel.builder()
.name("some model")
.algorithm(algorithm)
.version(1)
.content("some content")
.user(user)
.build();
}

@Test
public void toXContent() throws IOException {
MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version(1).content("test_content").build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"version\":1,\"content\":\"test_content\"}", mlModelContent);
}

@Test
public void toXContent_NullValue() throws IOException {
MLModel mlModel = MLModel.builder().build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{}", mlModelContent);
}

@Test
public void readInputStream_Success() throws IOException {
readInputStream(mlModel);
}

public void readInputStream(MLModel mlModel) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
mlModel.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
MLModel parsedMLModel = new MLModel(streamInput);
assertEquals(mlModel.getName(), parsedMLModel.getName());
assertEquals(mlModel.getAlgorithm(), parsedMLModel.getAlgorithm());
assertEquals(mlModel.getVersion(), parsedMLModel.getVersion());
assertEquals(mlModel.getContent(), parsedMLModel.getContent());
assertEquals(mlModel.getUser(), parsedMLModel.getUser());
}
}
Loading

0 comments on commit 7c45ab5

Please sign in to comment.