Skip to content

Commit

Permalink
Overall improvements (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x19 authored May 19, 2024
1 parent fe21f3c commit 3be8068
Show file tree
Hide file tree
Showing 24 changed files with 625 additions and 80 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ go.work
solgo
playground/*
bin/*
.idea
6 changes: 5 additions & 1 deletion ast/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (b *ASTBuilder) ToJSON() ([]byte, error) {
return b.InterfaceToJSON(b.tree.GetRoot())
}

// ToPrettyJSON converts the provided data to a JSON byte array.
// InterfaceToJSON converts the provided data to a JSON byte array.
func (b *ASTBuilder) InterfaceToJSON(data interface{}) ([]byte, error) {
return json.Marshal(data)
}
Expand Down Expand Up @@ -129,6 +129,10 @@ func (b *ASTBuilder) ImportFromJSON(ctx context.Context, jsonBytes []byte) (*Roo
return toReturn, nil
}

func (b *ASTBuilder) GetImports() []Node[NodeType] {
return b.currentImports
}

// GarbageCollect cleans up the ASTBuilder after resolving references.
func (b *ASTBuilder) GarbageCollect() {
b.currentEnums = nil
Expand Down
1 change: 1 addition & 0 deletions ast/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ func (c *Contract) Parse(unitCtx *parser.SourceUnitContext, ctx *parser.Contract
)

contractId := c.GetNextID()

contractNode := &Contract{
Id: contractId,
Name: ctx.Identifier().GetText(),
Expand Down
1 change: 1 addition & 0 deletions ast/reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ func (r *Resolver) resolveEntrySourceUnit() {
for _, entry := range node.GetExportedSymbols() {
if len(r.sources.EntrySourceUnitName) > 0 &&
r.sources.EntrySourceUnitName == entry.GetName() {

r.tree.astRoot.SetEntrySourceUnit(entry.GetId())
return
}
Expand Down
20 changes: 20 additions & 0 deletions ast/source_unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type SourceUnit[T NodeType] struct {
AbsolutePath string `json:"absolutePath"` // AbsolutePath is the absolute path of the source unit.
Name string `json:"name"` // Name is the name of the source unit. This is going to be one of the following: contract, interface or library name. It's here for convenience.
NodeType ast_pb.NodeType `json:"nodeType"` // NodeType is the type of the AST node.
Kind ast_pb.NodeType `json:"kind"` // Kind is the type of the AST node (contract, library, interface).
Nodes []Node[NodeType] `json:"nodes"` // Nodes is the list of AST nodes.
Src SrcNode `json:"src"` // Src is the source code location.
}
Expand Down Expand Up @@ -107,6 +108,11 @@ func (s *SourceUnit[T]) GetType() ast_pb.NodeType {
return s.NodeType
}

// GetKind returns the type of the source unit.
func (s *SourceUnit[T]) GetKind() ast_pb.NodeType {
return s.Kind
}

// GetSrc returns the source code location of the source unit.
func (s *SourceUnit[T]) GetSrc() SrcNode {
return s.Src
Expand Down Expand Up @@ -302,6 +308,7 @@ func (b *ASTBuilder) EnterSourceUnit(ctx *parser.SourceUnitContext) {
if interfaceCtx, ok := child.(*parser.InterfaceDefinitionContext); ok {
license := getLicenseFromSources(b.sources, b.comments, interfaceCtx.Identifier().GetText())
sourceUnit := NewSourceUnit[Node[ast_pb.SourceUnit]](b, interfaceCtx.Identifier().GetText(), license)
sourceUnit.Kind = ast_pb.NodeType_KIND_INTERFACE
interfaceNode := NewInterfaceDefinition(b)
interfaceNode.Parse(ctx, interfaceCtx, rootNode, sourceUnit)
b.sourceUnits = append(b.sourceUnits, sourceUnit)
Expand All @@ -310,6 +317,7 @@ func (b *ASTBuilder) EnterSourceUnit(ctx *parser.SourceUnitContext) {
if libraryCtx, ok := child.(*parser.LibraryDefinitionContext); ok {
license := getLicenseFromSources(b.sources, b.comments, libraryCtx.Identifier().GetText())
sourceUnit := NewSourceUnit[Node[ast_pb.SourceUnit]](b, libraryCtx.Identifier().GetText(), license)
sourceUnit.Kind = ast_pb.NodeType_KIND_LIBRARY
libraryNode := NewLibraryDefinition(b)
libraryNode.Parse(ctx, libraryCtx, rootNode, sourceUnit)
b.sourceUnits = append(b.sourceUnits, sourceUnit)
Expand All @@ -318,11 +326,23 @@ func (b *ASTBuilder) EnterSourceUnit(ctx *parser.SourceUnitContext) {
if contractCtx, ok := child.(*parser.ContractDefinitionContext); ok {
license := getLicenseFromSources(b.sources, b.comments, contractCtx.Identifier().GetText())
sourceUnit := NewSourceUnit[Node[ast_pb.SourceUnit]](b, contractCtx.Identifier().GetText(), license)
sourceUnit.Kind = ast_pb.NodeType_KIND_CONTRACT
contractNode := NewContractDefinition(b)
contractNode.Parse(ctx, contractCtx, rootNode, sourceUnit)
b.sourceUnits = append(b.sourceUnits, sourceUnit)
}
}

// Idea here is to basically set the source unit entry name as soon as we have parsed all of the classes.
// Now this won't be possible always but nevertheless. (In rest of the cases, resolver will take care of it)
if b.sources.EntrySourceUnitName != "" {
for _, sourceUnit := range b.sourceUnits {
if b.sources.EntrySourceUnitName == sourceUnit.GetName() {
rootNode.SetEntrySourceUnit(sourceUnit.GetId())
return
}
}
}
}

// ExitSourceUnit is called when the ASTBuilder exits a source unit context.
Expand Down
12 changes: 6 additions & 6 deletions ast/src.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import (

// SrcNode represents a node in the source code.
type SrcNode struct {
Line int64 `json:"line"` // Line number of the source node in the source code.
Column int64 `json:"column"` // Column number of the source node in the source code.
Start int64 `json:"start"` // Start position of the source node in the source code.
End int64 `json:"end"` // End position of the source node in the source code.
Length int64 `json:"length"` // Length of the source node in the source code.
ParentIndex int64 `json:"parent_index,omitempty"` // Index of the parent node in the source code.
Line int64 `json:"line"` // Line number of the source node in the source code.
Column int64 `json:"column"` // Column number of the source node in the source code.
Start int64 `json:"start"` // Start position of the source node in the source code.
End int64 `json:"end"` // End position of the source node in the source code.
Length int64 `json:"length"` // Length of the source node in the source code.
ParentIndex int64 `json:"parentIndex,omitempty"` // Index of the parent node in the source code.
}

// GetLine returns the line number of the source node in the source code.
Expand Down
100 changes: 98 additions & 2 deletions ast/state_variable.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package ast

import (
"strings"

v3 "github.com/cncf/xds/go/xds/type/v3"
"github.com/goccy/go-json"
ast_pb "github.com/unpackdev/protos/dist/go/ast"
"github.com/unpackdev/solgo/parser"
"strings"
)

// StateVariableDeclaration represents a state variable declaration in the Solidity abstract syntax tree (AST).
Expand Down Expand Up @@ -145,6 +145,102 @@ func (v *StateVariableDeclaration) GetInitialValue() Node[NodeType] {
return v.InitialValue
}

// UnmarshalJSON customizes the JSON unmarshaling for StateVariableDeclaration.
func (v *StateVariableDeclaration) UnmarshalJSON(data []byte) error {
var tempMap map[string]json.RawMessage
if err := json.Unmarshal(data, &tempMap); err != nil {
return err
}

if id, ok := tempMap["id"]; ok {
if err := json.Unmarshal(id, &v.Id); err != nil {
return err
}
}

if name, ok := tempMap["name"]; ok {
if err := json.Unmarshal(name, &v.Name); err != nil {
return err
}
}

if isConstant, ok := tempMap["isConstant"]; ok {
if err := json.Unmarshal(isConstant, &v.Constant); err != nil {
return err
}
}

if isStateVariable, ok := tempMap["isStateVariable"]; ok {
if err := json.Unmarshal(isStateVariable, &v.StateVariable); err != nil {
return err
}
}

if nodeType, ok := tempMap["nodeType"]; ok {
if err := json.Unmarshal(nodeType, &v.NodeType); err != nil {
return err
}
}

if visibility, ok := tempMap["visibility"]; ok {
if err := json.Unmarshal(visibility, &v.Visibility); err != nil {
return err
}
}

if storageLocation, ok := tempMap["storageLocation"]; ok {
if err := json.Unmarshal(storageLocation, &v.StorageLocation); err != nil {
return err
}
}

if mutability, ok := tempMap["mutability"]; ok {
if err := json.Unmarshal(mutability, &v.StateMutability); err != nil {
return err
}
}

if src, ok := tempMap["src"]; ok {
if err := json.Unmarshal(src, &v.Src); err != nil {
return err
}
}

if scope, ok := tempMap["scope"]; ok {
if err := json.Unmarshal(scope, &v.Scope); err != nil {
return err
}
}

if expression, ok := tempMap["initialValue"]; ok {
if err := json.Unmarshal(expression, &v.InitialValue); err != nil {
var tempNodeMap map[string]json.RawMessage
if err := json.Unmarshal(expression, &tempNodeMap); err != nil {
return err
}

var tempNodeType ast_pb.NodeType
if err := json.Unmarshal(tempNodeMap["nodeType"], &tempNodeType); err != nil {
return err
}

node, err := unmarshalNode(expression, tempNodeType)
if err != nil {
return err
}
v.InitialValue = node
}
}

if typeDescription, ok := tempMap["typeDescription"]; ok {
if err := json.Unmarshal(typeDescription, &v.TypeDescription); err != nil {
return err
}
}

return nil
}

// ToProto returns the protobuf representation of the state variable declaration.
func (v *StateVariableDeclaration) ToProto() NodeType {
proto := ast_pb.StateVariable{
Expand Down
26 changes: 23 additions & 3 deletions bindings/otterscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package bindings
import (
"context"
"fmt"
"github.com/pkg/errors"

"github.com/ethereum/go-ethereum/common"
"github.com/unpackdev/solgo/utils"
Expand All @@ -19,9 +20,9 @@ type CreatorInformation struct {
// GetContractCreator queries the Ethereum blockchain to find the creator of a specified smart contract. This method
// utilizes the Ethereum JSON-RPC API to request creator information, which includes both the creator's address and
// the transaction hash of the contract's creation. It's a valuable tool for auditing and tracking the origins of
// contracts on the network. WORKS ONLY WITH ERIGON NODE - OR NODES THAT SUPPORT OTTERSCAN!
// contracts on the network. WORKS ONLY WITH ERIGON NODE OR QUICKNODE PROVIDER - OR NODES THAT SUPPORT OTTERSCAN!
func (m *Manager) GetContractCreator(ctx context.Context, network utils.Network, contract common.Address) (*CreatorInformation, error) {
client := m.clientPool.GetClientByGroup(string(network))
client := m.clientPool.GetClientByGroup(network.String())
if client == nil {
return nil, fmt.Errorf("client not found for network %s", network)
}
Expand All @@ -30,7 +31,26 @@ func (m *Manager) GetContractCreator(ctx context.Context, network utils.Network,
var result *CreatorInformation

if err := rpcClient.CallContext(ctx, &result, "ots_getContractCreator", contract.Hex()); err != nil {
return nil, fmt.Errorf("failed to fetch otterscan creator information: %v", err)
return nil, errors.Wrap(err, "failed to fetch otterscan creator information")
}

return result, nil
}

// GetTransactionBySenderAndNonce retrieves a transaction hash based on a specific sender's address and nonce.
// This function also utilizes the Ethereum JSON-RPC API and requires a node that supports specific transaction lookup
// by sender and nonce, which is particularly useful for tracking transaction sequences and debugging transaction flows.
func (m *Manager) GetTransactionBySenderAndNonce(ctx context.Context, network utils.Network, sender common.Address, nonce int64) (*common.Hash, error) {
client := m.clientPool.GetClientByGroup(network.String())
if client == nil {
return nil, fmt.Errorf("client not found for network %s", network)
}

rpcClient := client.GetRpcClient()
var result *common.Hash

if err := rpcClient.CallContext(ctx, &result, "ots_getTransactionBySenderAndNonce", sender.Hex(), nonce); err != nil {
return nil, errors.Wrap(err, "failed to fetch otterscan get transaction by sender and nonce information")
}

return result, nil
Expand Down
26 changes: 26 additions & 0 deletions bindings/trace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package bindings

import (
"context"
"fmt"
"github.com/pkg/errors"

"github.com/ethereum/go-ethereum/common"
"github.com/unpackdev/solgo/utils"
)

func (m *Manager) TraceCallMany(ctx context.Context, network utils.Network, sender common.Address, nonce int64) (*common.Hash, error) {
client := m.clientPool.GetClientByGroup(network.String())
if client == nil {
return nil, fmt.Errorf("client not found for network %s", network)
}

rpcClient := client.GetRpcClient()
var result *common.Hash

if err := rpcClient.CallContext(ctx, &result, "trace_callMany", sender.Hex(), nonce); err != nil {
return nil, errors.Wrap(err, "failed to execute trace_callMany")
}

return result, nil
}
12 changes: 9 additions & 3 deletions contracts/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ type Descriptor struct {

// SourcesRaw is the raw sources from Etherscan|BscScan|etc. Should not be used anywhere except in
// the contract discovery process.
SourcesRaw *etherscan.Contract `json:"-"`
Sources *solgo.Sources `json:"sources,omitempty"`
SourceProvider string `json:"source_provider,omitempty"`
SourcesRaw *etherscan.Contract `json:"-"`
Sources *solgo.Sources `json:"sources,omitempty"`
SourcesUnsorted *solgo.Sources `json:"-"`
SourceProvider string `json:"source_provider,omitempty"`

// Source detection related fields.
Detector *detector.Detector `json:"-"`
Expand Down Expand Up @@ -144,6 +145,11 @@ func (d *Descriptor) GetSources() *solgo.Sources {
return d.Sources
}

// GetUnsortedSources returns the parsed sources of the contract, providing a structured view of the contract's code.
func (d *Descriptor) GetUnsortedSources() *solgo.Sources {
return d.SourcesUnsorted
}

// GetSourcesRaw returns the raw contract source as obtained from external providers like Etherscan.
func (d *Descriptor) GetSourcesRaw() *etherscan.Contract {
return d.SourcesRaw
Expand Down
4 changes: 4 additions & 0 deletions contracts/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ func (c *Contract) Parse(ctx context.Context) error {
)
return err
}

// Sets the address for more understanding when we need to troubleshoot contract parsing
parser.GetIR().SetAddress(c.addr)

c.descriptor.Detector = parser
c.descriptor.SolgoVersion = utils.GetBuildVersionByModule("github.com/unpackdev/solgo")

Expand Down
12 changes: 12 additions & 0 deletions contracts/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,19 @@ func (c *Contract) DiscoverSourceCode(ctx context.Context) error {
return fmt.Errorf("failed to create new sources from etherscan response: %s", err)
}

unsortedSources, err := solgo.NewUnsortedSourcesFromEtherScan(response.Name, response.SourceCode)
if err != nil {
zap.L().Error(
"failed to create new unsorted sources from etherscan response",
zap.Error(err),
zap.String("network", c.network.String()),
zap.String("contract_address", c.addr.String()),
)
return fmt.Errorf("failed to create new unsorted sources from etherscan response: %s", err)
}

c.descriptor.Sources = sources
c.descriptor.SourcesUnsorted = unsortedSources

license := strings.ReplaceAll(c.descriptor.SourcesRaw.LicenseType, "\r", "")
license = strings.ReplaceAll(license, "\n", "")
Expand Down
6 changes: 3 additions & 3 deletions ir/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func (c *Contract) ToProto() *ir_pb.Contract {

// processContract processes the contract unit and returns the Contract.
func (b *Builder) processContract(unit *ast.SourceUnit[ast.Node[ast_pb.SourceUnit]]) *Contract {
contract := getContractByNodeType(unit.GetContract())
contract := GetContractByNodeType(unit.GetContract())
contractNode := &Contract{
Unit: unit,

Expand Down Expand Up @@ -360,8 +360,8 @@ func (b *Builder) processContract(unit *ast.SourceUnit[ast.Node[ast_pb.SourceUni
return contractNode
}

// getContractByNodeType returns the ContractNode based on the node type.
func getContractByNodeType(c ast.Node[ast.NodeType]) ContractNode {
// GetContractByNodeType returns the ContractNode based on the node type.
func GetContractByNodeType(c ast.Node[ast.NodeType]) ContractNode {
switch contract := c.(type) {
case *ast.Library:
return contract
Expand Down
Loading

0 comments on commit 3be8068

Please sign in to comment.