Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiple return values with types #66

Merged
merged 27 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test fixture for deploying local anvil chain."""
from __future__ import annotations

import os
import subprocess
import time
from typing import Iterator
Expand All @@ -13,6 +14,34 @@

# pylint: disable=redefined-outer-name

# IMPORTANT NOTE!!!!!
# If you end up using this debugging method, this will catch exceptions before teardown of fixtures

# Use this in conjunction with the following launch.json configuration:
# {
# "name": "Debug Current Test",
# "type": "python",
# "request": "launch",
# "module": "pytest",
# "args": ["${file}", "-vs"],
# "console": "integratedTerminal",
# "justMyCode": true,
# "env": {
# "_PYTEST_RAISE": "1"
# },
# },
if os.getenv("_PYTEST_RAISE", "0") != "0":

@pytest.hookimpl(tryfirst=True)
def pytest_exception_interact(call):
"""Allows you to set breakpoints in pytest."""
raise call.excinfo.value

@pytest.hookimpl(tryfirst=True)
def pytest_internalerror(excinfo):
"""Allows you to set breakpoints in pytest."""
raise excinfo.value


@pytest.fixture(scope="session")
def local_chain() -> Iterator[str]:
Expand Down
41 changes: 39 additions & 2 deletions pypechain/render/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,19 @@ def render_contract_file(contract_name: str, abi_file_path: Path) -> str:
functions=function_datas,
)

# if any function has overloading
has_overloading = any(function_data["has_overloading"] for function_data in function_datas.values())
has_multiple_return_values = any(
function_data["has_multiple_return_values"] for function_data in function_datas.values()
)

# Render the template
return templates.base_template.render(
contract_name=contract_name,
structs_used=structs_used,
structs_for_abi=structs_for_abi,
has_overloading=has_overloading,
has_multiple_return_values=has_multiple_return_values,
has_bytecode=has_bytecode,
has_events=has_events,
functions_block=functions_block,
Expand Down Expand Up @@ -122,11 +129,36 @@ def get_has_multiple_return_signatures(signature_datas: list[SignatureData]) ->
if first_output_types is None:
first_output_types = signature_data["output_types"]
else:
lists_equal = all(item[0] == item[1] for item in zip(first_output_types, signature_data["output_types"]))
lists_equal = all(
output_types_to_compare[0] == output_types_to_compare[1]
for output_types_to_compare in zip(first_output_types, signature_data["output_types"])
)
if not lists_equal:
break

return not lists_equal


def get_has_multiple_return_values(signature_datas: list[SignatureData]) -> bool:
"""If there are multiple return values for a smart contract function, we'll need to overload
the call() method. This method compares the output types of all the values of a method.

Parameters
----------
signature_datas : list[SignatureData]
a list of SignatureData's to compare.

Returns
-------
bool
If there are multiple return signatures or not.
"""
for signature_data in signature_datas:
if len(signature_data["outputs"]) > 1:
return True
return False


class ContractTemplates(NamedTuple):
"""Templates for the generated contract file."""

Expand Down Expand Up @@ -173,7 +205,7 @@ def get_function_datas(abi: ABI) -> GetFunctionDatasReturnValue:
constructor_data: SignatureData | None = None
for abi_function in get_abi_items(abi):
if is_abi_function(abi_function):
# hanndle constructor
# handle constructor
if is_abi_constructor(abi_function):
constructor_data = {
"input_names_and_types": get_input_names_and_types(abi_function),
Expand All @@ -200,16 +232,21 @@ def get_function_datas(abi: ABI) -> GetFunctionDatasReturnValue:
"signature_datas": [signature_data],
"has_overloading": False,
"has_multiple_return_signatures": False,
"has_multiple_return_values": False,
}
if not function_datas.get(name):
function_datas[name] = function_data
function_datas[name]["has_multiple_return_values"] = get_has_multiple_return_values(
[signature_data]
)
else:
signature_datas = function_datas[name]["signature_datas"]
signature_datas.append(signature_data)
function_datas[name]["has_overloading"] = len(signature_datas) > 1
function_datas[name]["has_multiple_return_signatures"] = get_has_multiple_return_signatures(
signature_datas
)
function_datas[name]["has_multiple_return_values"] = get_has_multiple_return_values(signature_datas)
return GetFunctionDatasReturnValue(function_datas, constructor_data)


Expand Down
36 changes: 33 additions & 3 deletions pypechain/templates/contract.py/base.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ https://github.com/delvtech/pypechain"""
from __future__ import annotations

from dataclasses import fields, is_dataclass
from typing import Any, Tuple, Type, TypeVar, cast
{% if has_events %}from typing import Iterable, Sequence{% endif %}
from typing import Any, {% if has_multiple_return_values %}NamedTuple, {% endif %}Tuple, Type, TypeVar, cast{% if has_overloading %}, overload{% endif %}

