Skip to content

Commit

Permalink
Merge pull request #2574 from crytic/dev-echidna-values
Browse files Browse the repository at this point in the history
Echidna printer Improve values extraction
  • Loading branch information
montyly authored Oct 17, 2024
2 parents 79619f6 + f6b2509 commit 83e5fca
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 40 deletions.
99 changes: 65 additions & 34 deletions slither/printers/guidance/echidna.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from slither.core.expressions import NewContract
from slither.core.slither_core import SlitherCore
from slither.core.solidity_types import TypeAlias
from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.variable import Variable
from slither.printers.abstract_printer import AbstractPrinter
Expand Down Expand Up @@ -179,29 +180,74 @@ class ConstantValue(NamedTuple): # pylint: disable=inherit-non-class,too-few-pu
type: str


def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-nested-blocks
def _extract_constant_from_read(
ir: Operation,
r: SourceMapping,
all_cst_used: List[ConstantValue],
all_cst_used_in_binary: Dict[str, List[ConstantValue]],
context_explored: Set[Node],
) -> None:
var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r
# Do not report struct_name in a.struct_name
if isinstance(ir, Member):
return
if isinstance(var_read, Variable) and var_read.is_constant:
# In case of type conversion we use the destination type
if isinstance(ir, TypeConversion):
if isinstance(ir.type, TypeAlias):
value_type = ir.type.type
else:
value_type = ir.type
else:
value_type = var_read.type
try:
value = ConstantFolding(var_read.expression, value_type).result()
all_cst_used.append(ConstantValue(str(value), str(value_type)))
except NotConstant:
pass
if isinstance(var_read, Constant):
all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type)))
if isinstance(var_read, StateVariable):
if var_read.node_initialization:
if var_read.node_initialization.irs:
if var_read.node_initialization in context_explored:
return
context_explored.add(var_read.node_initialization)
_extract_constants_from_irs(
var_read.node_initialization.irs,
all_cst_used,
all_cst_used_in_binary,
context_explored,
)


def _extract_constant_from_binary(
ir: Binary,
all_cst_used: List[ConstantValue],
all_cst_used_in_binary: Dict[str, List[ConstantValue]],
):
for r in ir.read:
if isinstance(r, Constant):
all_cst_used_in_binary[str(ir.type)].append(ConstantValue(str(r.value), str(r.type)))
if isinstance(ir.variable_left, Constant) or isinstance(ir.variable_right, Constant):
if ir.lvalue:
try:
type_ = ir.lvalue.type
cst = ConstantFolding(ir.expression, type_).result()
all_cst_used.append(ConstantValue(str(cst.value), str(type_)))
except NotConstant:
pass


def _extract_constants_from_irs(
irs: List[Operation],
all_cst_used: List[ConstantValue],
all_cst_used_in_binary: Dict[str, List[ConstantValue]],
context_explored: Set[Node],
) -> None:
for ir in irs:
if isinstance(ir, Binary):
for r in ir.read:
if isinstance(r, Constant):
all_cst_used_in_binary[str(ir.type)].append(
ConstantValue(str(r.value), str(r.type))
)
if isinstance(ir.variable_left, Constant) or isinstance(
ir.variable_right, Constant
):
if ir.lvalue:
try:
type_ = ir.lvalue.type
cst = ConstantFolding(ir.expression, type_).result()
all_cst_used.append(ConstantValue(str(cst.value), str(type_)))
except NotConstant:
pass
_extract_constant_from_binary(ir, all_cst_used, all_cst_used_in_binary)
if isinstance(ir, TypeConversion):
if isinstance(ir.variable, Constant):
if isinstance(ir.type, TypeAlias):
Expand All @@ -222,24 +268,9 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n
except ValueError: # index could fail; should never happen in working solidity code
pass
for r in ir.read:
var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r
# Do not report struct_name in a.struct_name
if isinstance(ir, Member):
continue
if isinstance(var_read, Constant):
all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type)))
if isinstance(var_read, StateVariable):
if var_read.node_initialization:
if var_read.node_initialization.irs:
if var_read.node_initialization in context_explored:
continue
context_explored.add(var_read.node_initialization)
_extract_constants_from_irs(
var_read.node_initialization.irs,
all_cst_used,
all_cst_used_in_binary,
context_explored,
)
_extract_constant_from_read(
ir, r, all_cst_used, all_cst_used_in_binary, context_explored
)


