Skip to content

Commit

Permalink
Merge branch 'main' into matt-fix-constructor-args
Browse files Browse the repository at this point in the history
  • Loading branch information
sentilesdal authored Dec 15, 2023
2 parents ece66bc + df75d87 commit 339b898
Show file tree
Hide file tree
Showing 20 changed files with 872 additions and 755 deletions.
109 changes: 10 additions & 99 deletions example/types/ExampleContract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

from __future__ import annotations

from dataclasses import fields, is_dataclass
from typing import Any, NamedTuple, Tuple, Type, TypeVar, cast
from typing import Any, NamedTuple, Type, cast

from eth_account.signers.local import LocalAccount
from eth_typing import ChecksumAddress, HexStr
Expand All @@ -35,8 +34,7 @@
from web3.types import ABI, BlockIdentifier, CallOverride, TxParams

from .ExampleTypes import InnerStruct, NestedStruct, SimpleStruct

T = TypeVar("T")
from .utilities import dataclass_to_tuple, rename_returned_types

structs = {
"SimpleStruct": SimpleStruct,
Expand All @@ -45,93 +43,6 @@
}


def tuple_to_dataclass(cls: type[T], tuple_data: Any | Tuple[Any, ...]) -> T:
"""
Converts a tuple (including nested tuples) to a dataclass instance. If cls is not a dataclass,
then the data will just be passed through this function.
Arguments
---------
cls: type[T]
The dataclass type to which the tuple data is to be converted.
tuple_data: Any | Tuple[Any, ...]
A tuple (or nested tuple) of values to convert into a dataclass instance.
Returns
-------
T
Either an instance of cls populated with data from tuple_data or tuple_data itself.
"""
if not is_dataclass(cls):
return cast(T, tuple_data)

field_types = {field.name: field.type for field in fields(cls)}
field_values = {}

for (field_name, field_type), value in zip(field_types.items(), tuple_data):
field_type = structs.get(field_type, field_type)
if is_dataclass(field_type):
# Recursively convert nested tuples to nested dataclasses
field_values[field_name] = tuple_to_dataclass(field_type, value)
elif isinstance(value, tuple) and not getattr(field_type, "_name", None) == "Tuple":
# If it's a tuple and the field is not intended to be a tuple, assume it's a nested dataclass
field_values[field_name] = tuple_to_dataclass(field_type, value)
else:
# Otherwise, set the primitive value directly
field_values[field_name] = value

return cls(**field_values)


def dataclass_to_tuple(instance: Any) -> Any:
"""Convert a dataclass instance to a tuple, handling nested dataclasses.
If the input is not a dataclass, return the original value.
"""
if not is_dataclass(instance):
return instance

def convert_value(value: Any) -> Any:
"""Convert nested dataclasses to tuples recursively, or return the original value."""
if is_dataclass(value):
return dataclass_to_tuple(value)
return value

return tuple(convert_value(getattr(instance, field.name)) for field in fields(instance))


def rename_returned_types(return_types, raw_values) -> Any:
"""_summary_
Parameters
----------
return_types : _type_
_description_
raw_values : _type_
_description_
Returns
-------
tuple
_description_
"""
# cover case of multiple return values
if isinstance(return_types, list):
# Ensure raw_values is a tuple for consistency
if not isinstance(raw_values, list):
raw_values = (raw_values,)

# Convert the tuple to the dataclass instance using the utility function
converted_values = tuple(
tuple_to_dataclass(return_type, value) for return_type, value in zip(return_types, raw_values)
)

return converted_values

# cover case of single return value
converted_value = tuple_to_dataclass(return_types, raw_values)
return converted_value


class ExampleFlipFlopContractFunction(ContractFunction):
"""ContractFunction for the flipFlop method."""

Expand Down Expand Up @@ -162,7 +73,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleMixStructsAndPrimitivesContractFunction(ContractFunction):
Expand Down Expand Up @@ -198,7 +109,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleNamedSingleStructContractFunction(ContractFunction):
Expand All @@ -225,7 +136,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return cast(SimpleStruct, rename_returned_types(return_types, raw_values))
return cast(SimpleStruct, rename_returned_types(structs, return_types, raw_values))


class ExampleNamedTwoMixedStructsContractFunction(ContractFunction):
Expand Down Expand Up @@ -258,7 +169,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleSingleNestedStructContractFunction(ContractFunction):
Expand All @@ -285,7 +196,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return cast(NestedStruct, rename_returned_types(return_types, raw_values))
return cast(NestedStruct, rename_returned_types(structs, return_types, raw_values))


class ExampleSingleSimpleStructContractFunction(ContractFunction):
Expand All @@ -312,7 +223,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return cast(SimpleStruct, rename_returned_types(return_types, raw_values))
return cast(SimpleStruct, rename_returned_types(structs, return_types, raw_values))