from eth_typing import ChecksumAddress{% if has_bytecode %}, HexStr{% endif %}
{% if has_events %}from eth_utils.decorators import combomethod{% endif %}
from hexbytes import HexBytes
{% if has_overloading %}from multimethod import multimethod {% endif %}
from typing_extensions import Self
from web3 import Web3
from web3.contract.contract import Contract, ContractFunction, ContractFunctions
Expand Down Expand Up @@ -81,6 +79,38 @@ def tuple_to_dataclass(cls: type[T], tuple_data: Any | Tuple[Any, ...]) -> T:

return cls(**field_values)

def rename_returned_types(return_types, raw_values) -> Any:
"""_summary_

Parameters
----------
return_types : _type_
_description_
raw_values : _type_
_description_

Returns
-------
tuple
_description_
"""
# cover case of multiple return values
if isinstance(return_types, list):
# Ensure raw_values is a tuple for consistency
if not isinstance(raw_values, list):
raw_values = (raw_values,)

# Convert the tuple to the dataclass instance using the utility function
converted_values = tuple(
tuple_to_dataclass(return_type, value) for return_type, value in zip(return_types, raw_values)
)

return converted_values

# cover case of single return value
converted_value = tuple_to_dataclass(return_types, raw_values)
return converted_value

{{functions_block}}

