-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Xun Zhang <xunzh@amazon.com>
- Loading branch information
1 parent
b9efc53
commit 2cea375
Showing
29 changed files
with
696 additions
and
125 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
142 changes: 142 additions & 0 deletions
142
common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.parameter; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import org.opensearch.common.io.stream.StreamInput; | ||
import org.opensearch.common.io.stream.StreamOutput; | ||
import org.opensearch.common.xcontent.ToXContentObject; | ||
import org.opensearch.common.xcontent.XContentBuilder; | ||
import org.opensearch.common.xcontent.XContentParser; | ||
import org.opensearch.commons.authuser.User; | ||
|
||
import java.io.IOException; | ||
import java.util.Base64; | ||
|
||
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
|
||
@Getter | ||
public class MLModel implements ToXContentObject { | ||
public static final String ALGORITHM = "algorithm"; | ||
public static final String MODEL_NAME = "name"; | ||
public static final String MODEL_VERSION = "version"; | ||
public static final String MODEL_CONTENT = "content"; | ||
public static final String USER = "user"; | ||
|
||
private String name; | ||
private FunctionName algorithm; | ||
private Integer version; | ||
private String content; | ||
private User user; | ||
|
||
@Builder | ||
public MLModel(String name, FunctionName algorithm, Integer version, String content, User user) { | ||
this.name = name; | ||
this.algorithm = algorithm; | ||
this.version = version; | ||
this.content = content; | ||
this.user = user; | ||
} | ||
|
||
public MLModel(FunctionName algorithm, Model model) { | ||
this(model.getName(), algorithm, model.getVersion(), Base64.getEncoder().encodeToString(model.getContent()), null); | ||
} | ||
|
||
public MLModel(StreamInput input) throws IOException{ | ||
name = input.readOptionalString(); | ||
algorithm = input.readEnum(FunctionName.class); | ||
version = input.readInt(); | ||
content = input.readOptionalString(); | ||
if (input.readBoolean()) { | ||
this.user = new User(input); | ||
} else { | ||
user = null; | ||
} | ||
} | ||
|
||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeOptionalString(name); | ||
out.writeEnum(algorithm); | ||
out.writeInt(version); | ||
out.writeOptionalString(content); | ||
if (user != null) { | ||
out.writeBoolean(true); // user exists | ||
user.writeTo(out); | ||
} else { | ||
out.writeBoolean(false); // user does not exist | ||
} | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
if (name != null) { | ||
builder.field(MODEL_NAME, name); | ||
} | ||
if (algorithm != null) { | ||
builder.field(ALGORITHM, algorithm); | ||
} | ||
if (version != null) { | ||
builder.field(MODEL_VERSION, version); | ||
} | ||
if (content != null) { | ||
builder.field(MODEL_CONTENT, content); | ||
} | ||
if (user != null) { | ||
builder.field(USER, user); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
public static MLModel parse(XContentParser parser, String taskId) throws IOException { | ||
String name = null; | ||
FunctionName algorithm = null; | ||
Integer version = null; | ||
String content = null; | ||
User user = null; | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
|
||
switch (fieldName) { | ||
case MODEL_NAME: | ||
name = parser.text(); | ||
break; | ||
case MODEL_CONTENT: | ||
content = parser.text(); | ||
break; | ||
case MODEL_VERSION: | ||
version = parser.intValue(); | ||
break; | ||
case USER: | ||
user = User.parse(parser); | ||
break; | ||
case ALGORITHM: | ||
algorithm = FunctionName.valueOf(parser.text()); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
return MLModel.builder() | ||
.name(name) | ||
.algorithm(algorithm) | ||
.version(version) | ||
.content(content) | ||
.user(user) | ||
.build(); | ||
} | ||
|
||
public static MLModel fromStream(StreamInput in) throws IOException { | ||
MLModel mlModel = new MLModel(in); | ||
return mlModel; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 15 additions & 0 deletions
15
common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import org.opensearch.action.ActionType; | ||
|
||
public class MLModelGetAction extends ActionType<MLModelGetResponse> { | ||
public static final MLModelGetAction INSTANCE = new MLModelGetAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/models/get"; | ||
|
||
private MLModelGetAction() { super(NAME, MLModelGetResponse::new);} | ||
} |
76 changes: 76 additions & 0 deletions
76
common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.ToString; | ||
import lombok.experimental.FieldDefaults; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.common.io.stream.StreamInput; | ||
import org.opensearch.common.io.stream.StreamOutput; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
@Getter | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
@ToString | ||
public class MLModelGetRequest extends ActionRequest { | ||
|
||
String modelId; | ||
|
||
@Builder | ||
public MLModelGetRequest(String modelId) { | ||
this.modelId = modelId; | ||
} | ||
|
||
public MLModelGetRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.modelId = in.readString(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.modelId); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
|
||
if (this.modelId == null) { | ||
exception = addValidationError("ML model id can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
public static MLModelGetRequest fromActionRequest(ActionRequest actionRequest) { | ||
if (actionRequest instanceof MLModelGetRequest) { | ||
return (MLModelGetRequest)actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLModelGetRequest(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionRequest into MLModelGetRequest", e); | ||
} | ||
} | ||
} |
49 changes: 49 additions & 0 deletions
49
common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import lombok.Builder; | ||
import org.opensearch.action.ActionResponse; | ||
import org.opensearch.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.common.io.stream.StreamInput; | ||
import org.opensearch.common.io.stream.StreamOutput; | ||
import org.opensearch.common.xcontent.ToXContentObject; | ||
import org.opensearch.common.xcontent.XContentBuilder; | ||
import org.opensearch.ml.common.parameter.FunctionName; | ||
import org.opensearch.ml.common.parameter.MLModel; | ||
import org.opensearch.ml.common.parameter.MLOutput; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
public class MLModelGetResponse extends ActionResponse implements ToXContentObject { | ||
|
||
MLModel mlModel; | ||
|
||
@Builder | ||
public MLModelGetResponse(MLModel mlModel) { | ||
this.mlModel = mlModel; | ||
} | ||
|
||
|
||
public MLModelGetResponse(StreamInput in) throws IOException { | ||
super(in); | ||
mlModel = mlModel.fromStream(in); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException{ | ||
mlModel.writeTo(out); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { | ||
return mlModel.toXContent(xContentBuilder, params); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
common/src/test/java/org/opensearch/ml/common/parameter/MLModelTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.parameter; | ||
|
||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.opensearch.common.io.stream.BytesStreamOutput; | ||
import org.opensearch.common.io.stream.StreamInput; | ||
import org.opensearch.common.xcontent.XContentBuilder; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.commons.authuser.User; | ||
import org.opensearch.ml.common.TestHelper; | ||
|
||
import java.io.IOException; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; | ||
|
||
public class MLModelTests { | ||
|
||
MLModel mlModel; | ||
@Before | ||
public void setUp() { | ||
FunctionName algorithm = FunctionName.KMEANS; | ||
User user = new User(); | ||
mlModel = MLModel.builder() | ||
.name("some model") | ||
.algorithm(algorithm) | ||
.version(1) | ||
.content("some content") | ||
.user(user) | ||
.build(); | ||
} | ||
|
||
@Test | ||
public void toXContent() throws IOException { | ||
MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version(1).content("test_content").build(); | ||
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); | ||
mlModel.toXContent(builder, EMPTY_PARAMS); | ||
String mlModelContent = TestHelper.xContentBuilderToString(builder); | ||
assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"version\":1,\"content\":\"test_content\"}", mlModelContent); | ||
} | ||
|
||
@Test | ||
public void toXContent_NullValue() throws IOException { | ||
MLModel mlModel = MLModel.builder().build(); | ||
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); | ||
mlModel.toXContent(builder, EMPTY_PARAMS); | ||
String mlModelContent = TestHelper.xContentBuilderToString(builder); | ||
assertEquals("{}", mlModelContent); | ||
} | ||
|
||
@Test | ||
public void readInputStream_Success() throws IOException { | ||
readInputStream(mlModel); | ||
} | ||
|
||
public void readInputStream(MLModel mlModel) throws IOException { | ||
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); | ||
mlModel.writeTo(bytesStreamOutput); | ||
|
||
StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); | ||
MLModel parsedMLModel = new MLModel(streamInput); | ||
assertEquals(mlModel.getName(), parsedMLModel.getName()); | ||
assertEquals(mlModel.getAlgorithm(), parsedMLModel.getAlgorithm()); | ||
assertEquals(mlModel.getVersion(), parsedMLModel.getVersion()); | ||
assertEquals(mlModel.getContent(), parsedMLModel.getContent()); | ||
assertEquals(mlModel.getUser(), parsedMLModel.getUser()); | ||
} | ||
} |
Oops, something went wrong.