def _extract_constants(
Expand Down
172 changes: 166 additions & 6 deletions slither/visitors/expression/constants_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
TupleExpression,
TypeConversion,
CallExpression,
MemberAccess,
)
from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression
from slither.core.variables import Variable
from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int
from slither.visitors.expression.expression import ExpressionVisitor
Expand All @@ -27,7 +29,13 @@ class NotConstant(Exception):
KEY = "ConstantFolding"

CONSTANT_TYPES_OPERATIONS = Union[
Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
]


Expand Down Expand Up @@ -69,6 +77,9 @@ def result(self) -> "Literal":
# pylint: disable=import-outside-toplevel
def _post_identifier(self, expression: Identifier) -> None:
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.declarations.enum import Enum
from slither.core.solidity_types.type_alias import TypeAlias
from slither.core.declarations.contract import Contract

if isinstance(expression.value, Variable):
if expression.value.is_constant:
Expand All @@ -77,7 +88,14 @@ def _post_identifier(self, expression: Identifier) -> None:
# Everything outside of literal
if isinstance(
expr,
(BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
cf = ConstantFolding(expr, self._type)
expr = cf.result()
Expand All @@ -88,20 +106,41 @@ def _post_identifier(self, expression: Identifier) -> None:
elif isinstance(expression.value, SolidityFunction):
set_val(expression, expression.value)
else:
raise NotConstant
# Enum: We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value
# We can't handle it here because we don't have the field accessed so we do it in _post_member_access
# TypeAlias: Support when a .wrap() is done with a constant
# Contract: Support when a constatn is use from a different contract
if not isinstance(expression.value, (Enum, TypeAlias, Contract)):
raise NotConstant

# pylint: disable=too-many-branches,too-many-statements
def _post_binary_operation(self, expression: BinaryOperation) -> None:
expression_left = expression.expression_left
expression_right = expression.expression_right
if not isinstance(
expression_left,
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
raise NotConstant
if not isinstance(
expression_right,
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
raise NotConstant
left = get_val(expression_left)
Expand Down Expand Up @@ -205,6 +244,34 @@ def _post_assignement_operation(self, expression: expressions.AssignmentOperatio
raise NotConstant

def _post_call_expression(self, expression: expressions.CallExpression) -> None:
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.declarations.enum import Enum
from slither.core.solidity_types import TypeAlias

# pylint: disable=too-many-boolean-expressions
if (
isinstance(expression.called, Identifier)
and expression.called.value == SolidityFunction("type()")
and len(expression.arguments) == 1
and (
isinstance(expression.arguments[0], ElementaryTypeNameExpression)
or isinstance(expression.arguments[0], Identifier)
and isinstance(expression.arguments[0].value, Enum)
)
):
# Returning early to support type(ElemType).max/min or type(MyEnum).max/min
return
if (
isinstance(expression.called.expression, Identifier)
and isinstance(expression.called.expression.value, TypeAlias)
and isinstance(expression.called, MemberAccess)
and expression.called.member_name == "wrap"
and len(expression.arguments) == 1
):
# Handle constants in .wrap of user defined type
set_val(expression, get_val(expression.arguments[0]))
return

called = get_val(expression.called)
args = [get_val(arg) for arg in expression.arguments]
if called.name == "keccak256(bytes)":
Expand All @@ -220,12 +287,104 @@ def _post_conditional_expression(self, expression: expressions.ConditionalExpres
def _post_elementary_type_name_expression(
self, expression: expressions.ElementaryTypeNameExpression
) -> None:
raise NotConstant
# We don't have to raise an exception to support type(uint112).max or similar
pass

def _post_index_access(self, expression: expressions.IndexAccess) -> None:
raise NotConstant

# pylint: disable=too-many-locals
def _post_member_access(self, expression: expressions.MemberAccess) -> None:
from slither.core.declarations import (
SolidityFunction,
Contract,
EnumContract,
EnumTopLevel,
Enum,
)
from slither.core.solidity_types import UserDefinedType, TypeAlias

# pylint: disable=too-many-nested-blocks
if isinstance(expression.expression, CallExpression) and expression.member_name in [
"min",
"max",
]:
if isinstance(expression.expression.called, Identifier):
if expression.expression.called.value == SolidityFunction("type()"):
assert len(expression.expression.arguments) == 1
type_expression_found = expression.expression.arguments[0]
type_found: Union[ElementaryType, UserDefinedType]
if isinstance(type_expression_found, ElementaryTypeNameExpression):
type_expression_found_type = type_expression_found.type
assert isinstance(type_expression_found_type, ElementaryType)
type_found = type_expression_found_type
value = (
type_found.max if expression.member_name == "max" else type_found.min
)
set_val(expression, value)
return
# type(enum).max/min
# Case when enum is in another contract e.g. type(C.E).max
if isinstance(type_expression_found, MemberAccess):
contract = type_expression_found.expression.value
assert isinstance(contract, Contract)
for enum in contract.enums:
if enum.name == type_expression_found.member_name:
type_found_in_expression = enum
type_found = UserDefinedType(enum)
break
else:
assert isinstance(type_expression_found, Identifier)
type_found_in_expression = type_expression_found.value
assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel))
type_found = UserDefinedType(type_found_in_expression)
value = (
type_found_in_expression.max
if expression.member_name == "max"
else type_found_in_expression.min
)
set_val(expression, value)
return
elif isinstance(expression.expression, Identifier) and isinstance(
expression.expression.value, Enum
):
# Handle direct access to enum field
set_val(expression, expression.expression.value.values.index(expression.member_name))
return
elif isinstance(expression.expression, Identifier) and isinstance(
expression.expression.value, TypeAlias
):
# User defined type .wrap call handled in _post_call_expression
return
elif (
isinstance(expression.expression.value, Contract)
and expression.member_name in expression.expression.value.variables_as_dict
and expression.expression.value.variables_as_dict[expression.member_name].is_constant
):
# Handles when a constant is accessed on another contract
variables = expression.expression.value.variables_as_dict
if isinstance(variables[expression.member_name].expression, MemberAccess):
self._post_member_access(variables[expression.member_name].expression)
set_val(expression, get_val(variables[expression.member_name].expression))
return

# If the variable is a Literal we convert its value to int
if isinstance(variables[expression.member_name].expression, Literal):
value = convert_string_to_int(
variables[expression.member_name].expression.converted_value
)
# If the variable is a UnaryOperation we need convert its value to int
# and replacing possible spaces
elif isinstance(variables[expression.member_name].expression, UnaryOperation):
value = convert_string_to_int(
str(variables[expression.member_name].expression).replace(" ", "")
)
else:
value = variables[expression.member_name].expression

set_val(expression, value)
return

raise NotConstant

def _post_new_array(self, expression: expressions.NewArray) -> None:
Expand Down Expand Up @@ -272,6 +431,7 @@ def _post_type_conversion(self, expression: expressions.TypeConversion) -> None:
TupleExpression,
TypeConversion,
CallExpression,
MemberAccess,
),
):
raise NotConstant
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/slithir/test_constantfolding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pathlib import Path

from slither import Slither
from slither.printers.guidance.echidna import _extract_constants, ConstantValue

TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"


def test_enum_max_min(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.19")
slither = Slither(Path(TEST_DATA_DIR, "constantfolding.sol").as_posix(), solc=solc_path)

contracts = slither.get_contract_from_name("A")

constants = _extract_constants(contracts)[0]["A"]["use()"]

assert set(constants) == {
ConstantValue(value="2", type="uint256"),
ConstantValue(value="10", type="uint256"),
ConstantValue(value="100", type="uint256"),
ConstantValue(value="4294967295", type="uint32"),
}
Loading

0 comments on commit 83e5fca

Please sign in to comment.