{% if has_events %}{{ events_block }}{% endif %}
Expand Down
166 changes: 113 additions & 53 deletions pypechain/templates/contract.py/functions.py.jinja2
Original file line number Diff line number Diff line change
@@ -1,54 +1,113 @@
{# loop over all functions and create types for each #}
{%- for name, function_data in functions.items() -%}
{# check if the function is overloaded#}
{% if function_data.has_overloading %}
{# go through signatures, create a class per signature #}
{% for i in range(function_data.signature_datas|length) %}
class {{contract_name}}{{function_data.capitalized_name}}ContractFunction{{i}}(ContractFunction):
"""ContractFunction for the {{function_data.name}} method."""

def __call__(self{% if function_data.signature_datas[i].input_names_and_types %}, {{function_data.signature_datas[i].input_names_and_types|join(', ')}}{% endif %}) -> {{contract_name}}{{function_data.capitalized_name}}ContractFunction:{%- if function_data.signature_datas|length > 1 %} #type: ignore{% endif %}
super().__call__({% if function_data.signature_datas[i].input_names%}{{function_data.signature_datas[i].input_names|join(', ')}}{% endif %})
return cast({{contract_name}}{{function_data.capitalized_name}}ContractFunction, self)

{% set output_names = function_data.signature_datas[i].outputs %}
{% set output_types = function_data.signature_datas[i].output_types %}
{% set return_type = output_types[0] if output_types|length == 1 else "ReturnValues" if output_types|length > 1 else "None" %}

{% if output_types|length > 1%}
class ReturnValues(NamedTuple):
"""The return named tuple for {{function_data.capitalized_name}}."""
{% for j in range(output_types | length) -%}
{{output_names[j]}}: {{output_types[j]}}
{% endfor %}
{% endif %}

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = "latest",
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None,
) -> {{return_type}}:
"""returns {{return_type}}."""
# Define the expected return types from the smart contract call
{% if output_types|length == 1 %}
return_types = {{return_type}}
{% elif output_types|length > 1%}
return_types = [{{output_types|join(', ')}}]
{% endif %}
# Call the function
raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
{% if output_types|length == 1 %}
return cast({{output_types[0]}}, rename_returned_types(return_types, raw_values))
{% elif output_types|length > 1%}
return self.{{return_type}}(*rename_returned_types(return_types, raw_values))
{% endif %}
{% endfor %}
class {{contract_name}}{{function_data.capitalized_name}}ContractFunction(ContractFunction):
"""ContractFunction for the {{function_data.name}} method."""
# super() call methods are generic, while our version adds values & types
# pylint: disable=arguments-differ
{%- if function_data.signature_datas|length > 1-%}
# disable this warning when there is overloading
# pylint: disable=arguments-differ# disable this warning when there is overloading
# pylint: disable=function-redefined
{%- endif -%}
{% for signature_data in function_data.signature_datas %}
{% if function_data.has_overloading %} @multimethod{% endif %}
def __call__(self{% if signature_data.input_names_and_types %}, {{signature_data.input_names_and_types|join(', ')}}{% endif %}) -> "{{contract_name}}{{function_data.capitalized_name}}ContractFunction":{%- if function_data.signature_datas|length > 1 %} #type: ignore{% endif %}

{# go through signatures, create a class per signature #}
{% for i in range(function_data.signature_datas|length) %}
@overload
def __call__(self{% if function_data.signature_datas[i].input_names_and_types %}, {{function_data.signature_datas[i].input_names_and_types|join(', ')}}{% endif %}) -> {{contract_name}}{{function_data.capitalized_name}}ContractFunction{{i}}: # type: ignore
...
{% endfor %}

def __call__(self, *args) -> {{contract_name}}{{function_data.capitalized_name}}ContractFunction: # type: ignore
clone = super().__call__(*args)
self.kwargs = clone.kwargs
self.args = clone.args
return self # type: ignore
{% else %} {# no overloading #}
class {{contract_name}}{{function_data.capitalized_name}}ContractFunction(ContractFunction):
"""ContractFunction for the {{function_data.name}} method."""

{% set signature_data = function_data.signature_datas[0] %}
{% set output_names = signature_data.outputs %}
{% set output_types = signature_data.output_types %}
{% set return_type = output_types[0] if output_types|length == 1 else "ReturnValues" if output_types|length > 1 else "None" %}
{% if output_types|length > 1%}
class ReturnValues(NamedTuple):
"""The return named tuple for {{function_data.capitalized_name}}."""
{% for j in range(output_types | length) -%}
{{output_names[j]}}: {{output_types[j]}}
{% endfor %}
{% endif %}

def __call__(self{% if signature_data.input_names_and_types %}, {{signature_data.input_names_and_types|join(', ')}}{% endif %}) -> {{contract_name}}{{function_data.capitalized_name}}ContractFunction:{%- if signature_datas|length > 1 %} #type: ignore{% endif %}
clone = super().__call__({{signature_data.input_names|join(', ')}})
self.kwargs = clone.kwargs
self.args = clone.args
return self
{% endfor %}
{% set output_types = function_data.signature_datas[0].output_types %}

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = 'latest',
block_identifier: BlockIdentifier = "latest",
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None){% if function_data.has_multiple_return_signatures %} -> Any{% elif output_types|length == 1 %} -> {{output_types[0]}}{% elif output_types|length > 1%} -> tuple[{{output_types|join(', ')}}]{% else %} -> None{% endif %}:
{% if output_types|length == 1 %}"""returns {{output_types[0]}}"""{% elif output_types|length > 1%}"""returns ({{output_types|join(', ')}})"""{% else %}"""No return value"""{% endif %}
{% if output_types|length > 0 %}raw_values = {% endif %}super().call(transaction, block_identifier, state_override, ccip_read_enabled)
# Define the expected return types from the smart contract call
{% if output_types|length == 1 %}return_types = {{output_types[0]}}{% elif output_types|length > 1%}return_types = [{{output_types|join(', ')}}]{% endif %}
{% if output_types|length == 1 %}
return cast({{output_types[0]}}, self._call(return_types, raw_values))
{% elif output_types|length > 1 %}
return cast(tuple[{{output_types|join(', ')}}], self._call(return_types, raw_values))
{% endif %}
def _call(self, return_types, raw_values):
# cover case of multiple return values
if isinstance(return_types, list):
# Ensure raw_values is a tuple for consistency
if not isinstance(raw_values, list):
raw_values = (raw_values,)

# Convert the tuple to the dataclass instance using the utility function
converted_values = tuple(
(tuple_to_dataclass(return_type, value) for return_type, value in zip(return_types, raw_values))
)
ccip_read_enabled: bool | None = None,
) -> {{return_type}}:
"""returns {{return_type}}."""
# Define the expected return types from the smart contract call
{% if output_types|length == 1 %}
return_types = {{return_type}}
{% elif output_types|length > 1%}
return_types = [{{output_types|join(', ')}}]
{% endif %}
# Call the function
raw_values = super().call(transaction, block_identifier, state_override, ccip_read_enabled)
{% if output_types|length == 1 %}
return cast({{output_types[0]}}, rename_returned_types(return_types, raw_values))
{% elif output_types|length > 1%}
return self.{{return_type}}(*rename_returned_types(return_types, raw_values))
{% endif %}
{% endif %} {# if function_data.has_overloading #}

return converted_values

# cover case of single return value
converted_value = tuple_to_dataclass(return_types, raw_values)
return converted_value
{% endfor %}

class {{contract_name}}ContractFunctions(ContractFunctions):
Expand All @@ -57,20 +116,21 @@ class {{contract_name}}ContractFunctions(ContractFunctions):
{{function.name}}: {{contract_name}}{{function.capitalized_name}}ContractFunction
{% endfor %}
def __init__(
self,
abi: ABI,
w3: "Web3",
address: ChecksumAddress | None = None,
decode_tuples: bool | None = False,
) -> None:
super().__init__(abi, w3, address, decode_tuples)
{% for function in functions.values() -%}
self.{{function.name}} = {{contract_name}}{{function.capitalized_name}}ContractFunction.factory(
"{{function.name}}",
w3=w3,
contract_abi=abi,
address=address,
decode_tuples=decode_tuples,
function_identifier="{{function.name}}",
)
{% endfor %}
self,
abi: ABI,
w3: "Web3",
address: ChecksumAddress | None = None,
decode_tuples: bool | None = False,
) -> None:
super().__init__(abi, w3, address, decode_tuples)
{% for function in functions.values() -%}
self.{{function.name}} = {{contract_name}}{{function.capitalized_name}}ContractFunction.factory(
"{{function.name}}",
w3=w3,
contract_abi=abi,
address=address,
decode_tuples=decode_tuples,
function_identifier="{{function.name}}",
)
{% endfor %}

1 change: 1 addition & 0 deletions pypechain/test/overloading/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## Generate ABIs
Install [`solc`](https://docs.soliditylang.org/en/latest/installing-solidity.html).

from this directory run:

Expand Down
Loading