diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 35a609193..7e76cec0d 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -13,6 +13,7 @@ from slither.core.expressions import NewContract from slither.core.slither_core import SlitherCore from slither.core.solidity_types import TypeAlias +from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.printers.abstract_printer import AbstractPrinter @@ -179,7 +180,66 @@ class ConstantValue(NamedTuple): # pylint: disable=inherit-non-class,too-few-pu type: str -def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-nested-blocks +def _extract_constant_from_read( + ir: Operation, + r: SourceMapping, + all_cst_used: List[ConstantValue], + all_cst_used_in_binary: Dict[str, List[ConstantValue]], + context_explored: Set[Node], +) -> None: + var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r + # Do not report struct_name in a.struct_name + if isinstance(ir, Member): + return + if isinstance(var_read, Variable) and var_read.is_constant: + # In case of type conversion we use the destination type + if isinstance(ir, TypeConversion): + if isinstance(ir.type, TypeAlias): + value_type = ir.type.type + else: + value_type = ir.type + else: + value_type = var_read.type + try: + value = ConstantFolding(var_read.expression, value_type).result() + all_cst_used.append(ConstantValue(str(value), str(value_type))) + except NotConstant: + pass + if isinstance(var_read, Constant): + all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type))) + if isinstance(var_read, StateVariable): + if var_read.node_initialization: + if var_read.node_initialization.irs: + if var_read.node_initialization in context_explored: + return + context_explored.add(var_read.node_initialization) + _extract_constants_from_irs( + var_read.node_initialization.irs, + all_cst_used, + all_cst_used_in_binary, + context_explored, + ) + + +def _extract_constant_from_binary( + ir: Binary, + all_cst_used: List[ConstantValue], + all_cst_used_in_binary: Dict[str, List[ConstantValue]], +): + for r in ir.read: + if isinstance(r, Constant): + all_cst_used_in_binary[str(ir.type)].append(ConstantValue(str(r.value), str(r.type))) + if isinstance(ir.variable_left, Constant) or isinstance(ir.variable_right, Constant): + if ir.lvalue: + try: + type_ = ir.lvalue.type + cst = ConstantFolding(ir.expression, type_).result() + all_cst_used.append(ConstantValue(str(cst.value), str(type_))) + except NotConstant: + pass + + +def _extract_constants_from_irs( irs: List[Operation], all_cst_used: List[ConstantValue], all_cst_used_in_binary: Dict[str, List[ConstantValue]], @@ -187,21 +247,7 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n ) -> None: for ir in irs: if isinstance(ir, Binary): - for r in ir.read: - if isinstance(r, Constant): - all_cst_used_in_binary[str(ir.type)].append( - ConstantValue(str(r.value), str(r.type)) - ) - if isinstance(ir.variable_left, Constant) or isinstance( - ir.variable_right, Constant - ): - if ir.lvalue: - try: - type_ = ir.lvalue.type - cst = ConstantFolding(ir.expression, type_).result() - all_cst_used.append(ConstantValue(str(cst.value), str(type_))) - except NotConstant: - pass + _extract_constant_from_binary(ir, all_cst_used, all_cst_used_in_binary) if isinstance(ir, TypeConversion): if isinstance(ir.variable, Constant): if isinstance(ir.type, TypeAlias): @@ -222,24 +268,9 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n except ValueError: # index could fail; should never happen in working solidity code pass for r in ir.read: - var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r - # Do not report struct_name in a.struct_name - if isinstance(ir, Member): - continue - if isinstance(var_read, Constant): - all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type))) - if isinstance(var_read, StateVariable): - if var_read.node_initialization: - if var_read.node_initialization.irs: - if var_read.node_initialization in context_explored: - continue - context_explored.add(var_read.node_initialization) - _extract_constants_from_irs( - var_read.node_initialization.irs, - all_cst_used, - all_cst_used_in_binary, - context_explored, - ) + _extract_constant_from_read( + ir, r, all_cst_used, all_cst_used_in_binary, context_explored + ) def _extract_constants( diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index b1fa570c6..ddadb77a1 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -13,7 +13,9 @@ TupleExpression, TypeConversion, CallExpression, + MemberAccess, ) +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.variables import Variable from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor @@ -27,7 +29,13 @@ class NotConstant(Exception): KEY = "ConstantFolding" CONSTANT_TYPES_OPERATIONS = Union[ - Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, ] @@ -69,6 +77,9 @@ def result(self) -> "Literal": # pylint: disable=import-outside-toplevel def _post_identifier(self, expression: Identifier) -> None: from slither.core.declarations.solidity_variables import SolidityFunction + from slither.core.declarations.enum import Enum + from slither.core.solidity_types.type_alias import TypeAlias + from slither.core.declarations.contract import Contract if isinstance(expression.value, Variable): if expression.value.is_constant: @@ -77,7 +88,14 @@ def _post_identifier(self, expression: Identifier) -> None: # Everything outside of literal if isinstance( expr, - (BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): cf = ConstantFolding(expr, self._type) expr = cf.result() @@ -88,7 +106,12 @@ def _post_identifier(self, expression: Identifier) -> None: elif isinstance(expression.value, SolidityFunction): set_val(expression, expression.value) else: - raise NotConstant + # Enum: We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value + # We can't handle it here because we don't have the field accessed so we do it in _post_member_access + # TypeAlias: Support when a .wrap() is done with a constant + # Contract: Support when a constatn is use from a different contract + if not isinstance(expression.value, (Enum, TypeAlias, Contract)): + raise NotConstant # pylint: disable=too-many-branches,too-many-statements def _post_binary_operation(self, expression: BinaryOperation) -> None: @@ -96,12 +119,28 @@ def _post_binary_operation(self, expression: BinaryOperation) -> None: expression_right = expression.expression_right if not isinstance( expression_left, - (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): raise NotConstant if not isinstance( expression_right, - (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): raise NotConstant left = get_val(expression_left) @@ -205,6 +244,34 @@ def _post_assignement_operation(self, expression: expressions.AssignmentOperatio raise NotConstant def _post_call_expression(self, expression: expressions.CallExpression) -> None: + from slither.core.declarations.solidity_variables import SolidityFunction + from slither.core.declarations.enum import Enum + from slither.core.solidity_types import TypeAlias + + # pylint: disable=too-many-boolean-expressions + if ( + isinstance(expression.called, Identifier) + and expression.called.value == SolidityFunction("type()") + and len(expression.arguments) == 1 + and ( + isinstance(expression.arguments[0], ElementaryTypeNameExpression) + or isinstance(expression.arguments[0], Identifier) + and isinstance(expression.arguments[0].value, Enum) + ) + ): + # Returning early to support type(ElemType).max/min or type(MyEnum).max/min + return + if ( + isinstance(expression.called.expression, Identifier) + and isinstance(expression.called.expression.value, TypeAlias) + and isinstance(expression.called, MemberAccess) + and expression.called.member_name == "wrap" + and len(expression.arguments) == 1 + ): + # Handle constants in .wrap of user defined type + set_val(expression, get_val(expression.arguments[0])) + return + called = get_val(expression.called) args = [get_val(arg) for arg in expression.arguments] if called.name == "keccak256(bytes)": @@ -220,12 +287,104 @@ def _post_conditional_expression(self, expression: expressions.ConditionalExpres def _post_elementary_type_name_expression( self, expression: expressions.ElementaryTypeNameExpression ) -> None: - raise NotConstant + # We don't have to raise an exception to support type(uint112).max or similar + pass def _post_index_access(self, expression: expressions.IndexAccess) -> None: raise NotConstant + # pylint: disable=too-many-locals def _post_member_access(self, expression: expressions.MemberAccess) -> None: + from slither.core.declarations import ( + SolidityFunction, + Contract, + EnumContract, + EnumTopLevel, + Enum, + ) + from slither.core.solidity_types import UserDefinedType, TypeAlias + + # pylint: disable=too-many-nested-blocks + if isinstance(expression.expression, CallExpression) and expression.member_name in [ + "min", + "max", + ]: + if isinstance(expression.expression.called, Identifier): + if expression.expression.called.value == SolidityFunction("type()"): + assert len(expression.expression.arguments) == 1 + type_expression_found = expression.expression.arguments[0] + type_found: Union[ElementaryType, UserDefinedType] + if isinstance(type_expression_found, ElementaryTypeNameExpression): + type_expression_found_type = type_expression_found.type + assert isinstance(type_expression_found_type, ElementaryType) + type_found = type_expression_found_type + value = ( + type_found.max if expression.member_name == "max" else type_found.min + ) + set_val(expression, value) + return + # type(enum).max/min + # Case when enum is in another contract e.g. type(C.E).max + if isinstance(type_expression_found, MemberAccess): + contract = type_expression_found.expression.value + assert isinstance(contract, Contract) + for enum in contract.enums: + if enum.name == type_expression_found.member_name: + type_found_in_expression = enum + type_found = UserDefinedType(enum) + break + else: + assert isinstance(type_expression_found, Identifier) + type_found_in_expression = type_expression_found.value + assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel)) + type_found = UserDefinedType(type_found_in_expression) + value = ( + type_found_in_expression.max + if expression.member_name == "max" + else type_found_in_expression.min + ) + set_val(expression, value) + return + elif isinstance(expression.expression, Identifier) and isinstance( + expression.expression.value, Enum + ): + # Handle direct access to enum field + set_val(expression, expression.expression.value.values.index(expression.member_name)) + return + elif isinstance(expression.expression, Identifier) and isinstance( + expression.expression.value, TypeAlias + ): + # User defined type .wrap call handled in _post_call_expression + return + elif ( + isinstance(expression.expression.value, Contract) + and expression.member_name in expression.expression.value.variables_as_dict + and expression.expression.value.variables_as_dict[expression.member_name].is_constant + ): + # Handles when a constant is accessed on another contract + variables = expression.expression.value.variables_as_dict + if isinstance(variables[expression.member_name].expression, MemberAccess): + self._post_member_access(variables[expression.member_name].expression) + set_val(expression, get_val(variables[expression.member_name].expression)) + return + + # If the variable is a Literal we convert its value to int + if isinstance(variables[expression.member_name].expression, Literal): + value = convert_string_to_int( + variables[expression.member_name].expression.converted_value + ) + # If the variable is a UnaryOperation we need convert its value to int + # and replacing possible spaces + elif isinstance(variables[expression.member_name].expression, UnaryOperation): + value = convert_string_to_int( + str(variables[expression.member_name].expression).replace(" ", "") + ) + else: + value = variables[expression.member_name].expression + + set_val(expression, value) + return + raise NotConstant def _post_new_array(self, expression: expressions.NewArray) -> None: @@ -272,6 +431,7 @@ def _post_type_conversion(self, expression: expressions.TypeConversion) -> None: TupleExpression, TypeConversion, CallExpression, + MemberAccess, ), ): raise NotConstant diff --git a/tests/unit/slithir/test_constantfolding.py b/tests/unit/slithir/test_constantfolding.py new file mode 100644 index 000000000..fcf00035b --- /dev/null +++ b/tests/unit/slithir/test_constantfolding.py @@ -0,0 +1,22 @@ +from pathlib import Path + +from slither import Slither +from slither.printers.guidance.echidna import _extract_constants, ConstantValue + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +def test_enum_max_min(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.19") + slither = Slither(Path(TEST_DATA_DIR, "constantfolding.sol").as_posix(), solc=solc_path) + + contracts = slither.get_contract_from_name("A") + + constants = _extract_constants(contracts)[0]["A"]["use()"] + + assert set(constants) == { + ConstantValue(value="2", type="uint256"), + ConstantValue(value="10", type="uint256"), + ConstantValue(value="100", type="uint256"), + ConstantValue(value="4294967295", type="uint32"), + } diff --git a/tests/unit/slithir/test_data/constantfolding.sol b/tests/unit/slithir/test_data/constantfolding.sol new file mode 100644 index 000000000..aef4a2427 --- /dev/null +++ b/tests/unit/slithir/test_data/constantfolding.sol @@ -0,0 +1,19 @@ +type MyType is uint256; + +contract A{ + + enum E{ + a,b,c + } + + + uint a = 10; + E b = type(E).max; + uint c = type(uint32).max; + MyType d = MyType.wrap(100); + + function use() public returns(uint){ + E e = b; + return a +c + MyType.unwrap(d); + } +} \ No newline at end of file