Skip to content

Commit

Permalink
add register action request/response (#1769) (#1780)
Browse files Browse the repository at this point in the history
* add register action request/response

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* add ut for MLUndeployModelsResponse

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* add more ut

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
(cherry picked from commit 94c5d21)

Co-authored-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and ylwu-amzn authored Dec 18, 2023
1 parent 4b8a13d commit 581dec4
Show file tree
Hide file tree
Showing 15 changed files with 599 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -19,7 +20,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class LLMSpec implements ToXContentObject {
public static final String MODEL_ID_FIELD = "model_id";
Expand Down
46 changes: 27 additions & 19 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -26,7 +27,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class MLAgent implements ToXContentObject, Writeable {
public static final String AGENT_NAME_FIELD = "name";
Expand Down Expand Up @@ -64,9 +65,6 @@ public MLAgent(String name,
Instant createdTime,
Instant lastUpdateTime,
String appType) {
if (name == null) {
throw new IllegalArgumentException("agent name is null");
}
this.name = name;
this.type = type;
this.description = description;
Expand All @@ -77,6 +75,24 @@ public MLAgent(String name,
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
this.appType = appType;
validate();
}

private void validate() {
if (name == null) {
throw new IllegalArgumentException("agent name is null");
}
Set<String> toolNames = new HashSet<>();
if (tools != null) {
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Duplicate tool defined: " + toolName);
} else {
toolNames.add(toolName);
}
}
}
}

public MLAgent(StreamInput input) throws IOException{
Expand All @@ -99,18 +115,10 @@ public MLAgent(StreamInput input) throws IOException{
if (input.readBoolean()) {
memory = new MLMemorySpec(input);
}
createdTime = input.readInstant();
lastUpdateTime = input.readInstant();
appType = input.readString();
if (!"flow".equals(type)) {
Set<String> toolNames = new HashSet<>();
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName);
}
}
}
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
appType = input.readOptionalString();
validate();
}

public void writeTo(StreamOutput out) throws IOException {
Expand Down Expand Up @@ -144,9 +152,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeInstant(createdTime);
out.writeInstant(lastUpdateTime);
out.writeString(appType);
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalString(appType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -18,7 +19,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;


@EqualsAndHashCode
@Getter
public class MLMemorySpec implements ToXContentObject {
public static final String MEMORY_TYPE_FIELD = "type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -19,7 +20,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class MLToolSpec implements ToXContentObject {
public static final String TOOL_TYPE_FIELD = "type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.transport.agent;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
Expand All @@ -20,6 +21,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLAgentGetResponse extends ActionResponse implements ToXContentObject {
MLAgent mlAgent;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import org.opensearch.action.ActionType;

public class MLRegisterAgentAction extends ActionType<MLRegisterAgentResponse> {
public static MLRegisterAgentAction INSTANCE = new MLRegisterAgentAction();
public static final String NAME = "cluster:admin/opensearch/ml/agents/register";

private MLRegisterAgentAction() {
super(NAME, MLRegisterAgentResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.agent.MLAgent;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLRegisterAgentRequest extends ActionRequest {

MLAgent mlAgent;

@Builder
public MLRegisterAgentRequest(MLAgent mlAgent) {
this.mlAgent = mlAgent;
}

public MLRegisterAgentRequest(StreamInput in) throws IOException {
super(in);
this.mlAgent = new MLAgent(in);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (mlAgent == null) {
exception = addValidationError("ML agent can't be null", exception);
}

return exception;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.mlAgent.writeTo(out);
}

public static MLRegisterAgentRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLRegisterAgentRequest) {
return (MLRegisterAgentRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterAgentRequest", e);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject {
public static final String AGENT_ID_FIELD = "agent_id";

private String agentId;

public MLRegisterAgentResponse(StreamInput in) throws IOException {
super(in);
this.agentId = in.readString();
}

public MLRegisterAgentResponse(String agentId) {
this.agentId= agentId;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(agentId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(AGENT_ID_FIELD, agentId);
builder.endObject();
return builder;
}

public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLRegisterAgentResponse) {
return (MLRegisterAgentResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionResponse into MLRegisterAgentResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLUndeployModelsResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -49,4 +54,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
return builder;
}

public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLUndeployModelsResponse) {
return (MLUndeployModelsResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUndeployModelsResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionResponse into MLUndeployModelsResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ public void constructor_NullName() {
MLAgent agent = new MLAgent(null, "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test");
}

@Test
public void constructor_DuplicateTool() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Duplicate tool defined: test_tool_name");
MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false);
MLAgent agent = new MLAgent("test_name", "test_type", "test_description", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(mlToolSpec, mlToolSpec), null, null, Instant.EPOCH, Instant.EPOCH, "test");
}

@Test
public void writeTo() throws IOException {
MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test");
Expand Down
Loading

0 comments on commit 581dec4

Please sign in to comment.