diff --git a/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java index 6c0fda289a..561fe81d5f 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.agent; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -19,7 +20,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; - +@EqualsAndHashCode @Getter public class LLMSpec implements ToXContentObject { public static final String MODEL_ID_FIELD = "model_id"; diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index ba2f241375..9033b92afc 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.agent; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -26,7 +27,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; - +@EqualsAndHashCode @Getter public class MLAgent implements ToXContentObject, Writeable { public static final String AGENT_NAME_FIELD = "name"; @@ -64,9 +65,6 @@ public MLAgent(String name, Instant createdTime, Instant lastUpdateTime, String appType) { - if (name == null) { - throw new IllegalArgumentException("agent name is null"); - } this.name = name; this.type = type; this.description = description; @@ -77,6 +75,24 @@ public MLAgent(String name, this.createdTime = createdTime; this.lastUpdateTime = lastUpdateTime; this.appType = appType; + validate(); + } + + private void validate() { + if (name == null) { + throw new IllegalArgumentException("agent name is null"); + } + Set toolNames = new HashSet<>(); + if (tools != null) { + for (MLToolSpec toolSpec : tools) { + String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); + if (toolNames.contains(toolName)) { + throw new IllegalArgumentException("Duplicate tool defined: " + toolName); + } else { + toolNames.add(toolName); + } + } + } } public MLAgent(StreamInput input) throws IOException{ @@ -99,18 +115,10 @@ public MLAgent(StreamInput input) throws IOException{ if (input.readBoolean()) { memory = new MLMemorySpec(input); } - createdTime = input.readInstant(); - lastUpdateTime = input.readInstant(); - appType = input.readString(); - if (!"flow".equals(type)) { - Set toolNames = new HashSet<>(); - for (MLToolSpec toolSpec : tools) { - String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); - if (toolNames.contains(toolName)) { - throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName); - } - } - } + createdTime = input.readOptionalInstant(); + lastUpdateTime = input.readOptionalInstant(); + appType = input.readOptionalString(); + validate(); } public void writeTo(StreamOutput out) throws IOException { @@ -144,9 +152,9 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } - out.writeInstant(createdTime); - out.writeInstant(lastUpdateTime); - out.writeString(appType); + out.writeOptionalInstant(createdTime); + out.writeOptionalInstant(lastUpdateTime); + out.writeOptionalString(appType); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java index aa192a7ee2..5d13d5236c 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.agent; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import org.opensearch.core.common.io.stream.StreamInput; @@ -18,7 +19,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - +@EqualsAndHashCode @Getter public class MLMemorySpec implements ToXContentObject { public static final String MEMORY_TYPE_FIELD = "type"; diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index 055c59d449..7b9b640c8a 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.agent; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -19,7 +20,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; - +@EqualsAndHashCode @Getter public class MLToolSpec implements ToXContentObject { public static final String TOOL_TYPE_FIELD = "type"; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java index a437ef0ed8..593e314b31 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.transport.agent; import lombok.Builder; +import lombok.Getter; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -20,6 +21,7 @@ import java.io.IOException; import java.io.UncheckedIOException; +@Getter public class MLAgentGetResponse extends ActionResponse implements ToXContentObject { MLAgent mlAgent; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentAction.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentAction.java new file mode 100644 index 0000000000..c5d1a1232f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.opensearch.action.ActionType; + +public class MLRegisterAgentAction extends ActionType { + public static MLRegisterAgentAction INSTANCE = new MLRegisterAgentAction(); + public static final String NAME = "cluster:admin/opensearch/ml/agents/register"; + + private MLRegisterAgentAction() { + super(NAME, MLRegisterAgentResponse::new); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java new file mode 100644 index 0000000000..4add7827d5 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.agent.MLAgent; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLRegisterAgentRequest extends ActionRequest { + + MLAgent mlAgent; + + @Builder + public MLRegisterAgentRequest(MLAgent mlAgent) { + this.mlAgent = mlAgent; + } + + public MLRegisterAgentRequest(StreamInput in) throws IOException { + super(in); + this.mlAgent = new MLAgent(in); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (mlAgent == null) { + exception = addValidationError("ML agent can't be null", exception); + } + + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + this.mlAgent.writeTo(out); + } + + public static MLRegisterAgentRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLRegisterAgentRequest) { + return (MLRegisterAgentRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterAgentRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterAgentRequest", e); + } + + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java new file mode 100644 index 0000000000..7f8b633cbe --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.Getter; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +@Getter +public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject { + public static final String AGENT_ID_FIELD = "agent_id"; + + private String agentId; + + public MLRegisterAgentResponse(StreamInput in) throws IOException { + super(in); + this.agentId = in.readString(); + } + + public MLRegisterAgentResponse(String agentId) { + this.agentId= agentId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(agentId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(AGENT_ID_FIELD, agentId); + builder.endObject(); + return builder; + } + + public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterAgentResponse) { + return (MLRegisterAgentResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterAgentResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLRegisterAgentResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java index 7534b52187..71fc7ef38b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLUndeployModelsResponse extends ActionResponse implements ToXContentObject { @@ -49,4 +54,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } return builder; } + + public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLUndeployModelsResponse) { + return (MLUndeployModelsResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLUndeployModelsResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLUndeployModelsResponse", e); + } + } } diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index bfaec959c4..e00a49aeb6 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -35,6 +35,14 @@ public void constructor_NullName() { MLAgent agent = new MLAgent(null, "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test"); } + @Test + public void constructor_DuplicateTool() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Duplicate tool defined: test_tool_name"); + MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false); + MLAgent agent = new MLAgent("test_name", "test_type", "test_description", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(mlToolSpec, mlToolSpec), null, null, Instant.EPOCH, Instant.EPOCH, "test"); + } + @Test public void writeTo() throws IOException { MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index 7d733a4308..b692ce34ac 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -4,9 +4,11 @@ */ package org.opensearch.ml.common.transport.agent; +import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.*; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -17,40 +19,41 @@ import java.io.*; import java.time.Instant; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; public class MLAgentGetResponseTest { MLAgent mlAgent; + @Before + public void setUp() { + mlAgent = MLAgent.builder() + .name("test_agent") + .appType("test_app") + .type("flow") + .tools(Arrays.asList(MLToolSpec.builder().type("CatIndexTool").build())) + .build(); + } + @Test public void Create_MLAgentResponse_With_StreamInput() throws IOException { // Create a BytesStreamOutput to simulate the StreamOutput - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - - //create a test agent using input - bytesStreamOutput.writeString("Test Agent"); - bytesStreamOutput.writeString("flow"); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z")); - bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z")); - bytesStreamOutput.writeString("test"); - - StreamInput testInputStream = bytesStreamOutput.bytes().streamInput(); - - MLAgentGetResponse mlAgentGetResponse = new MLAgentGetResponse(testInputStream); - MLAgent testMlAgent = mlAgentGetResponse.mlAgent; - assertEquals("flow",testMlAgent.getType()); - assertEquals("Test Agent",testMlAgent.getName()); - assertEquals("test",testMlAgent.getAppType()); + MLAgentGetResponse agentGetResponse = new MLAgentGetResponse(mlAgent); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + agentGetResponse.writeTo(out); + } + }; + MLAgentGetResponse parsedResponse = MLAgentGetResponse.fromActionResponse(actionResponse); + assertNotSame(agentGetResponse, parsedResponse); + assertEquals(agentGetResponse.getMlAgent(), parsedResponse.getMlAgent()); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java new file mode 100644 index 0000000000..aa790d0ccd --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class MLRegisterAgentActionTest { + + @Test + public void actionInstance() { + assertNotNull(MLRegisterAgentAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/agents/register", MLRegisterAgentAction.NAME); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java new file mode 100644 index 0000000000..ee446db82f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLToolSpec; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; + +import static org.junit.Assert.*; + +public class MLRegisterAgentRequestTest { + + MLAgent mlAgent; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + mlAgent = MLAgent.builder() + .name("test_agent") + .appType("test_app") + .type("flow") + .tools(Arrays.asList(MLToolSpec.builder().type("CatIndexTool").build())) + .build(); + } + + @Test + public void constructor_Agent() { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + assertEquals(mlAgent, registerAgentRequest.getMlAgent()); + + ActionRequestValidationException validationException = registerAgentRequest.validate(); + assertNull(validationException); + } + + @Test + public void constructor_NullAgent() { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest((MLAgent) null); + assertNull(registerAgentRequest.getMlAgent()); + + ActionRequestValidationException validationException = registerAgentRequest.validate(); + assertNotNull(validationException); + assertTrue(validationException.toString().contains("ML agent can't be null")); + } + + @Test + public void writeTo_Success() throws IOException { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + registerAgentRequest.writeTo(bytesStreamOutput); + MLRegisterAgentRequest parsedRequest = new MLRegisterAgentRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(mlAgent, parsedRequest.getMlAgent()); + } + + @Test + public void fromActionRequest_Success() { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + registerAgentRequest.writeTo(out); + } + }; + MLRegisterAgentRequest parsedRequest = MLRegisterAgentRequest.fromActionRequest(actionRequest); + assertNotSame(registerAgentRequest, parsedRequest); + assertEquals(registerAgentRequest.getMlAgent(), parsedRequest.getMlAgent()); + } + + @Test + public void fromActionRequest_Success_MLRegisterAgentRequest() { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + MLRegisterAgentRequest parsedRequest = MLRegisterAgentRequest.fromActionRequest(registerAgentRequest); + assertSame(registerAgentRequest, parsedRequest); + } + + @Test + public void fromActionRequest_Exception() { + exceptionRule.expect(UncheckedIOException.class); + exceptionRule.expectMessage("Failed to parse ActionRequest into MLRegisterAgentRequest"); + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLRegisterAgentRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java new file mode 100644 index 0000000000..9997eb0ad6 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.*; + +public class MLRegisterAgentResponseTest { + String agentId; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + agentId = "test_agent_id"; + } + + @Test + public void constructor_AgentId() { + MLRegisterAgentResponse response = new MLRegisterAgentResponse(agentId); + assertEquals(agentId, response.getAgentId()); + } + + @Test + public void writeTo_Success() throws IOException { + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + registerAgentResponse.writeTo(bytesStreamOutput); + MLRegisterAgentResponse parsedResponse = new MLRegisterAgentResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(agentId, parsedResponse.getAgentId()); + } + + @Test + public void toXContent() throws IOException { + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + registerAgentResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals("{\"agent_id\":\"test_agent_id\"}", jsonStr); + } + + @Test + public void fromActionResponse_Success() { + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + ActionResponse actionResponse = new ActionResponse() { + + @Override + public void writeTo(StreamOutput out) throws IOException { + registerAgentResponse.writeTo(out); + } + }; + MLRegisterAgentResponse parsedResponse = MLRegisterAgentResponse.fromActionResponse(actionResponse); + assertNotSame(registerAgentResponse, parsedResponse); + assertEquals(registerAgentResponse.getAgentId(), parsedResponse.getAgentId()); + } + + @Test + public void fromActionResponse_Success_MLRegisterAgentResponse() { + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + MLRegisterAgentResponse parsedResponse = MLRegisterAgentResponse.fromActionResponse(registerAgentResponse); + assertSame(registerAgentResponse, parsedResponse); + } + + @Test + public void fromActionResponse_Exception() { + exceptionRule.expect(UncheckedIOException.class); + exceptionRule.expectMessage("Failed to parse ActionResponse into MLRegisterAgentResponse"); + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + ActionResponse actionResponse = new ActionResponse() { + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLRegisterAgentResponse.fromActionResponse(actionResponse); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java new file mode 100644 index 0000000000..69f12099e9 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.undeploy; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.InetAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.*; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +public class MLUndeployModelsResponseTest { + + MLUndeployModelNodesResponse undeployModelNodesResponse; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + ClusterName clusterName = new ClusterName("clusterName"); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "response"); + DiscoveryNode localNode = new DiscoveryNode( + "test_node_name", + "test_node_id", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[]{"node"}); + MLUndeployModelNodeResponse nodeResponse = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + List nodes = Arrays.asList(nodeResponse); + + List failures = Arrays.asList(); + undeployModelNodesResponse = new MLUndeployModelNodesResponse(clusterName, nodes, failures); + + } + + @Test + public void writeTo_Success() throws IOException { + MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + undeployModelsResponse.writeTo(bytesStreamOutput); + MLUndeployModelsResponse parsedResponse = new MLUndeployModelsResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(1, parsedResponse.getResponse().getNodes().size()); + assertEquals("test_node_id", parsedResponse.getResponse().getNodes().get(0).getNode().getId()); + } + + @Test + public void toXContent() throws IOException { + MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + undeployModelsResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals("{\"test_node_id\":{\"stats\":{\"modelId1\":\"response\"}}}", jsonStr); + } + + @Test + public void fromActionResponse_Success() { + MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + undeployModelsResponse.writeTo(out); + } + }; + MLUndeployModelsResponse parsedResponse = MLUndeployModelsResponse.fromActionResponse(actionResponse); + assertNotSame(undeployModelsResponse, parsedResponse); + assertEquals(1, parsedResponse.getResponse().getNodes().size()); + assertEquals("test_node_id", parsedResponse.getResponse().getNodes().get(0).getNode().getId()); + } + + @Test + public void fromActionResponse_Success_MLUndeployModelsResponse() { + MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse); + MLUndeployModelsResponse parsedResponse = MLUndeployModelsResponse.fromActionResponse(undeployModelsResponse); + assertSame(undeployModelsResponse, parsedResponse); + } + + @Test + public void fromActionResponse_Exception() { + exceptionRule.expect(UncheckedIOException.class); + exceptionRule.expectMessage("Failed to parse ActionResponse into MLUndeployModelsResponse"); + MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLUndeployModelsResponse.fromActionResponse(actionResponse); + } +}