Skip to content

Commit

Permalink
Better struct handling in code generation util
Browse files Browse the repository at this point in the history
  • Loading branch information
webthethird committed Jul 27, 2023
1 parent 8f1a875 commit baa7fb5
Showing 1 changed file with 100 additions and 12 deletions.
112 changes: 100 additions & 12 deletions slither/utils/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
MappingType,
ArrayType,
ElementaryType,
TypeAlias
)
from slither.core.declarations import Structure, Enum, Contract
from slither.core.declarations import Structure, StructureContract, Enum, Contract

if TYPE_CHECKING:
from slither.core.declarations import FunctionContract, CustomErrorContract
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.structure_variable import StructureVariable


# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments,too-many-locals,too-many-branches
def generate_interface(
contract: "Contract",
unroll_structs: bool = True,
Expand Down Expand Up @@ -56,12 +58,47 @@ def generate_interface(
for enum in contract.enums:
interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n"
if include_structs:
for struct in contract.structures:
# Include structures defined in this contract and at the top level
structs = contract.structures + contract.compilation_unit.structures_top_level
# Function signatures may reference other structures as well
# Include structures defined in libraries used for them
for _for in contract.using_for.keys():
if (
isinstance(_for, UserDefinedType)
and isinstance(_for.type, StructureContract)
and _for.type not in structs
):
structs.append(_for.type)
# Include any other structures used as function arguments/returns
for func in contract.functions_entry_points:
for arg in func.parameters + func.returns:
_type = arg.type
if isinstance(_type, ArrayType):
_type = _type.type
while isinstance(_type, MappingType):
_type = _type.type_to
if isinstance(_type, UserDefinedType):
_type = _type.type
if isinstance(_type, Structure) and _type not in structs:
structs.append(_type)
for struct in structs:
interface += generate_struct_interface_str(struct, indent=4)
for elem in struct.elems_ordered:
if (
isinstance(elem.type, UserDefinedType)
and isinstance(elem.type.type, StructureContract)
and elem.type.type not in structs
):
structs.append(elem.type.type)
for var in contract.state_variables_entry_points:
interface += f" function {generate_interface_variable_signature(var, unroll_structs)};\n"
# if any(func.name == var.name for func in contract.functions_entry_points):
# # ignore public variables that override a public function
# continue
var_sig = generate_interface_variable_signature(var, unroll_structs)
if var_sig is not None and var_sig != "":
interface += f" function {var_sig};\n"
for func in contract.functions_entry_points:
if func.is_constructor or func.is_fallback or func.is_receive:
if func.is_constructor or func.is_fallback or func.is_receive or not func.is_implemented:
continue
interface += (
f" function {generate_interface_function_signature(func, unroll_structs)};\n"
Expand All @@ -75,6 +112,10 @@ def generate_interface_variable_signature(
) -> Optional[str]:
if var.visibility in ["private", "internal"]:
return None
if isinstance(var.type, UserDefinedType) and isinstance(var.type.type, Structure):
for elem in var.type.type.elems_ordered:
if isinstance(elem.type, MappingType):
return ""
if unroll_structs:
params = [
convert_type_for_solidity_signature_to_string(x).replace("(", "").replace(")", "")
Expand All @@ -93,6 +134,11 @@ def generate_interface_variable_signature(
_type = _type.type_to
while isinstance(_type, (ArrayType, UserDefinedType)):
_type = _type.type
if isinstance(_type, TypeAlias):
_type = _type.type
if isinstance(_type, Structure):
if any(isinstance(elem.type, MappingType) for elem in _type.elems_ordered):
return ""
ret = str(_type)
if isinstance(_type, Structure) or (isinstance(_type, Type) and _type.is_dynamic):
ret += " memory"
Expand Down Expand Up @@ -125,6 +171,8 @@ def format_var(var: "LocalVariable", unroll: bool) -> str:
.replace("(", "")
.replace(")", "")
)
if var.type.is_dynamic:
return f"{_handle_dynamic_struct_elem(var.type)} {var.location}"
if isinstance(var.type, ArrayType) and isinstance(
var.type.type, (UserDefinedType, ElementaryType)
):
Expand All @@ -135,12 +183,14 @@ def format_var(var: "LocalVariable", unroll: bool) -> str:
+ f" {var.location}"
)
if isinstance(var.type, UserDefinedType):
if isinstance(var.type.type, (Structure, Enum)):
if isinstance(var.type.type, Structure):
return f"{str(var.type.type)} memory"
if isinstance(var.type.type, Enum):
return str(var.type.type)
if isinstance(var.type.type, Contract):
return "address"
if var.type.is_dynamic:
return f"{var.type} {var.location}"
if isinstance(var.type, TypeAlias):
return str(var.type.type)
return str(var.type)

name, _, _ = func.signature
Expand All @@ -154,6 +204,12 @@ def format_var(var: "LocalVariable", unroll: bool) -> str:
view = " view" if func.view and not func.pure else ""
pure = " pure" if func.pure else ""
payable = " payable" if func.payable else ""
# Make sure the function doesn't return a struct with nested mappings
for ret in func.returns:
if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Structure):
for elem in ret.type.type.elems_ordered:
if isinstance(elem.type, MappingType):
return ""
returns = [format_var(ret, unroll_structs) for ret in func.returns]
parameters = [format_var(param, unroll_structs) for param in func.parameters]
_interface_signature_str = (
Expand Down Expand Up @@ -184,17 +240,49 @@ def generate_struct_interface_str(struct: "Structure", indent: int = 0) -> str:
spaces += " "
definition = f"{spaces}struct {struct.name} {{\n"
for elem in struct.elems_ordered:
if isinstance(elem.type, UserDefinedType):
if isinstance(elem.type.type, (Structure, Enum)):
if elem.type.is_dynamic:
definition += f"{spaces} {_handle_dynamic_struct_elem(elem.type)} {elem.name};\n"
elif isinstance(elem.type, UserDefinedType):
if isinstance(elem.type.type, Structure):
definition += f"{spaces} {elem.type.type} {elem.name};\n"
elif isinstance(elem.type.type, Contract):
definition += f"{spaces} address {elem.name};\n"
else:
definition += f"{spaces} {convert_type_for_solidity_signature_to_string(elem.type)} {elem.name};\n"
elif isinstance(elem.type, TypeAlias):
definition += f"{spaces} {elem.type.type} {elem.name};\n"
else:
definition += f"{spaces} {elem.type} {elem.name};\n"
definition += f"{spaces}}}\n"
return definition


def _handle_dynamic_struct_elem(elem_type: Type) -> str:
assert elem_type.is_dynamic
if isinstance(elem_type, ElementaryType):
return f"{elem_type}"
if isinstance(elem_type, ArrayType):
base_type = elem_type.type
if isinstance(base_type, UserDefinedType):
if isinstance(base_type.type, Contract):
return "address[]"
if isinstance(base_type.type, Enum):
return convert_type_for_solidity_signature_to_string(elem_type)
return f"{base_type.type.name}[]"
return f"{base_type}[]"
if isinstance(elem_type, MappingType):
type_to = elem_type.type_to
type_from = elem_type.type_from
if isinstance(type_from, UserDefinedType) and isinstance(type_from.type, Contract):
type_from = ElementaryType("address")
if isinstance(type_to, MappingType):
return f"mapping({type_from} => {_handle_dynamic_struct_elem(type_to)})"
if isinstance(type_to, UserDefinedType):
if isinstance(type_to.type, Contract):
return f"mapping({type_from} => address)"
return f"mapping({type_from} => {type_to.type.name})"
return f"{elem_type}"
return ""


def generate_custom_error_interface(
error: "CustomErrorContract", unroll_structs: bool = True
) -> str:
Expand Down

0 comments on commit baa7fb5

Please sign in to comment.