Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make exceptions that prevent transitions available #970

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2015 the original author or authors.
* Copyright 2015-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,5 +64,4 @@ public interface StateMachine<S, E> extends Region<S, E> {
* @return true, if error has been set
*/
boolean hasStateMachineError();

}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/*
* Copyright 2019-2020 the original author or authors.
* Copyright 2019-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -15,6 +15,7 @@
*/
package org.springframework.statemachine;

import java.util.Optional;
import org.springframework.messaging.Message;
import org.springframework.statemachine.region.Region;

Expand Down Expand Up @@ -59,6 +60,14 @@ public interface StateMachineEventResult<S, E> {
*/
Mono<Void> complete();

/**
* If there was an exception that caused the transition to be denied - return that
* @return Optional Throwable that caused the transition to be denied
*/
default Optional<Throwable> getDenialCause() {
return Optional.empty();
};

/**
* Enumeration of a result type indicating whether a region accepted, denied or
* deferred an event.
Expand All @@ -82,7 +91,7 @@ public enum ResultType {
*/
public static <S, E> StateMachineEventResult<S, E> from(Region<S, E> region, Message<E> message,
ResultType resultType) {
return new DefaultStateMachineEventResult<>(region, message, resultType, null);
return new DefaultStateMachineEventResult<>(region, message, resultType, null, null);
}


Expand All @@ -100,7 +109,24 @@ public static <S, E> StateMachineEventResult<S, E> from(Region<S, E> region, Mes
*/
public static <S, E> StateMachineEventResult<S, E> from(Region<S, E> region, Message<E> message,
ResultType resultType, Mono<Void> complete) {
return new DefaultStateMachineEventResult<>(region, message, resultType, complete);
return new DefaultStateMachineEventResult<>(region, message, resultType, complete, null);
}

/**
* Create a {@link StateMachineEventResult} from a {@link Region},
* {@link Message} and a {@link ResultType}.
*
* @param <S> the type of state
* @param <E> the type of event
* @param region the region
* @param message the message
* @param resultType the result type
* @param denialCause the throwable (that most likely caused transition denial)
* @return the state machine event result
*/
public static <S, E> StateMachineEventResult<S, E> from(Region<S, E> region, Message<E> message,
ResultType resultType, Throwable denialCause) {
return new DefaultStateMachineEventResult<>(region, message, resultType, null, denialCause);
}

static class DefaultStateMachineEventResult<S, E> implements StateMachineEventResult<S, E> {
Expand All @@ -109,13 +135,15 @@ static class DefaultStateMachineEventResult<S, E> implements StateMachineEventRe
private final Message<E> message;
private final ResultType resultType;
private Mono<Void> complete;
private Throwable denialCause;

DefaultStateMachineEventResult(Region<S, E> region, Message<E> message, ResultType resultType,
Mono<Void> complete) {
Mono<Void> complete, Throwable denialCause) {
this.region = region;
this.message = message;
this.resultType = resultType;
this.complete = complete != null ? complete : Mono.empty();
this.denialCause = denialCause;
}

@Override
Expand All @@ -138,6 +166,11 @@ public Mono<Void> complete() {
return complete;
}

@Override
public Optional<Throwable> getDenialCause() {
return Optional.ofNullable(denialCause);
}

@Override
public String toString() {
return "DefaultStateMachineEventResult [region=" + region + ", message=" + message + ", resultType="
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,12 @@ public void setTransitionConflightPolicy(TransitionConflictPolicy transitionConf

private Flux<StateMachineEventResult<S, E>> handleEvent(Message<E> message) {
if (hasStateMachineError()) {
return Flux.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED));
return Flux.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED, currentError.getCause()));
}
return Mono.just(message)
.map(m -> getStateMachineInterceptors().preEvent(m, this))
.flatMapMany(m -> acceptEvent(m))
.onErrorResume(error -> Flux.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED)))
.onErrorResume(error -> Flux.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED, error.getCause())))
.doOnNext(notifyOnDenied());
}

Expand Down Expand Up @@ -668,7 +668,7 @@ private Flux<StateMachineEventResult<S, E>> acceptEvent(Message<E> message) {
}))
.onErrorResume(t -> {
return Mono.defer(() -> {
return Mono.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED));
return Mono.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED, t.getCause()));
});
});
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.springframework.beans.factory.BeanFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.statemachine.StateMachineEventResult.ResultType;
import org.springframework.statemachine.action.Action;
import org.springframework.statemachine.config.StateMachineFactory;
Expand Down Expand Up @@ -125,6 +126,15 @@ public static <S, E> void doSendEventAndConsumeResultAsDenied(StateMachine<S, E>
.verifyComplete();
}

public static <S, E> void doSendEventAndConsumeResultAsDeniedWithAccessDeniedException(StateMachine<S, E> stateMachine, E event) {
StepVerifier.create(stateMachine.sendEvent(eventAsMono(event)))
.consumeNextWith(result -> {
assertThat(result.getResultType()).isEqualTo(ResultType.DENIED);
assertThat(result.getDenialCause().map(t -> t instanceof AccessDeniedException).orElse(false)).isTrue();
})
.verifyComplete();
}

public static <S, E> void doSendEventAndConsumeResultAsDenied(StateMachine<S, E> stateMachine, Message<E> event) {
StepVerifier.create(stateMachine.sendEvent(eventAsMono(event)))
.consumeNextWith(result -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2015-2020 the original author or authors.
* Copyright 2015-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,10 +18,12 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.statemachine.TestUtils.doSendEventAndConsumeAll;
import static org.springframework.statemachine.TestUtils.doSendEventAndConsumeResultAsDenied;
import static org.springframework.statemachine.TestUtils.doSendEventAndConsumeResultAsDeniedWithAccessDeniedException;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.springframework.security.access.AccessDeniedException;
import org.springframework.statemachine.AbstractStateMachineTests;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.config.StateMachineBuilder;
Expand Down Expand Up @@ -53,7 +55,7 @@ protected static void assertTransitionDenied(StateMachine<States, Events> machin
assertThat(machine.getState().getIds()).containsOnly(States.S0);

listener.reset(1);
doSendEventAndConsumeAll(machine, Events.A);
doSendEventAndConsumeResultAsDeniedWithAccessDeniedException(machine, Events.A);
assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS)).isFalse();
assertThat(listener.stateChangedCount).isZero();
assertThat(machine.getState().getIds()).containsOnly(States.S0);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2015-2020 the original author or authors.
* Copyright 2015-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,7 +35,7 @@ public class EventSecurityTests extends AbstractSecurityTests {
public void testNoSecurityContext() throws Exception {
TestListener listener = new TestListener();
StateMachine<States, Events> machine = buildMachine(listener, "ROLE_ANONYMOUS", ComparisonType.ANY, null);
assertTransitionDeniedResultAsDenied(machine, listener);
assertTransitionDenied(machine, listener);
}

@Test
Expand Down