class ExampleTwoMixedStructsContractFunction(ContractFunction):
Expand Down Expand Up @@ -345,7 +256,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleTwoSimpleStructsContractFunction(ContractFunction):
Expand Down Expand Up @@ -378,7 +289,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleContractFunctions(ContractFunctions):
Expand Down
117 changes: 117 additions & 0 deletions example/types/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""A web3.py Contract class for the {{contract_name}} contract.
DO NOT EDIT. This file was generated by pypechain. See documentation at
https://github.com/delvtech/pypechain"""

# contracts have PascalCase names
# pylint: disable=invalid-name

# contracts control how many attributes and arguments we have in generated code
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments

# we don't need else statement if the other conditionals all have return,
# but it's easier to generate
# pylint: disable=no-else-return

# This file is bound to get very long depending on contract sizes.
# pylint: disable=too-many-lines

# methods are overriden with specific arguments instead of generic *args, **kwargs
# pylint: disable=arguments-differ

from __future__ import annotations

from dataclasses import fields, is_dataclass
from typing import Any, Tuple, TypeVar, cast

T = TypeVar("T")


def tuple_to_dataclass(cls: type[T], structs: dict[str, Any], tuple_data: Any | Tuple[Any, ...]) -> T:
"""
Converts a tuple (including nested tuples) to a dataclass instance. If cls is not a dataclass,
then the data will just be passed through this function.
Arguments
---------
cls: type[T]
The dataclass type to which the tuple data is to be converted.
tuple_data: Any | Tuple[Any, ...]
A tuple (or nested tuple) of values to convert into a dataclass instance.
Returns
-------
T
Either an instance of cls populated with data from tuple_data or tuple_data itself.
"""
if not is_dataclass(cls):
return cast(T, tuple_data)

field_types = {field.name: field.type for field in fields(cls)}
field_values = {}

for (field_name, field_type), value in zip(field_types.items(), tuple_data):
field_type = structs.get(field_type, field_type)
if is_dataclass(field_type):
# Recursively convert nested tuples to nested dataclasses
field_values[field_name] = tuple_to_dataclass(field_type, structs, value)
elif isinstance(value, tuple) and not getattr(field_type, "_name", None) == "Tuple":
# If it's a tuple and the field is not intended to be a tuple, assume it's a nested dataclass
field_values[field_name] = tuple_to_dataclass(field_type, structs, value)
else:
# Otherwise, set the primitive value directly
field_values[field_name] = value

return cls(**field_values)


def dataclass_to_tuple(instance: Any) -> Any:
"""Convert a dataclass instance to a tuple, handling nested dataclasses.
If the input is not a dataclass, return the original value.
"""
if not is_dataclass(instance):
return instance

def convert_value(value: Any) -> Any:
"""Convert nested dataclasses to tuples recursively, or return the original value."""
if is_dataclass(value):
return dataclass_to_tuple(value)
return value

return tuple(convert_value(getattr(instance, field.name)) for field in fields(instance))


def rename_returned_types(
structs: dict[str, Any], return_types: list[Any] | Any, raw_values: list[str | int | tuple] | str | int | tuple
) -> Any:
"""Convert structs in the return value to known dataclasses.
Parameters
----------
return_types : list[str] | str
The type or list of types returned from a contract.
raw_values : list[str | int | tuple] | str | int | tuple
The actual returned values from the contract.
Returns
-------
Any
_description_
"""
# cover case of multiple return values
if isinstance(return_types, list):
# Ensure raw_values is a tuple for consistency
if not isinstance(raw_values, list):
raw_values = (raw_values,)

# Convert the tuple to the dataclass instance using the utility function
converted_values = tuple(
tuple_to_dataclass(return_type, structs, value) for return_type, value in zip(return_types, raw_values)
)

return converted_values

# cover case of single return value
converted_value = tuple_to_dataclass(return_types, structs, raw_values)
return converted_value
9 changes: 8 additions & 1 deletion pypechain/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shutil
import sys
from pathlib import Path
from shutil import copy2
from typing import NamedTuple, Sequence

from web3.exceptions import NoABIFunctionsFound
Expand Down Expand Up @@ -58,9 +59,15 @@ def main(argv: Sequence[str] | None = None) -> None:
print(f"Error creating types for {json_file}")
raise err

# Finally, render the __init__.py file
# Render the __init__.py file
render_init_file(output_dir, file_names, line_length)

# Copy utilities.py to the output_dir
# Get the path to `utilities.py` (assuming it's in the same directory as your script)
utilities_path = Path(__file__).parent / "templates/utilities.py"
# Copy the file to the output directory
copy2(utilities_path, output_dir)


def gather_json_files(directory: str) -> list:
"""Gathers all JSON files in the specified directory and its subdirectories."""
Expand Down
2 changes: 1 addition & 1 deletion pypechain/render/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class GetFunctionDatasReturnValue(NamedTuple):


def get_function_datas(abi: ABI) -> GetFunctionDatasReturnValue:
"""TODO fill me in
"""Gets the information needed for the generated Contract file.
Arguments
---------
Expand Down
Loading

0 comments on commit 339b898

Please sign in to comment.