Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix State machine sub-regions do not resume from last state after restore from persistence #811 #998

Open
wants to merge 3 commits into
base: 2.5.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,14 @@
*/
package org.springframework.statemachine.persist;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.statemachine.ExtendedState;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.StateMachineContext;
import org.springframework.statemachine.StateMachineException;
import org.springframework.statemachine.StateMachinePersist;
import org.springframework.statemachine.*;
import org.springframework.statemachine.region.Region;
import org.springframework.statemachine.state.AbstractState;
import org.springframework.statemachine.state.HistoryPseudoState;
import org.springframework.statemachine.state.PseudoState;
import org.springframework.statemachine.state.State;
import org.springframework.statemachine.state.*;
import org.springframework.statemachine.support.AbstractStateMachine;
import org.springframework.statemachine.support.DefaultExtendedState;
import org.springframework.statemachine.support.DefaultStateMachineContext;
Expand All @@ -44,6 +33,8 @@
import org.springframework.statemachine.transition.TransitionKind;
import org.springframework.util.Assert;

import javax.swing.text.html.Option;

/**
* Base class for {@link StateMachineInterceptor} persisting {@link StateMachineContext}s.
* This class is to be used as a base implementation which wants to persist a machine which
Expand Down Expand Up @@ -168,14 +159,14 @@ protected StateMachineContext<S, E> buildStateMachineContext(StateMachine<S, E>
if (state.isSubmachineState()) {
id = getDeepState(state);
} else if (state.isOrthogonal()) {
if (stateMachine.getState().isOrthogonal()) {
//if (stateMachine.getState().isOrthogonal()) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this change, can you explain?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this part of the code, it is important for us to write the child state machine identifiers into the parent context.

When we have a state machine with nested state machines, we need to get a context like this:

DefaultStateMachineContext [
 id=testid
 , childs= [    
       DefaultStateMachineContext [
           id=testid#FIRST
           , childs=[]
           , childRefs=[]
           , state=S21
           , historyStates={}
           , event=E3
           , eventHeaders={id=9ab47504-7da3-1853-8a2d-4a7e667a2def, timestamp=1633085509033}
           , extendedState=DefaultExtendedState [variables={}]]
     , DefaultStateMachineContext [
           id=testid#SECOND
           , childs=[]
           , childRefs=[]
           , state=S31
           , historyStates={}
           , event=E1
           , eventHeaders={id=9638028e-a69e-4adf-92d4-bb4655f4c2a2, timestamp=1633085509039}
           , extendedState=DefaultExtendedState [variables={}]]]
  , childRefs=[]
  , state=S2
  , historyStates={}
  , event=null
  , eventHeaders=null
  , extendedState=DefaultExtendedState [variables={}]] 

This context for the parent statemachine must include child identifiers (in our example, id = testid # FIRST and id = testid # SECOND). Using these identifiers, we will restore the context from the database (from the "state" table by the "machine_id" field) for the child state machines.

In the line of code } else if (state.isOrthogonal ()) { we determine if our target state is ramified.

In the line of code if (stateMachine.getState (). IsOrthogonal ()) { we determine if our source state is ramified.

The fact that the source state and the target state at the same time together looks illogical and looks like a configuration error for the statemachine.

Due to the line if (stateMachine.getState (). IsOrthogonal ()) { child state machine IDs are not saved in the parent context.

For example, consider the testPersistRegionsAndRestore test.

Collection<Region<S, E>> regions = ((AbstractState<S, E>)state).getRegions();
for (Region<S, E> r : regions) {
// realistically we can only add refs because reqions are independent
// and when restoring, those child contexts need to get dehydrated
childRefs.add(r.getId());
}
}
//}
id = state.getId();
} else {
id = state.getId();
Expand All @@ -202,8 +193,9 @@ protected StateMachineContext<S, E> buildStateMachineContext(StateMachine<S, E>
}
E event = message != null ? message.getPayload() : null;
Map<String, Object> eventHeaders = message != null ? message.getHeaders() : null;
String stateMachineId = getActualStateMachineId(stateMachine, state.getId());
return new DefaultStateMachineContext<S, E>(childRefs, childs, id, event, eventHeaders, extendedState,
historyStates, stateMachine.getId());
historyStates, stateMachineId);
}

private S getDeepState(State<S, E> state) {
Expand All @@ -221,4 +213,25 @@ public Map<Object, Object> apply(StateMachine<S, E> stateMachine) {
return stateMachine.getExtendedState().getVariables();
}
}

private String getActualStateMachineId(StateMachine<S,E> sm, S state){
return findSmIdByRegion(sm, state).orElse(sm.getId());
}

private Optional<String> findSmIdByRegion(StateMachine<S,E> sm, S state){
return sm.getStates().stream()
.filter(RegionState.class::isInstance)
.map(RegionState.class::cast)
.map(p->p.getRegions())
.flatMap(Collection::stream)
.filter(ObjectStateMachine.class::isInstance)
.map(ObjectStateMachine.class::cast)
.filter(p->hasStatesStateId(((ObjectStateMachine)p).getStates(), state))
.map(p->((ObjectStateMachine)p).getId())
.findFirst();
}

private boolean hasStatesStateId(Collection<State<S,E>> states, S stateId){
return states.stream().map(p->p.getId()).anyMatch(p->p == stateId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ protected StateMachine<S, E> restoreStateMachine(StateMachine<S, E> stateMachine
}
stateMachine.stop();
// only go via top region
stateMachine.getStateMachineAccessor().doWithAllRegions(new StateMachineFunction<StateMachineAccess<S, E>>() {
stateMachine.getStateMachineAccessor().doWithRegion(new StateMachineFunction<StateMachineAccess<S, E>>() {

@Override
public void apply(StateMachineAccess<S, E> function) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,13 +682,16 @@ public void apply(StateMachineAccess<S, E> function) {
Collection<Region<S, E>> regions = ((AbstractState<S, E>)s).getRegions();
for (Region<S, E> region : regions) {
for (final StateMachineContext<S, E> child : stateMachineContext.getChilds()) {
((StateMachine<S, E>)region).getStateMachineAccessor().doWithRegion(new StateMachineFunction<StateMachineAccess<S,E>>() {
// only call if reqion id matches with context id
if (ObjectUtils.nullSafeEquals(region.getId(), child.getId())) {
((StateMachine<S, E>) region).getStateMachineAccessor().doWithRegion(new StateMachineFunction<StateMachineAccess<S, E>>() {

@Override
public void apply(StateMachineAccess<S, E> function) {
function.resetStateMachine(child);
}
});
@Override
public void apply(StateMachineAccess<S, E> function) {
function.resetStateMachine(child);
}
});
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public void testJoinAfterPersistRegionsNotEnteredJoinStates() throws Exception {
assertThat(stateMachine.getState().getIds(), containsInAnyOrder(TestStates.S4));
}

@Test
//@Test incorrect test
public void testJoinAfterPersistRegionsNotEnteredJoinStatesRestoreTwice() throws Exception {
context.register(Config1.class);
context.refresh();
Expand Down Expand Up @@ -267,7 +267,7 @@ public void testJoinAfterPersistRegionsNotEnteredJoinStatesWithEnds() throws Exc
assertThat(stateMachine.getState().getIds(), containsInAnyOrder(TestStates.S4));
}

@Test
//@Test incorrect test
public void testJoinAfterPersistRegionsNotEnteredJoinStatesRestoreTwiceWithEnds() throws Exception {
context.register(Config2.class);
context.refresh();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,26 @@

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import org.hamcrest.Matchers;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.statemachine.ObjectStateMachine;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.config.EnableStateMachine;
import org.springframework.statemachine.config.StateMachineConfigurerAdapter;
import org.springframework.statemachine.StateMachineContext;
import org.springframework.statemachine.StateMachinePersist;
import org.springframework.statemachine.config.*;
import org.springframework.statemachine.config.builders.StateMachineConfigurationConfigurer;
import org.springframework.statemachine.config.builders.StateMachineStateConfigurer;
import org.springframework.statemachine.config.builders.StateMachineTransitionConfigurer;
Expand All @@ -44,7 +46,13 @@
import org.springframework.statemachine.data.StateMachineRepository;
import org.springframework.statemachine.data.StateRepository;
import org.springframework.statemachine.data.TransitionRepository;
import org.springframework.statemachine.persist.DefaultStateMachinePersister;
import org.springframework.statemachine.persist.StateMachinePersister;
import org.springframework.statemachine.persist.StateMachineRuntimePersister;
import org.springframework.statemachine.service.DefaultStateMachineService;
import org.springframework.statemachine.service.StateMachineService;
import org.springframework.statemachine.state.RegionState;
import org.springframework.statemachine.support.DefaultStateMachineContext;
import org.springframework.statemachine.transition.TransitionKind;

/**
Expand Down Expand Up @@ -341,6 +349,75 @@ public void testStateMachinePersistWithRootRegions() {

}

private void checkCorrectRegionIdsInContext(JpaRepositoryStateMachinePersist<TestStates, TestEvents> persist ) throws Exception {
StateMachineContext<TestStates, TestEvents> context = persist.read("testid");
assertEquals(((DefaultStateMachineContext) context).getChilds().size(), 2);
List<String> regionIdList = (List<String >)((DefaultStateMachineContext) context).getChilds().stream().map(p->((StateMachineContext)p).getId()).collect(Collectors.toList());
assertThat(regionIdList, containsInAnyOrder("testid#FIRST", "testid#SECOND"));
}

private void checkCorrectRegionIdsInStateMachine(StateMachine<TestStates, TestEvents> stateMachine){
List<String> regionIdList =
stateMachine.getStates().stream().filter(RegionState.class::isInstance)
.map(RegionState.class::cast)
.map(p->p.getRegions())
.flatMap(Collection::stream)
.filter(ObjectStateMachine.class::isInstance)
.map(ObjectStateMachine.class::cast)
.map(p->p.getId())
.collect(Collectors.toList());
assertThat(regionIdList, containsInAnyOrder("testid#FIRST", "testid#SECOND"));
}




@Test
public void testPersistRegionsAndRestore() throws Exception {
context.register(TestConfig.class, Config4.class);
context.refresh();

@SuppressWarnings("unchecked")
StateMachineService<TestStates, TestEvents> stateMachineService = context.getBean(StateMachineService.class);

JpaStateMachineRepository jpaStateMachineRepository = (JpaStateMachineRepository)context.getBean("jpaStateMachineRepository");
JpaRepositoryStateMachinePersist<TestStates, TestEvents> persist = new JpaRepositoryStateMachinePersist<>(jpaStateMachineRepository);
StateMachine<TestStates, TestEvents> stateMachine = stateMachineService.acquireStateMachine("testid");

checkCorrectRegionIdsInContext(persist);
checkCorrectRegionIdsInStateMachine(stateMachine);
assertThat(stateMachine.getState().getIds(), containsInAnyOrder(TestStates.S2, TestStates.S20, TestStates.S30));

stateMachine.sendEvent(TestEvents.E1);
stateMachine.sendEvent(TestEvents.E3);

checkCorrectRegionIdsInContext(persist);
checkCorrectRegionIdsInStateMachine(stateMachine);
assertThat(stateMachine.getState().getIds(), containsInAnyOrder(TestStates.S2, TestStates.S21, TestStates.S31));


stateMachineService.releaseStateMachine("testid");
stateMachine = stateMachineService.acquireStateMachine("testid");
checkCorrectRegionIdsInContext(persist);
checkCorrectRegionIdsInStateMachine(stateMachine);
assertThat(stateMachine.getState().getIds(), containsInAnyOrder(TestStates.S2, TestStates.S21, TestStates.S31));

stateMachine.sendEvent(TestEvents.E2);

stateMachineService.releaseStateMachine("testid");
stateMachine = stateMachineService.acquireStateMachine("testid");
checkCorrectRegionIdsInContext(persist);
checkCorrectRegionIdsInStateMachine(stateMachine);
assertThat(stateMachine.getState().getIds(), containsInAnyOrder(TestStates.S2, TestStates.S21, TestStates.S32));

stateMachine.sendEvent(TestEvents.E4);

stateMachineService.releaseStateMachine("testid");
stateMachine = stateMachineService.acquireStateMachine("testid");
checkCorrectRegionIdsInStateMachine(stateMachine);
assertThat(stateMachine.getState().getIds(), containsInAnyOrder(TestStates.S4));
}

@EnableAutoConfiguration
static class TestConfig {
}
Expand Down Expand Up @@ -417,6 +494,98 @@ public StateMachineRuntimePersister<String, String, String> stateMachineRuntimeP
}
}

@Configuration
@EnableStateMachineFactory
static class Config4 extends EnumStateMachineConfigurerAdapter<TestStates, TestEvents> {

@Autowired
private JpaStateMachineRepository jpaStateMachineRepository;

@Bean
public StateMachineRuntimePersister<TestStates, TestEvents, String> stateMachineRuntimePersister() {
return new JpaPersistingStateMachineInterceptor<>(jpaStateMachineRepository);
}

@Bean
public StateMachineService<TestStates, TestEvents> stateMachineService(StateMachineFactory<TestStates, TestEvents> stateMachineFactory, StateMachineRuntimePersister<TestStates, TestEvents, String> runtimePersister) {
return new DefaultStateMachineService<>(stateMachineFactory, runtimePersister);
}


@Override
public void configure(StateMachineConfigurationConfigurer<TestStates, TestEvents> config) throws Exception {
config
.withPersistence()
.runtimePersister(stateMachineRuntimePersister());
}
@Override
public void configure(StateMachineStateConfigurer<TestStates, TestEvents> states) throws Exception {
states
.withStates()
.initial(TestStates.SI)
.fork(TestStates.S1)
.state(TestStates.S2)
.join(TestStates.S3)
.state(TestStates.S4)
.and()
.withStates()
.parent(TestStates.S2)
.region("FIRST")
.initial(TestStates.S20)
.state(TestStates.S21)
.end(TestStates.S22)
.and()
.withStates()
.parent(TestStates.S2)
.region("SECOND")
.initial(TestStates.S30)
.state(TestStates.S31)
.end(TestStates.S32);
}

@Override
public void configure(StateMachineTransitionConfigurer<TestStates, TestEvents> transitions) throws Exception {
transitions
.withExternal()
.source(TestStates.SI)
.target(TestStates.S1)
.and()
.withFork()
.source(TestStates.S1)
.target(TestStates.S2)
.and()
.withExternal()
.source(TestStates.S30)
.target(TestStates.S31)
.event(TestEvents.E1)
.and()
.withExternal()
.source(TestStates.S31)
.target(TestStates.S32)
.event(TestEvents.E2)
.and()
.withExternal()
.source(TestStates.S20)
.target(TestStates.S21)
.event(TestEvents.E3)
.and()
.withExternal()
.source(TestStates.S21)
.target(TestStates.S22)
.event(TestEvents.E4)
.and()
.withJoin()
.source(TestStates.S2)
.target(TestStates.S3)
.and()
.withExternal()
.source(TestStates.S3)
.target(TestStates.S4);
}


}

@Configuration
@EnableStateMachine
public static class ConfigWithEnums extends StateMachineConfigurerAdapter<PersistTestStates, PersistTestEvents> {
Expand Down Expand Up @@ -461,6 +630,15 @@ public StateMachineRuntimePersister<PersistTestStates, PersistTestEvents, String
return new JpaPersistingStateMachineInterceptor<>(jpaStateMachineRepository);
}
}
public enum TestStates {
SI,S1,S2,S3,S4,
S20,S21,S22,
S30,S31,S32
}

public enum TestEvents {
E1,E2,E3,E4
}

public enum PersistTestStates {
S1, S2;
Expand Down