Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AST TypeName uint[] Bugfix #130

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions abi/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestBuilderFromSources(t *testing.T) {
expectedProto string
unresolvedReferences int64
isEmpty bool
disabled bool
}{
{
name: "Empty Contract Test",
Expand All @@ -53,6 +54,7 @@ func TestBuilderFromSources(t *testing.T) {
expectedProto: tests.ReadJsonBytesForTest(t, "abi/Empty.abi.proto").Content,
unresolvedReferences: 0,
isEmpty: true,
disabled: false,
},
{
name: "Simple Storage Contract Test",
Expand All @@ -77,6 +79,7 @@ func TestBuilderFromSources(t *testing.T) {
expectedAbi: tests.ReadJsonBytesForTest(t, "abi/SimpleStorage.abi").Content,
expectedProto: tests.ReadJsonBytesForTest(t, "abi/SimpleStorage.abi.proto").Content,
unresolvedReferences: 0,
disabled: false,
},
{
name: "OpenZeppelin ERC20 Test",
Expand Down Expand Up @@ -116,6 +119,7 @@ func TestBuilderFromSources(t *testing.T) {
expectedAbi: tests.ReadJsonBytesForTest(t, "abi/ERC20.abi").Content,
expectedProto: tests.ReadJsonBytesForTest(t, "abi/ERC20.abi.proto").Content,
unresolvedReferences: 0,
disabled: false,
},
{
name: "Token Sale ERC20 Test",
Expand Down Expand Up @@ -144,6 +148,7 @@ func TestBuilderFromSources(t *testing.T) {
expectedAbi: tests.ReadJsonBytesForTest(t, "abi/TokenSale.abi").Content,
expectedProto: tests.ReadJsonBytesForTest(t, "abi/TokenSale.abi.proto").Content,
unresolvedReferences: 0,
disabled: false,
},
{
name: "Lottery Test",
Expand All @@ -162,6 +167,7 @@ func TestBuilderFromSources(t *testing.T) {
expectedAbi: tests.ReadJsonBytesForTest(t, "abi/Lottery.abi").Content,
expectedProto: tests.ReadJsonBytesForTest(t, "abi/Lottery.abi.proto").Content,
unresolvedReferences: 0,
disabled: false,
},
{
name: "Cheelee Test", // Took this one as I could discover ipfs metadata :joy:
Expand Down Expand Up @@ -240,11 +246,17 @@ func TestBuilderFromSources(t *testing.T) {
expectedAbi: tests.ReadJsonBytesForTest(t, "abi/TransparentUpgradeableProxy.abi").Content,
expectedProto: tests.ReadJsonBytesForTest(t, "abi/TransparentUpgradeableProxy.abi.proto").Content,
unresolvedReferences: 0,
disabled: false,
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {

if testCase.disabled {
return
}

builder, err := NewBuilderFromSources(context.TODO(), testCase.sources)
assert.NoError(t, err)
assert.NotNil(t, builder)
Expand Down
6 changes: 6 additions & 0 deletions ast/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func normalizeTypeName(typeName string) string {
func normalizeTypeDescription(typeName string) (string, string) {
isArray := strings.Contains(typeName, "[") && strings.Contains(typeName, "]")
isSlice := strings.HasSuffix(typeName, "[]")
isPrefixSlice := strings.HasPrefix(typeName, "[]")

switch {
case isArray:
Expand All @@ -86,6 +87,11 @@ func normalizeTypeDescription(typeName string) (string, string) {
normalizedTypePart := normalizeTypeName(typePart)
return normalizedTypePart + "[]", fmt.Sprintf("t_%s_slice", normalizedTypePart)

case isPrefixSlice:
typePart := typeName[2:]
normalizedTypePart := normalizeTypeName(typePart)
return "[]" + normalizedTypePart, fmt.Sprintf("t_%s_slice", normalizedTypePart)

case strings.HasPrefix(typeName, "uint"):
if typeName == "uint" {
return "uint256", "t_uint256"
Expand Down
127 changes: 104 additions & 23 deletions ast/type_name.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ func (t *TypeName) parseTypeName(unit *SourceUnit[Node[ast_pb.SourceUnit]], pare
ParentIndex: parentNodeId,
}

if t.GetType() == ast_pb.NodeType_NT_DEFAULT {
t.NodeType = ast_pb.NodeType_IDENTIFIER
}

if ctx.ElementaryTypeName() != nil {
normalizedTypeName, normalizedTypeIdentifier := normalizeTypeDescription(
ctx.ElementaryTypeName().GetText(),
Expand All @@ -325,28 +329,51 @@ func (t *TypeName) parseTypeName(unit *SourceUnit[Node[ast_pb.SourceUnit]], pare
zap.String("function_type_name", ctx.FunctionTypeName().GetText()),
zap.String("type", fmt.Sprintf("%T", ctx.FunctionTypeName())),
)
} else {
} else if ctx.Expression() != nil {
zap.L().Warn(
"Expression type name is not supported yet @ TypeName.parseTypeName",
zap.String("function_type_name", ctx.FunctionTypeName().GetText()),
zap.String("type", fmt.Sprintf("%T", ctx.FunctionTypeName())),
)
} else if ctx.IdentifierPath() != nil {
pathCtx := ctx.IdentifierPath()

// It seems to be a user-defined type but that does not exist as a type in the parser...
t.NodeType = ast_pb.NodeType_USER_DEFINED_PATH_NAME

pathCtx := ctx.IdentifierPath()
if pathCtx != nil {
t.PathNode = &PathNode{
Id: t.GetNextID(),
Name: pathCtx.GetText(),
Src: SrcNode{
Id: t.GetNextID(),
Line: int64(pathCtx.GetStart().GetLine()),
Column: int64(pathCtx.GetStart().GetColumn()),
Start: int64(pathCtx.GetStart().GetStart()),
End: int64(pathCtx.GetStop().GetStop()),
Length: int64(pathCtx.GetStop().GetStop() - pathCtx.GetStart().GetStart() + 1),
ParentIndex: t.GetId(),
},

NodeType: ast_pb.NodeType_IDENTIFIER_PATH,
}
t.PathNode = &PathNode{
Id: t.GetNextID(),
Name: pathCtx.GetText(),
Src: SrcNode{
Id: t.GetNextID(),
Line: int64(pathCtx.GetStart().GetLine()),
Column: int64(pathCtx.GetStart().GetColumn()),
Start: int64(pathCtx.GetStart().GetStart()),
End: int64(pathCtx.GetStop().GetStop()),
Length: int64(pathCtx.GetStop().GetStop() - pathCtx.GetStart().GetStart() + 1),
ParentIndex: t.GetId(),
},

NodeType: ast_pb.NodeType_IDENTIFIER_PATH,
}

normalizedTypeName, normalizedTypeIdentifier := normalizeTypeDescription(
pathCtx.GetText(),
)

switch normalizedTypeIdentifier {
case "t_address":
t.StateMutability = ast_pb.Mutability_NONPAYABLE
case "t_address_payable":
t.StateMutability = ast_pb.Mutability_PAYABLE
}

if len(normalizedTypeName) > 0 {
t.TypeDescription = &TypeDescription{
TypeIdentifier: normalizedTypeIdentifier,
TypeString: normalizedTypeName,
}
} else {
if refId, refTypeDescription := t.GetResolver().ResolveByNode(t, pathCtx.GetText()); refTypeDescription != nil {
if t.PathNode != nil {
t.PathNode.ReferencedDeclaration = refId
Expand All @@ -355,10 +382,34 @@ func (t *TypeName) parseTypeName(unit *SourceUnit[Node[ast_pb.SourceUnit]], pare
t.TypeDescription = refTypeDescription
}
}
}
} else if ctx.TypeName() != nil {
t.generateTypeName(unit, ctx.TypeName(), t, t)
} else {
normalizedTypeName, normalizedTypeIdentifier := normalizeTypeDescription(
t.Name,
)

if t.GetType() == ast_pb.NodeType_NT_DEFAULT {
t.NodeType = ast_pb.NodeType_IDENTIFIER
switch normalizedTypeIdentifier {
case "t_address":
t.StateMutability = ast_pb.Mutability_NONPAYABLE
case "t_address_payable":
t.StateMutability = ast_pb.Mutability_PAYABLE
}

if len(normalizedTypeName) > 0 {
t.TypeDescription = &TypeDescription{
TypeIdentifier: normalizedTypeIdentifier,
TypeString: normalizedTypeName,
}
} else {
if refId, refTypeDescription := t.GetResolver().ResolveByNode(t, t.Name); refTypeDescription != nil {
if t.PathNode != nil {
t.PathNode.ReferencedDeclaration = refId
}
t.ReferencedDeclaration = refId
t.TypeDescription = refTypeDescription
}
}
}
}

Expand Down Expand Up @@ -485,7 +536,7 @@ func (t *TypeName) parseMappingTypeName(unit *SourceUnit[Node[ast_pb.SourceUnit]
}

// generateTypeName generates the TypeName based on the given context.
func (t *TypeName) generateTypeName(sourceUnit *SourceUnit[Node[ast_pb.SourceUnit]], ctx interface{}, parentNode *TypeName, typeNameNode *TypeName) *TypeName {
func (t *TypeName) generateTypeName(sourceUnit *SourceUnit[Node[ast_pb.SourceUnit]], ctx any, parentNode *TypeName, typeNameNode *TypeName) *TypeName {
typeName := &TypeName{
ASTBuilder: t.ASTBuilder,
Id: t.GetNextID(),
Expand Down Expand Up @@ -555,6 +606,7 @@ func (t *TypeName) generateTypeName(sourceUnit *SourceUnit[Node[ast_pb.SourceUni
}
parentNode.TypeDescription = t.TypeDescription
typeName = typeNameNode

case parser.ITypeNameContext:
typeName.Name = specificCtx.GetText()
typeName.Src = SrcNode{
Expand Down Expand Up @@ -595,6 +647,7 @@ func (t *TypeName) generateTypeName(sourceUnit *SourceUnit[Node[ast_pb.SourceUni
TypeIdentifier: normalizedTypeIdentifier,
}
}

}

// We're still not able to discover reference, so what we're going to do now is look for the references...
Expand All @@ -608,7 +661,7 @@ func (t *TypeName) generateTypeName(sourceUnit *SourceUnit[Node[ast_pb.SourceUni
return typeName
}

// parseElementaryTypeName parses the ElementaryTypeName from the given ElementaryTypeNameContext.
// parseFunctionTypeName parses the ElementaryTypeName from the given ElementaryTypeNameContext.
func (t *TypeName) parseFunctionTypeName(unit *SourceUnit[Node[ast_pb.SourceUnit]], parentNodeId int64, ctx *parser.FunctionTypeNameContext) {
t.Name = "function"
t.NodeType = ast_pb.NodeType_FUNCTION_TYPE_NAME
Expand All @@ -617,6 +670,14 @@ func (t *TypeName) parseFunctionTypeName(unit *SourceUnit[Node[ast_pb.SourceUnit
t.TypeDescription = t.Expression.GetTypeDescription()
}

func (t *TypeName) parsePrimaryExpression(unit *SourceUnit[Node[ast_pb.SourceUnit]], fnNode Node[NodeType], parentNodeId int64, ctx *parser.PrimaryExpressionContext) {
t.Name = "function"
t.NodeType = ast_pb.NodeType_IDENTIFIER
statement := NewPrimaryExpression(t.ASTBuilder)
t.Expression = statement.Parse(unit, nil, fnNode, nil, nil, &PathNode{Id: parentNodeId}, ctx)
t.TypeDescription = t.Expression.GetTypeDescription()
}

// Parse parses the TypeName from the given TypeNameContext.
func (t *TypeName) Parse(unit *SourceUnit[Node[ast_pb.SourceUnit]], fnNode Node[NodeType], parentNodeId int64, ctx parser.ITypeNameContext) {
t.Id = t.GetNextID()
Expand All @@ -643,6 +704,15 @@ func (t *TypeName) Parse(unit *SourceUnit[Node[ast_pb.SourceUnit]], fnNode Node[
t.parseTypeName(unit, parentNodeId, childCtx)
case *parser.FunctionTypeNameContext:
t.parseFunctionTypeName(unit, parentNodeId, childCtx)
case *parser.PrimaryExpressionContext:
t.parsePrimaryExpression(unit, fnNode, parentNodeId, childCtx)
case *antlr.TerminalNodeImpl:
continue
default:
zap.L().Warn(
"TypeName child not recognized",
zap.String("type", fmt.Sprintf("%T", childCtx)),
)
}
}

Expand All @@ -651,6 +721,17 @@ func (t *TypeName) Parse(unit *SourceUnit[Node[ast_pb.SourceUnit]], fnNode Node[
t.Expression = expression.Parse(unit, nil, fnNode, nil, nil, nil, ctx.Expression())
t.TypeDescription = t.Expression.GetTypeDescription()
}

if t.GetTypeDescription() == nil {
normalizedTypeName, normalizedTypeIdentifier := normalizeTypeDescription(
t.Name,
)

t.TypeDescription = &TypeDescription{
TypeString: normalizedTypeName,
TypeIdentifier: normalizedTypeIdentifier,
}
}
}

// ParseMul parses the TypeName from the given TermalNode.
Expand Down
2 changes: 1 addition & 1 deletion data/solc/releases/releases.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion data/tests/audits/ERC20.slither.raw.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion data/tests/audits/Lottery.slither.raw.json

Large diffs are not rendered by default.

Loading