Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move utilities to a separate file. #82

Merged
merged 4 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 10 additions & 99 deletions example/types/ExampleContract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

from __future__ import annotations

from dataclasses import fields, is_dataclass
from typing import Any, NamedTuple, Tuple, Type, TypeVar, cast
from typing import Any, NamedTuple, Type, cast

from eth_account.signers.local import LocalAccount
from eth_typing import ChecksumAddress, HexStr
Expand All @@ -35,8 +34,7 @@
from web3.types import ABI, BlockIdentifier, CallOverride, TxParams

from .ExampleTypes import InnerStruct, NestedStruct, SimpleStruct

T = TypeVar("T")
from .utilities import dataclass_to_tuple, rename_returned_types

structs = {
"SimpleStruct": SimpleStruct,
Expand All @@ -45,93 +43,6 @@
}


def tuple_to_dataclass(cls: type[T], tuple_data: Any | Tuple[Any, ...]) -> T:
"""
Converts a tuple (including nested tuples) to a dataclass instance. If cls is not a dataclass,
then the data will just be passed through this function.
Arguments
---------
cls: type[T]
The dataclass type to which the tuple data is to be converted.
tuple_data: Any | Tuple[Any, ...]
A tuple (or nested tuple) of values to convert into a dataclass instance.
Returns
-------
T
Either an instance of cls populated with data from tuple_data or tuple_data itself.
"""
if not is_dataclass(cls):
return cast(T, tuple_data)

field_types = {field.name: field.type for field in fields(cls)}
field_values = {}

for (field_name, field_type), value in zip(field_types.items(), tuple_data):
field_type = structs.get(field_type, field_type)
if is_dataclass(field_type):
# Recursively convert nested tuples to nested dataclasses
field_values[field_name] = tuple_to_dataclass(field_type, value)
elif isinstance(value, tuple) and not getattr(field_type, "_name", None) == "Tuple":
# If it's a tuple and the field is not intended to be a tuple, assume it's a nested dataclass
field_values[field_name] = tuple_to_dataclass(field_type, value)
else:
# Otherwise, set the primitive value directly
field_values[field_name] = value

return cls(**field_values)


def dataclass_to_tuple(instance: Any) -> Any:
"""Convert a dataclass instance to a tuple, handling nested dataclasses.
If the input is not a dataclass, return the original value.
"""
if not is_dataclass(instance):
return instance

def convert_value(value: Any) -> Any:
"""Convert nested dataclasses to tuples recursively, or return the original value."""
if is_dataclass(value):
return dataclass_to_tuple(value)
return value

return tuple(convert_value(getattr(instance, field.name)) for field in fields(instance))


def rename_returned_types(return_types, raw_values) -> Any:
"""_summary_
Parameters
----------
return_types : _type_
_description_
raw_values : _type_
_description_
Returns
-------
tuple
_description_
"""
# cover case of multiple return values
if isinstance(return_types, list):
# Ensure raw_values is a tuple for consistency
if not isinstance(raw_values, list):
raw_values = (raw_values,)

# Convert the tuple to the dataclass instance using the utility function
converted_values = tuple(
tuple_to_dataclass(return_type, value) for return_type, value in zip(return_types, raw_values)
)

return converted_values

# cover case of single return value
converted_value = tuple_to_dataclass(return_types, raw_values)
return converted_value


class ExampleFlipFlopContractFunction(ContractFunction):
"""ContractFunction for the flipFlop method."""

Expand Down Expand Up @@ -162,7 +73,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleMixStructsAndPrimitivesContractFunction(ContractFunction):
Expand Down Expand Up @@ -198,7 +109,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleNamedSingleStructContractFunction(ContractFunction):
Expand All @@ -225,7 +136,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return cast(SimpleStruct, rename_returned_types(return_types, raw_values))
return cast(SimpleStruct, rename_returned_types(structs, return_types, raw_values))


class ExampleNamedTwoMixedStructsContractFunction(ContractFunction):
Expand Down Expand Up @@ -258,7 +169,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleSingleNestedStructContractFunction(ContractFunction):
Expand All @@ -285,7 +196,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return cast(NestedStruct, rename_returned_types(return_types, raw_values))
return cast(NestedStruct, rename_returned_types(structs, return_types, raw_values))


class ExampleSingleSimpleStructContractFunction(ContractFunction):
Expand All @@ -312,7 +223,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return cast(SimpleStruct, rename_returned_types(return_types, raw_values))
return cast(SimpleStruct, rename_returned_types(structs, return_types, raw_values))


