Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to type propagation of arithmetic expressions #1449

Merged
merged 9 commits into from
Mar 8, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ abstract class Language<T : LanguageFrontend<*, *>> : Node() {
} else {
rhs
}
lhs is BooleanType && rhs is BooleanType -> lhs
else -> unknownType()
}
}
Expand All @@ -147,14 +148,13 @@ abstract class Language<T : LanguageFrontend<*, *>> : Node() {
* programming languages.
*/
open fun propagateTypeOfBinaryOperation(operation: BinaryOperator): Type {
if (operation.operatorCode == "==" || operation.operatorCode == "===") {
// A comparison, so we return the type "boolean"
return this.builtInTypes.values.firstOrNull { it is BooleanType }
?: this.builtInTypes.values.firstOrNull { it.name.localName.startsWith("bool") }
?: unknownType()
}

return when (operation.operatorCode) {
"==",
"===" ->
// A comparison, so we return the type "boolean"
this.builtInTypes.values.firstOrNull { it is BooleanType }
?: this.builtInTypes.values.firstOrNull { it.name.localName.startsWith("bool") }
?: unknownType()
"+" ->
if (operation.lhs.type is StringType) {
// string + anything => string
Expand All @@ -167,12 +167,16 @@ abstract class Language<T : LanguageFrontend<*, *>> : Node() {
}
"-",
"*",
"/" -> arithmeticOpTypePropagation(operation.lhs.type, operation.rhs.type)
"/",
"%",
"&",
"&&",
"|",
"^",
"||",
"^" -> arithmeticOpTypePropagation(operation.lhs.type, operation.rhs.type)
"<<",
">>" ->
">>",
">>>" ->
if (operation.lhs.type.isPrimitive && operation.rhs.type.isPrimitive) {
// primitive type 1 OP primitive type 2 => primitive type 1
operation.lhs.type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ import com.github.javaparser.resolution.declarations.ResolvedMethodDeclaration
import de.fraunhofer.aisec.cpg.frontends.Handler
import de.fraunhofer.aisec.cpg.frontends.HandlerInterface
import de.fraunhofer.aisec.cpg.graph.*
import de.fraunhofer.aisec.cpg.graph.declarations.MethodDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.RecordDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.*
import de.fraunhofer.aisec.cpg.graph.statements.*
import de.fraunhofer.aisec.cpg.graph.statements.expressions.*
import de.fraunhofer.aisec.cpg.graph.statements.expressions.AssignExpression
Expand Down Expand Up @@ -462,7 +461,8 @@ class ExpressionHandler(lang: JavaLanguageFrontend) :
is DoubleLiteralExpr ->
newLiteral(
literalExpr.asDoubleLiteralExpr().asDouble(),
this.primitiveType("double"),
if (literalExpr.value.endsWith("f", true)) this.primitiveType("float")
else this.primitiveType("double"),
rawNode = expr
)
is LongLiteralExpr ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,10 +619,7 @@ internal class JavaLanguageFrontendTest : BaseTest() {
it.registerLanguage(JavaLanguage())
}
val tu =
findByUniqueName(
result.components.flatMap { it.translationUnits },
"src/test/resources/fix-328/Cat.java"
)
findByUniqueName(result.components.flatMap { it.translationUnits }, file1.toString())
val namespace = tu.getDeclarationAs(0, NamespaceDeclaration::class.java)
assertNotNull(namespace)

Expand All @@ -649,7 +646,7 @@ internal class JavaLanguageFrontendTest : BaseTest() {
this.declarationHandler =
object : DeclarationHandler(this@MyJavaLanguageFrontend) {
override fun handleClassOrInterfaceDeclaration(
classInterDecl: ClassOrInterfaceDeclaration
classInterDecl: ClassOrInterfaceDeclaration,
): RecordDeclaration {
// take the original class and replace the name
val declaration =
Expand Down Expand Up @@ -800,4 +797,52 @@ internal class JavaLanguageFrontendTest : BaseTest() {
assertNotNull(jArg)
assertContains(jArg.prevDFG, loopVariable)
}

@Test
fun testArithmeticOperators() {
val file = File("src/test/resources/Issue1444.java")

val result =
TestUtils.analyze(listOf(file), file.parentFile.toPath(), true) {
it.registerLanguage(JavaLanguage())
}
val record = result.records["Operators"]
assertNotNull(record)
assertFalse { record.methods.isEmpty() }

val mainMethod = record.methods["main"]

val expressionLists = mainMethod.mcalls
assertEquals(6, expressionLists.size)

assertNotNull(mainMethod)

with(mainMethod) {
val intOperationsList = expressionLists[0]
assertEquals(14, intOperationsList.arguments.size)
assertTrue { intOperationsList.arguments.all { it.type == primitiveType("int") } }

val longOperationsList = expressionLists[1]
assertEquals(14, longOperationsList.arguments.size)
assertTrue { longOperationsList.arguments.all { it.type == primitiveType("long") } }

val floatOperationsList = expressionLists[2]
assertEquals(7, floatOperationsList.arguments.size)
assertTrue { floatOperationsList.arguments.all { it.type == primitiveType("float") } }

val doubleOperationsList = expressionLists[3]
assertEquals(7, doubleOperationsList.arguments.size)
assertTrue { doubleOperationsList.arguments.all { it.type == primitiveType("double") } }

val booleanOperationsList = expressionLists[4]
assertEquals(6, booleanOperationsList.arguments.size)
assertTrue {
booleanOperationsList.arguments.all { it.type == primitiveType("boolean") }
}

val stringOperationsList = expressionLists[5]
assertEquals(6, stringOperationsList.arguments.size)
assertTrue { stringOperationsList.arguments.all { it.type == primitiveType("String") } }
}
}
}
83 changes: 83 additions & 0 deletions cpg-language-java/src/test/resources/Issue1444.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
public class Operators {

public static void main(String[] args) {
// results should be type IntegerType("int")
List.of(
1 + 2,
3 - 4,
5 * 6,
7 / 8,
9 % 10,
11 << 12,
13 >> 14,
14 >>> 14,
15 ^ 16,
17 & 18,
19 | 20,
+21,
-22,
~23
);

// results should be type IntegerType("long")
List.of(
1L + 2,
3 - 4L,
5L * 6,
7 / 8L,
9L % 10,
11L << 12,
13L >> 14,
14L >>> 14,
15 ^ 16L,
17L & 18,
19 | 20L,
+21L,
-22L,
~23L
);

// results should be type FloatingPointType("float")
List.of(
1.f + 2,
3 - 4.f,
5.f * 6,
7 / 8.f,
9.f % 10,
+21.f,
-22.f
);

// results should be type FloatingPointType("long")
List.of(
1.f + 2.d,
3 - 4.d,
5.d * 6.f,
7.d / 8.f,
9.f % 10.d,
+21.d,
-22.d
);

// results should be type BooleanType
List.of(
true && false,
true & true,
false || true,
true | true,
false ^ true,
!false
);

// result should be type StringType
List.of(
"1" + 2,
3 + "4" ,
"5" + true,
'7' + "8",
"9" + null,
new ArrayList<Object>() + "12"
);
}

}
Loading