diff --git a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java index ca514103f..ba431a43c 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java @@ -30,10 +30,7 @@ import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.TypeUtils; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Set; +import java.util.*; @AllArgsConstructor @NoArgsConstructor @@ -163,19 +160,16 @@ private String getStringTemplateAndAppendArguments(J.MethodInvocation assertThat return "assertThat(#{any()}).%s(#{any()})"; } - } - private static Expression extractEitherArgument(boolean assertThatArgumentIsEmpty, Expression assertThatArgument, Expression methodToReplaceArgument) { - if (assertThatArgumentIsEmpty) { - return methodToReplaceArgument; - } - // Only on the assertThat argument do we possibly replace the argument with the select; such as list.size() -> list - if (assertThatArgument instanceof J.MethodInvocation) { - Expression select = ((J.MethodInvocation) assertThatArgument).getSelect(); - if (select != null) { - return select; + private Expression extractEitherArgument(boolean assertThatArgumentIsEmpty, Expression assertThatArgument, Expression methodToReplaceArgument) { + if (assertThatArgumentIsEmpty) { + return methodToReplaceArgument; + } + // Only on the assertThat argument do we possibly replace the argument with the select; such as list.size() -> list + if (CHAINED_ASSERT_MATCHER.matches(assertThatArgument)) { + return Objects.requireNonNull(((J.MethodInvocation) assertThatArgument).getSelect()); } + return assertThatArgument; } - return assertThatArgument; } } diff --git a/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java b/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java index 8a2b7f42e..b88b7d315 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java @@ -42,7 +42,7 @@ void stringIsEmpty() { java( """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat("hello world".isEmpty()).isTrue(); @@ -51,7 +51,7 @@ void testMethod() { """, """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat("hello world").isEmpty(); @@ -71,7 +71,7 @@ void stringIsEmptyDescribedAs() { java( """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod(String actual) { assertThat(actual.isEmpty()).as("Reason").isTrue(); @@ -80,7 +80,7 @@ void testMethod(String actual) { """, """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod(String actual) { assertThat(actual).as("Reason").isEmpty(); @@ -102,12 +102,12 @@ void chainedRecipes() { java( """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat(getString().isEmpty()).isTrue(); } - + String getString() { return "hello world"; } @@ -115,12 +115,12 @@ String getString() { """, """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat(getString()).isEmpty(); } - + String getString() { return "hello world"; } @@ -141,14 +141,14 @@ void chainedRecipesOfDifferingTypes() { java( """ import java.nio.file.Path; - + import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void string(String actual) { assertThat(actual.startsWith("prefix")).isTrue(); } - + void path(Path actual) { assertThat(actual.startsWith("prefix")).isTrue(); } @@ -156,14 +156,14 @@ void path(Path actual) { """, """ import java.nio.file.Path; - + import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void string(String actual) { assertThat(actual).startsWith("prefix"); } - + void path(Path actual) { assertThat(actual).startsWithRaw(Path.of("prefix")); } @@ -181,13 +181,13 @@ void assertThatArgHasArgument() { java( """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { String expected = "hello world"; assertThat(getString().equalsIgnoreCase(expected)).isTrue(); } - + String getString() { return "hello world"; } @@ -195,13 +195,13 @@ String getString() { """, """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { String expected = "hello world"; assertThat(getString()).isEqualToIgnoringCase(expected); } - + String getString() { return "hello world"; } @@ -219,13 +219,13 @@ void replacementHasArgument() { java( """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { int length = 5; assertThat(getString().length()).isEqualTo(length); } - + String getString() { return "hello world"; } @@ -233,13 +233,13 @@ String getString() { """, """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { int length = 5; assertThat(getString()).hasSize(length); } - + String getString() { return "hello world"; } @@ -258,12 +258,12 @@ void normalCase() { java( """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat(getString().trim()).isEmpty(); } - + String getString() { return "hello world"; } @@ -271,12 +271,12 @@ String getString() { """, """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat(getString()).isBlank(); } - + String getString() { return "hello world"; } @@ -297,7 +297,7 @@ void stringContains() { java( """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat("hello world".contains("lo wo")).isTrue(); @@ -307,7 +307,7 @@ void testMethod() { """, """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat("hello world").contains("lo wo"); @@ -319,6 +319,49 @@ void testMethod() { ); } + @Test + void stringContainsObjectMethod() { + rewriteRun( + spec -> spec.recipes( + new SimplifyChainedAssertJAssertion("contains", "isTrue", "contains", "java.lang.String")), + //language=java + java( + """ + import static org.assertj.core.api.Assertions.assertThat; + + class Pojo { + public String getString() { + return "lo wo"; + } + } + + class MyTest { + void testMethod() { + var pojo = new Pojo(); + assertThat("hello world".contains(pojo.getString())).isTrue(); + } + } + """, + """ + import static org.assertj.core.api.Assertions.assertThat; + + class Pojo { + public String getString() { + return "lo wo"; + } + } + + class MyTest { + void testMethod() { + var pojo = new Pojo(); + assertThat("hello world").contains(pojo.getString()); + } + } + """ + ) + ); + } + @Test void mapMethodDealsWithTwoArguments() { rewriteRun( @@ -328,16 +371,16 @@ void mapMethodDealsWithTwoArguments() { """ import java.util.Collections; import java.util.Map; - + import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { String key = "key"; String value = "value"; assertThat(getMap().get(key)).isEqualTo(value); } - + Map getMap() { return Collections.emptyMap(); } @@ -346,16 +389,16 @@ Map getMap() { """ import java.util.Collections; import java.util.Map; - + import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { String key = "key"; String value = "value"; assertThat(getMap()).containsEntry(key, value); } - + Map getMap() { return Collections.emptyMap(); } @@ -373,9 +416,9 @@ void keySetContainsWithMultipleArguments() { java( """ import java.util.Map; - + import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod(Map map) { // we don't yet support `containsKeys` @@ -395,12 +438,12 @@ void isNotEmptyTest() { java( """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat(getString().isEmpty()).isFalse(); } - + String getString() { return "hello world"; } @@ -408,12 +451,12 @@ String getString() { """, """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat(getString()).isNotEmpty(); } - + String getString() { return "hello world"; } @@ -431,12 +474,12 @@ void doesNoRunOnWrongCombination() { java( """ import static org.assertj.core.api.Assertions.assertThat; - + class MyTest { void testMethod() { assertThat(getString().isBlank()).isFalse(); } - + String getString() { return "hello world"; } @@ -463,7 +506,7 @@ void simplifyPresenceAssertion() { """ import static org.assertj.core.api.Assertions.assertThat; import java.util.Optional; - + class Test { void simpleTest(Optional o) { assertThat(o.isPresent()).isTrue(); @@ -476,7 +519,7 @@ void simpleTest(Optional o) { """ import static org.assertj.core.api.Assertions.assertThat; import java.util.Optional; - + class Test { void simpleTest(Optional o) { assertThat(o).isPresent(); @@ -502,7 +545,7 @@ void simplifiyEqualityAssertion() { """ import static org.assertj.core.api.Assertions.assertThat; import java.util.Optional; - + class Test { void simpleTest(Optional o) { assertThat(o.get()).isEqualTo("foo"); @@ -513,7 +556,7 @@ void simpleTest(Optional o) { """ import static org.assertj.core.api.Assertions.assertThat; import java.util.Optional; - + class Test { void simpleTest(Optional o) { assertThat(o).contains("foo");