Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve autofix in AVOID_NULL_CHECKS rule #1201

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ open class DiktatJavaExecTaskBase @Inject constructor(
// validate configuration
require(inputs == null && excludes == null) {
"`inputs` and `excludes` arguments for diktat task are deprecated and now should be changed for `inputs {}` " +
"with configuration for PatternFilterable. Please check https://github.com/diktat-static-analysis/diktat/README.md for more info."
"with configuration for PatternFilterable. Please check https://github.com/analysis-dev/diktat/README.md for more info."
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.jetbrains.kotlin.com.intellij.psi.tree.IElementType
import org.jetbrains.kotlin.psi.KtBinaryExpression
import org.jetbrains.kotlin.psi.KtBlockExpression
import org.jetbrains.kotlin.psi.KtIfExpression
import org.jetbrains.kotlin.psi.psiUtil.blockExpressionsOrSingle

/**
* This rule check and fixes explicit null checks (explicit comparison with `null`)
Expand Down Expand Up @@ -120,56 +121,66 @@ class NullChecksRule(configRules: List<RulesConfig>) : DiktatRule(
isEqualToNull: Boolean
) {
val variableName = binaryExpression.left!!.text
val thenCodeLines = condition.extractLinesFromBlock(THEN)
val elseCodeLines = condition.extractLinesFromBlock(ELSE)
val text = if (isEqualToNull) {
when {
elseCodeLines.isNullOrEmpty() ->
if (condition.getBreakNodeFromIf(THEN)) {
"$variableName ?: break"
} else {
"$variableName ?: run {${thenCodeLines?.joinToString(prefix = "\n", postfix = "\n", separator = "\n")}}"
}
thenCodeLines!!.singleOrNull() == "null" -> """
|$variableName?.let {
|${elseCodeLines.joinToString(separator = "\n")}
|}
""".trimMargin()
thenCodeLines.singleOrNull() == "break" -> """
|$variableName?.let {
|${elseCodeLines.joinToString(separator = "\n")}
|} ?: break
""".trimMargin()
else -> """
|$variableName?.let {
|${elseCodeLines.joinToString(separator = "\n")}
|}
|?: run {
|${thenCodeLines.joinToString(separator = "\n")}
|}
""".trimMargin()
}
val thenFromExistingCode = condition.extractLinesFromBlock(THEN)
val elseFromExistingCode = condition.extractLinesFromBlock(ELSE)

// if (a == null) { foo() } else { bar() } -> if (a != null) { bar() } else { foo() }
val thenCodeLines = if (isEqualToNull) {
elseFromExistingCode
} else {
when {
elseCodeLines.isNullOrEmpty() || (elseCodeLines.singleOrNull() == "null") ->
"$variableName?.let {${thenCodeLines?.joinToString(prefix = "\n", postfix = "\n", separator = "\n")}}"
elseCodeLines.singleOrNull() == "break" ->
"$variableName?.let {${thenCodeLines?.joinToString(prefix = "\n", postfix = "\n", separator = "\n")}} ?: break"
else -> """
|$variableName?.let {
|${thenCodeLines?.joinToString(separator = "\n")}
|}
|?: run {
|${elseCodeLines.joinToString(separator = "\n")}
|}
""".trimMargin()
}
thenFromExistingCode
}
val elseCodeLines = if (isEqualToNull) {
thenFromExistingCode
} else {
elseFromExistingCode
}
val numberOfStatementsInElseBlock = if (isEqualToNull) {
(condition.treeParent.psi as KtIfExpression).then?.blockExpressionsOrSingle()?.count() ?: 0
} else {
(condition.treeParent.psi as KtIfExpression).`else`?.blockExpressionsOrSingle()?.count() ?: 0
}

val elseEditedCodeLines = getEditedElseCodeLines(elseCodeLines, numberOfStatementsInElseBlock)
val thenEditedCodeLines = getEditedThenCodeLines(variableName, thenCodeLines, elseEditedCodeLines)

val text = "$thenEditedCodeLines $elseEditedCodeLines"
val tree = KotlinParser().createNode(text)
condition.treeParent.treeParent.addChild(tree, condition.treeParent)
condition.treeParent.treeParent.removeChild(condition.treeParent)
}

private fun getEditedElseCodeLines(elseCodeLines: List<String>?, numberOfStatementsInElseBlock: Int): String = when {
// else { "null"/empty } -> ""
elseCodeLines == null || elseCodeLines.singleOrNull() == "null" -> ""
// else { bar() } -> ?: bar()
numberOfStatementsInElseBlock == 1 -> "?: ${elseCodeLines.joinToString(postfix = "\n", separator = "\n")}"
// else { ... } -> ?: run { ... }
else -> getDefaultCaseElseCodeLines(elseCodeLines)
}

@Suppress("UnsafeCallOnNullableType")
private fun getEditedThenCodeLines(
variableName: String,
thenCodeLines: List<String>?,
elseEditedCodeLines: String
): String = when {
// if (a != null) { } -> a ?: editedElse
(thenCodeLines.isNullOrEmpty() && elseEditedCodeLines.isNotEmpty()) ||
// if (a != null) { a } else { ... } -> a ?: editedElse
(thenCodeLines?.singleOrNull() == variableName && elseEditedCodeLines.isNotEmpty()) -> variableName
// if (a != null) { a.foo() } -> a?.foo()
thenCodeLines?.singleOrNull()?.startsWith("$variableName.") ?: false -> "$variableName?${thenCodeLines?.firstOrNull()!!.removePrefix(variableName)}"
// if (a != null) { break } -> a?.let { ... }
// if (a != null) { foo() } -> a?.let { ... }
else -> getDefaultCaseThenCodeLines(variableName, thenCodeLines)
}

private fun getDefaultCaseThenCodeLines(variableName: String, thenCodeLines: List<String>?): String =
"$variableName?.let {${thenCodeLines?.joinToString(prefix = "\n", postfix = "\n", separator = "\n")}}"

private fun getDefaultCaseElseCodeLines(elseCodeLines: List<String>): String = "?: run {${elseCodeLines.joinToString(prefix = "\n", postfix = "\n", separator = "\n")}}"

@Suppress("COMMENT_WHITE_SPACE", "UnsafeCallOnNullableType")
private fun nullCheckInOtherStatements(binaryExprNode: ASTNode) {
val condition = (binaryExprNode.psi as KtBinaryExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@ class NullChecksRuleFixTest : FixTestBase("test/paragraph4/null_checks", ::NullC
fun `should fix require function`() {
fixAndCompare("RequireFunctionExpected.kt", "RequireFunctionTest.kt")
}

@Test
@Tag(WarningNames.AVOID_NULL_CHECKS)
fun `should fix if conditions when assigned`() {
fixAndCompare("IfConditionAssignCheckExpected.kt", "IfConditionAssignCheckTest.kt")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package test.paragraph4.null_checks

fun foo() {
val x = a?.let {
f(a)
} ?: g(a)

val y = a ?: 0

x ?: println("NULL")

val z = x ?: run {
println("NULL")
0
}

x?.let {
f(x)
} ?: run {
println("NULL")
g(x)
}
}

fun bar() {
val x = a?.let {
f(a)
} ?: g(a)

val y = a ?: 0

x ?: println("NULL")

val z = x ?: run {
println("NULL")
0
}

x?.let {
f(x)
} ?: run {
println("NULL")
g(x)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package test.paragraph4.null_checks

fun foo() {
val x = if (a != null) f(a) else g(a)

val y = if (a != null) a else 0

if (x != null) {
x
} else {
println("NULL")
}

val z = if (x != null) {
x
} else {
println("NULL")
0
}

if (x != null) {
f(x)
} else {
println("NULL")
g(x)
}
}

fun bar() {
val x = if (a == null) g(a) else f(a)

val y = if (a == null) 0 else a

if (x == null) {
println("NULL")
} else {
x
}

val z = if (x == null) {
println("NULL")
0
} else {
x
}

if (x == null) {
println("NULL")
g(x)
} else {
f(x)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,23 @@ bar()

some?.let {
print("qwe")
}
?: run {
print("asd")
}
} ?: print("asd")

some?.let {
print("qweqwe")
}

some?.let {
print("qqq")
}
?: run {
print("www")
}
} ?: print("www")

some?.let {
print("ttt")
}

some?.let {
print("ttt")
}
?: run {
} ?: run {
null
value
}
Expand All @@ -52,11 +45,8 @@ fun foo() {
while (result != 0 ) {
result?.let {
goo()
}
?: run {
for(i in 1..10)
} ?: for(i in 1..10)
break
}
}
while (result != 0) {
result = goo()
Expand All @@ -69,3 +59,23 @@ break
}
}

fun checkSmartCases() {
val x = a?.toString() ?: "Null"
val y = a.b.c?.toString() ?: a.b.toString()
a?.let {
print()
}
a?.let {
foo()
} ?: boo()
}

fun reversedCheckSmartCases() {
val x = a?.toString() ?: "Null"
val y = a.b.c?.toString() ?: a.b.toString()
a ?: print()
a?.let {
foo()
} ?: boo()
}

Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,45 @@ fun foo() {
}
}

fun checkSmartCases() {
val x = if (a != null) {
a.toString()
} else {
"Null"
}
val y = if (a.b.c != null) {
a.b.c.toString()
} else {
a.b.toString()
}
if (a != null) {
print()
}
if (a != null) {
foo()
} else {
boo()
}
}

fun reversedCheckSmartCases() {
val x = if (a == null) {
"Null"
} else {
a.toString()
}
val y = if (a.b.c == null) {
a.b.toString()
} else {
a.b.c.toString()
}
if (a == null) {
print()
}
if (a == null) {
boo()
} else {
foo()
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ fun foo() {

prop?.let {
doAnotherSmth()
}
?: run {
doSmth()
}
} ?: doSmth()
}

fun fooo() {
Expand Down