diff --git a/resilience4j-spring/src/main/java/io/github/resilience4j/fallback/CompletionStageFallbackDecorator.java b/resilience4j-spring/src/main/java/io/github/resilience4j/fallback/CompletionStageFallbackDecorator.java index cc3ef69a5f..976c13ebda 100644 --- a/resilience4j-spring/src/main/java/io/github/resilience4j/fallback/CompletionStageFallbackDecorator.java +++ b/resilience4j-spring/src/main/java/io/github/resilience4j/fallback/CompletionStageFallbackDecorator.java @@ -18,7 +18,9 @@ import io.vavr.CheckedFunction0; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; /** * fallbackMethod decorator for {@link CompletionStage} @@ -35,23 +37,14 @@ public boolean supports(Class target) { public CheckedFunction0 decorate(FallbackMethod fallbackMethod, CheckedFunction0 supplier) { return supplier.andThen(request -> { - CompletionStage completionStage = (CompletionStage) request; - + CompletionStage completionStage = (CompletionStage) request; CompletableFuture promise = new CompletableFuture(); - completionStage.whenComplete((result, throwable) -> { - if (throwable != null) { - try { - ((CompletionStage) fallbackMethod.fallback((Throwable) throwable)) - .whenComplete((fallbackResult, fallbackThrowable) -> { - if (fallbackThrowable != null) { - promise.completeExceptionally((Throwable) fallbackThrowable); - } else { - promise.complete(fallbackResult); - } - }); - } catch (Throwable fallbackThrowable) { - promise.completeExceptionally(fallbackThrowable); + if (throwable != null){ + if (throwable instanceof CompletionException || throwable instanceof ExecutionException) { + tryRecover(fallbackMethod, promise, throwable.getCause()); + }else{ + tryRecover(fallbackMethod, promise, throwable); } } else { promise.complete(result); @@ -61,4 +54,21 @@ public CheckedFunction0 decorate(FallbackMethod fallbackMethod, return promise; }); } + + @SuppressWarnings("unchecked") + private void tryRecover(FallbackMethod fallbackMethod, CompletableFuture promise, + Throwable throwable) { + try { + CompletionStage completionStage = (CompletionStage) fallbackMethod.fallback(throwable); + completionStage.whenComplete((fallbackResult, fallbackThrowable) -> { + if (fallbackThrowable != null) { + promise.completeExceptionally(fallbackThrowable); + } else { + promise.complete(fallbackResult); + } + }); + } catch (Throwable fallbackThrowable) { + promise.completeExceptionally(fallbackThrowable); + } + } } diff --git a/resilience4j-spring/src/test/java/io/github/resilience4j/fallback/FallbackMethodTest.java b/resilience4j-spring/src/test/java/io/github/resilience4j/fallback/FallbackMethodTest.java index b82a8ff1ed..ae6963ee29 100644 --- a/resilience4j-spring/src/test/java/io/github/resilience4j/fallback/FallbackMethodTest.java +++ b/resilience4j-spring/src/test/java/io/github/resilience4j/fallback/FallbackMethodTest.java @@ -18,6 +18,7 @@ import org.junit.Test; import java.lang.reflect.Method; +import java.util.concurrent.CompletableFuture; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -35,6 +36,16 @@ public void fallbackRuntimeExceptionTest() throws Throwable { .isEqualTo("recovered-RuntimeException"); } + @Test + public void fallbackFuture() throws Throwable { + FallbackMethodTest target = new FallbackMethodTest(); + Method testMethod = target.getClass().getMethod("testFutureMethod", String.class); + FallbackMethod fallbackMethod = FallbackMethod + .create("futureFallbackMethod", testMethod, new Object[]{"test"}, target); + CompletableFuture future = (CompletableFuture) fallbackMethod.fallback(new IllegalStateException("err")); + assertThat(future.get()).isEqualTo("recovered-IllegalStateException"); + } + @Test public void fallbackGlobalExceptionWithSameMethodReturnType() throws Throwable { FallbackMethodTest target = new FallbackMethodTest(); @@ -134,6 +145,10 @@ public String testMethod(String parameter) { return "test"; } + public CompletableFuture testFutureMethod(String parameter) { + return CompletableFuture.completedFuture("test"); + } + public String fallbackMethod(String parameter, RuntimeException exception) { return "recovered-RuntimeException"; } @@ -142,6 +157,10 @@ public String fallbackMethod(IllegalStateException exception) { return "recovered-IllegalStateException"; } + public CompletableFuture futureFallbackMethod(String parameter, IllegalStateException exception) { + return CompletableFuture.completedFuture("recovered-IllegalStateException"); + } + public String fallbackMethod(String parameter, IllegalArgumentException exception) { return "recovered-IllegalArgumentException"; }