Skip to content

Commit

Permalink
Send original update request back in accept/reject response (#2074)
Browse files Browse the repository at this point in the history
  • Loading branch information
Quinn-With-Two-Ns authored May 21, 2024
1 parent 82d5a88 commit 5ccb859
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ enum State {
private String requestMsgId;
private long requestSeqID;
private Meta meta;
private Optional<Request> originalRequest = Optional.empty();
private String messageId;

public static final StateMachineDefinition<State, ExplicitEvent, UpdateProtocolStateMachine>
Expand Down Expand Up @@ -175,7 +176,8 @@ void triggerUpdate() {
requestMsgId = this.currentMessage.getId();
requestSeqID = this.currentMessage.getEventId();
try {
meta = this.currentMessage.getBody().unpack(Request.class).getMeta();
originalRequest = Optional.of(this.currentMessage.getBody().unpack(Request.class));
meta = originalRequest.get().getMeta();
} catch (InvalidProtocolBufferException e) {
throw new IllegalArgumentException("Current message not an update:" + this.currentMessage);
}
Expand All @@ -199,8 +201,10 @@ public void accept() {
Acceptance.newBuilder()
.setAcceptedRequestMessageId(requestMsgId)
.setAcceptedRequestSequencingEventId(requestSeqID)
.setAcceptedRequest(originalRequest.get())
.build();

// Clear the original request to allow GC to reclaim the memory.
originalRequest = Optional.empty();
messageId = requestMsgId + "/accept";
sendHandle.apply(
Message.newBuilder()
Expand All @@ -217,6 +221,7 @@ public void reject(Failure failure) {
.setRejectedRequestMessageId(requestMsgId)
.setRejectedRequestSequencingEventId(requestSeqID)
.setFailure(failure)
.setRejectedRequest(originalRequest.get())
.build();

String messageId = requestMsgId + "/reject";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import static io.temporal.internal.statemachines.MutableSideEffectStateMachine.*;
import static io.temporal.internal.statemachines.SideEffectStateMachine.SIDE_EFFECT_MARKER_NAME;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
Expand All @@ -34,10 +33,7 @@
import io.temporal.api.enums.v1.EventType;
import io.temporal.api.history.v1.*;
import io.temporal.api.protocol.v1.Message;
import io.temporal.api.update.v1.Input;
import io.temporal.api.update.v1.Meta;
import io.temporal.api.update.v1.Outcome;
import io.temporal.api.update.v1.Request;
import io.temporal.api.update.v1.*;
import io.temporal.common.converter.DataConverter;
import io.temporal.common.converter.DefaultDataConverter;
import io.temporal.internal.common.ProtobufTimeUtils;
Expand Down Expand Up @@ -84,7 +80,7 @@ public static void generateCoverage() {
}

@Test
public void testUpdateAccept() {
public void testUpdateAccept() throws InvalidProtocolBufferException {
class TestUpdateListener extends TestEntityManagerListenerBase {

@Override
Expand Down Expand Up @@ -173,8 +169,31 @@ protected void update(UpdateMessage message, AsyncWorkflowBuilder<Void> builder)
{
TestEntityManagerListenerBase listener = new TestUpdateListener();
stateMachines = newStateMachines(listener);
List<Command> commands = h.handleWorkflowTaskTakeCommands(stateMachines, 0);
assertEquals(0, commands.size());
Request request =
Request.newBuilder()
.setInput(
Input.newBuilder()
.setName("updateName")
.setArgs(converter.toPayloads("arg").get()))
.build();
stateMachines.setMessages(
Collections.unmodifiableList(
Arrays.asList(
new Message[] {
Message.newBuilder()
.setProtocolInstanceId("protocol_id")
.setId("id")
.setEventId(0)
.setBody(Any.pack(request))
.build(),
})));
List<Command> commands = h.handleWorkflowTaskTakeCommands(stateMachines, 1);
assertEquals(2, commands.size());
List<Message> messages = stateMachines.takeMessages();
assertEquals(1, messages.size());
Acceptance acceptance = messages.get(0).getBody().unpack(Acceptance.class);
assertNotNull(acceptance);
assertEquals(request, acceptance.getAcceptedRequest());
}
{
TestEntityManagerListenerBase listener = new TestUpdateListener();
Expand Down Expand Up @@ -369,7 +388,7 @@ protected void update(UpdateMessage message, AsyncWorkflowBuilder<Void> builder)
}

@Test
public void testUpdateRejected() {
public void testUpdateRejected() throws InvalidProtocolBufferException {
class TestUpdateListener extends TestEntityManagerListenerBase {

@Override
Expand Down Expand Up @@ -404,14 +423,13 @@ protected void update(UpdateMessage message, AsyncWorkflowBuilder<Void> builder)
// Full replay
TestEntityManagerListenerBase listener = new TestUpdateListener();
stateMachines = newStateMachines(listener);
Any messageBody =
Any.pack(
Request.newBuilder()
.setInput(
Input.newBuilder()
.setName("updateName")
.setArgs(converter.toPayloads("arg").get()))
.build());
Request request =
Request.newBuilder()
.setInput(
Input.newBuilder()
.setName("updateName")
.setArgs(converter.toPayloads("arg").get()))
.build();
stateMachines.setMessages(
Collections.unmodifiableList(
Arrays.asList(
Expand All @@ -420,11 +438,16 @@ protected void update(UpdateMessage message, AsyncWorkflowBuilder<Void> builder)
.setProtocolInstanceId("protocol_id")
.setId("id")
.setEventId(0)
.setBody(messageBody)
.setBody(Any.pack(request))
.build(),
})));
List<Command> commands = h.handleWorkflowTaskTakeCommands(stateMachines, 1);
assertEquals(0, commands.size());
List<Message> messages = stateMachines.takeMessages();
assertEquals(1, messages.size());
Rejection rejection = messages.get(0).getBody().unpack(Rejection.class);
assertNotNull(rejection);
assertEquals(request, rejection.getRejectedRequest());
}
}

Expand Down

0 comments on commit 5ccb859

Please sign in to comment.