class ExampleTwoMixedStructsContractFunction(ContractFunction):
Expand Down Expand Up @@ -345,7 +256,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleTwoSimpleStructsContractFunction(ContractFunction):
Expand Down Expand Up @@ -378,7 +289,7 @@ def call(
# Call the function

raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
return self.ReturnValues(*rename_returned_types(return_types, raw_values))
return self.ReturnValues(*rename_returned_types(structs, return_types, raw_values))


class ExampleContractFunctions(ContractFunctions):
Expand Down
117 changes: 117 additions & 0 deletions example/types/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""A web3.py Contract class for the {{contract_name}} contract.
DO NOT EDIT. This file was generated by pypechain. See documentation at
https://github.com/delvtech/pypechain"""

# contracts have PascalCase names
# pylint: disable=invalid-name

# contracts control how many attributes and arguments we have in generated code
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments

# we don't need else statement if the other conditionals all have return,
# but it's easier to generate
# pylint: disable=no-else-return

# This file is bound to get very long depending on contract sizes.
# pylint: disable=too-many-lines

# methods are overriden with specific arguments instead of generic *args, **kwargs
# pylint: disable=arguments-differ

from __future__ import annotations

from dataclasses import fields, is_dataclass
from typing import Any, Tuple, TypeVar, cast

T = TypeVar("T")


def tuple_to_dataclass(cls: type[T], structs: dict[str, Any], tuple_data: Any | Tuple[Any, ...]) -> T:
"""
Converts a tuple (including nested tuples) to a dataclass instance. If cls is not a dataclass,
then the data will just be passed through this function.
Arguments
---------
cls: type[T]
The dataclass type to which the tuple data is to be converted.
tuple_data: Any | Tuple[Any, ...]
A tuple (or nested tuple) of values to convert into a dataclass instance.
Returns
-------
T
Either an instance of cls populated with data from tuple_data or tuple_data itself.
"""
if not is_dataclass(cls):
return cast(T, tuple_data)

field_types = {field.name: field.type for field in fields(cls)}
field_values = {}

for (field_name, field_type), value in zip(field_types.items(), tuple_data):
field_type = structs.get(field_type, field_type)
if is_dataclass(field_type):
# Recursively convert nested tuples to nested dataclasses
field_values[field_name] = tuple_to_dataclass(field_type, structs, value)
elif isinstance(value, tuple) and not getattr(field_type, "_name", None) == "Tuple":
# If it's a tuple and the field is not intended to be a tuple, assume it's a nested dataclass
field_values[field_name] = tuple_to_dataclass(field_type, structs, value)
else:
# Otherwise, set the primitive value directly
field_values[field_name] = value

return cls(**field_values)


def dataclass_to_tuple(instance: Any) -> Any:
"""Convert a dataclass instance to a tuple, handling nested dataclasses.
If the input is not a dataclass, return the original value.
"""
if not is_dataclass(instance):
return instance

def convert_value(value: Any) -> Any:
"""Convert nested dataclasses to tuples recursively, or return the original value."""
if is_dataclass(value):
return dataclass_to_tuple(value)
return value

return tuple(convert_value(getattr(instance, field.name)) for field in fields(instance))


def rename_returned_types(
structs: dict[str, Any], return_types: list[Any] | Any, raw_values: list[str | int | tuple] | str | int | tuple
) -> Any:
"""Convert structs in the return value to known dataclasses.
Parameters
----------
return_types : list[str] | str
The type or list of types returned from a contract.
raw_values : list[str | int | tuple] | str | int | tuple
The actual returned values from the contract.
Returns
-------
Any
_description_
"""
# cover case of multiple return values
if isinstance(return_types, list):
# Ensure raw_values is a tuple for consistency
if not isinstance(raw_values, list):
raw_values = (raw_values,)

# Convert the tuple to the dataclass instance using the utility function
converted_values = tuple(
tuple_to_dataclass(return_type, structs, value) for return_type, value in zip(return_types, raw_values)
)

return converted_values

# cover case of single return value
converted_value = tuple_to_dataclass(return_types, structs, raw_values)
return converted_value
9 changes: 8 additions & 1 deletion pypechain/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shutil
import sys
from pathlib import Path
from shutil import copy2
from typing import NamedTuple, Sequence

from web3.exceptions import NoABIFunctionsFound
Expand Down Expand Up @@ -58,9 +59,15 @@ def main(argv: Sequence[str] | None = None) -> None:
print(f"Error creating types for {json_file}")
raise err

# Finally, render the __init__.py file
# Render the __init__.py file
render_init_file(output_dir, file_names, line_length)

# Copy utilities.py to the output_dir
# Get the path to `utilities.py` (assuming it's in the same directory as your script)
utilities_path = Path(__file__).parent / "templates/utilities.py"
# Copy the file to the output directory
copy2(utilities_path, output_dir)


def gather_json_files(directory: str) -> list:
"""Gathers all JSON files in the specified directory and its subdirectories."""
Expand Down
2 changes: 1 addition & 1 deletion pypechain/render/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class GetFunctionDatasReturnValue(NamedTuple):


def get_function_datas(abi: ABI) -> GetFunctionDatasReturnValue:
"""TODO fill me in
"""Gets the information needed for the generated Contract file.
Arguments
---------
Expand Down
Loading