Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add return types to call() methods. #32

Merged
merged 6 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pypechain/render/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
get_input_names,
get_input_names_and_values,
get_output_names,
get_output_names_and_values,
get_output_types,
is_abi_constructor,
is_abi_function,
load_abi_from_file,
Expand All @@ -25,6 +27,7 @@ class SignatureData(TypedDict):
input_names_and_types: list[str]
input_names: list[str]
outputs: list[str]
output_types: list[str]


class FunctionData(TypedDict):
Expand Down Expand Up @@ -139,6 +142,7 @@ def get_function_datas(abi: ABI) -> tuple[dict[str, FunctionData], SignatureData
"input_names_and_types": get_input_names_and_values(abi_function),
"input_names": get_input_names(abi_function),
"outputs": get_output_names(abi_function),
"output_types": get_output_names_and_values(abi_function),
}

# handle all other functions
Expand All @@ -148,6 +152,7 @@ def get_function_datas(abi: ABI) -> tuple[dict[str, FunctionData], SignatureData
"input_names_and_types": get_input_names_and_values(abi_function),
"input_names": get_input_names(abi_function),
"outputs": get_output_names(abi_function),
"output_types": get_output_types(abi_function),
}
function_data: FunctionData = {
# TODO: pass a typeguarded ABIFunction that has only required fields?
Expand Down
4 changes: 2 additions & 2 deletions pypechain/templates/contract.py/base.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# pylint: disable=too-many-lines

from __future__ import annotations
from typing import cast
from typing import cast, NamedTuple

from eth_typing import ChecksumAddress{% if has_bytecode %}, HexStr{% endif %}
{% if has_bytecode %}from hexbytes import HexBytes{% endif %}
from web3.types import ABI
from web3.types import ABI, BlockIdentifier, CallOverride, TxParams
from web3.contract.contract import Contract, ContractFunction, ContractFunctions
from web3.exceptions import FallbackNotFound
{% if has_overloading %}
Expand Down
11 changes: 11 additions & 0 deletions pypechain/templates/contract.py/functions.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ class {{contract_name}}{{function_data.capitalized_name}}ContractFunction(Contra
def __call__(self{% if signature_data.input_names_and_types %}, {{signature_data.input_names_and_types|join(', ')}}{% endif %}) -> "{{contract_name}}{{function_data.capitalized_name}}ContractFunction":{%- if has_overloading %} #type: ignore{% endif %}
super().__call__({{signature_data.input_names|join(', ')}})
return self

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = 'latest',
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None){% if signature_data.output_types|length == 1 %} -> {{signature_data.output_types[0]}}{% elif signature_data.output_types|length > 1%} -> tuple[{{signature_data.output_types|join(', ')}}]{% endif %}:
{% if signature_data.output_types|length == 1 %}"""returns {{signature_data.output_types[0]}}"""{% elif signature_data.output_types|length > 1%}"""returns ({{signature_data.output_types|join(', ')}})"""{% else %}"""No return value"""{% endif %}
return super().call(transaction, block_identifier, state_override, ccip_read_enabled)


{% endfor %}
{% endfor %}
class {{contract_name}}ContractFunctions(ContractFunctions):
Expand Down
83 changes: 73 additions & 10 deletions pypechain/utilities/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Literal, NamedTuple, Sequence, TypeGuard, cast
from typing import List, Literal, NamedTuple, Sequence, TypedDict, TypeGuard, cast

from web3 import Web3
from web3.types import ABI, ABIElement, ABIEvent, ABIFunction, ABIFunctionComponents, ABIFunctionParams
Expand All @@ -16,33 +16,32 @@
from pypechain.utilities.types import solidity_to_python_type


class Input(NamedTuple):
class Input(TypedDict):
"""An input of a function or event."""

internalType: str
name: str
type: str
indexed: bool | None = None
indexed: bool | None


class Output(NamedTuple):
class Output(TypedDict):
sentilesdal marked this conversation as resolved.
Show resolved Hide resolved
"""An output of a function or event."""

internalType: str
internalType: str
name: str
type: str


class AbiItem(NamedTuple):
class AbiItem(TypedDict):
"""An item of an ABI, can be an event, function or struct."""

type: str
inputs: List[Input]
stateMutability: str | None = None
anonymous: bool | None = None
name: str | None = None
outputs: List[Output] | None = None
stateMutability: str | None
anonymous: bool | None
name: str | None
outputs: List[Output] | None


class AbiJson(NamedTuple):
Expand Down Expand Up @@ -557,6 +556,28 @@ def get_input_names_and_values(function: ABIFunction) -> list[str]:
return _get_names_and_values(function, "inputs")


def get_output_types(function: ABIFunction) -> list[str]:
"""Returns function output type strings for jinja templating.

i.e. for the solidity function signature: function doThing(address who, uint256 amount, bool
flag, bytes extraData)

the following list would be returned: ['who: str', 'amount: int', 'flag: bool', 'extraData:
bytes']

Arguments
---------
function : ABIFunction
A web3 dict of an ABI function description.

Returns
-------
list[str]
A list of function python values, i.e. ['str', 'bool']
sentilesdal marked this conversation as resolved.
Show resolved Hide resolved
"""
return _get_param_types(function, "outputs")


def get_output_names_and_values(function: ABIFunction) -> list[str]:
"""Returns function input name/type strings for jinja templating.

Expand Down Expand Up @@ -610,6 +631,48 @@ def _get_names_and_values(function: ABIFunction, parameters_type: Literal["input
return stringified_function_parameters


def _get_param_types(function: ABIFunction, parameters_type: Literal["inputs", "outputs"]) -> list[str]:
"""Returns function input or output type strings for jinja templating.

i.e. for the solidity function signature: function doThing(address who, uint256 amount, bool
flag, bytes extraData)

the following list would be returned: ['who: str', 'amount: int', 'flag: bool', 'extraData:
bytes']

Arguments
---------
function : ABIFunction
A web3 dict of an ABI function description.
parameters_type : Literal["inputs", "outputs"]
If we are looking at the inputs or outputs of a function.

Returns
-------
list[str]
A list of function parameter python types, i.e. ['str', 'bool']
"""
stringified_function_parameters: list[str] = []
inputs_or_outputs = function.get(parameters_type, [])
inputs_or_outputs = cast(list[ABIFunctionParams], inputs_or_outputs)

for param in inputs_or_outputs:
python_type = get_param_type(param)
stringified_function_parameters.append(f"{python_type}")
return stringified_function_parameters


def get_param_type(param: ABIFunctionParams):
"""Gets the associated python type, including generated dataclasses"""
internal_type = cast(str, param.get("internalType", ""))
# if we find a struct, we'll add it to the dict of StructInfo's
if is_struct(internal_type):
python_type = get_struct_name(param)
else:
python_type = solidity_to_python_type(param.get("type", "unknown"))
return python_type


def get_abi_from_json(json_abi: FoundryJson | SolcJson | ABI) -> ABI:
"""Gets the ABI from a supported json format."""
if is_foundry_json(json_abi):
Expand Down
5 changes: 4 additions & 1 deletion pypechain/utilities/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,7 @@ def apply_black_formatting(code: str, line_length: int = 80) -> str:
try:
return black.format_file_contents(code, fast=False, mode=black.Mode(line_length=line_length))
except ValueError as exc:
raise ValueError(f"cannot format with Black\n code:\n{code}") from exc
print(f"cannot format with Black\n code:\n{code}")
print(f"{exc=}")
return code
# raise ValueError(f"cannot format with Black\n code:\n{code}") from exc
sentilesdal marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 22 additions & 0 deletions snapshots/expected_not_overloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ def __call__(self) -> "OverloadedBalanceOfContractFunction":
super().__call__()
return self

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = 'latest',
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None) -> int:
"""returns int"""
return super().call(transaction, block_identifier, state_override, ccip_read_enabled)



class OverloadedBalanceOfWhoContractFunction(ContractFunction):
"""ContractFunction for the balanceOfWho method."""
# super() call methods are generic, while our version adds values & types
Expand All @@ -16,6 +27,17 @@ def __call__(self, who: str) -> "OverloadedBalanceOfWhoContractFunction":
super().__call__(who)
return self

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = 'latest',
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None) -> bool:
"""returns bool"""
return super().call(transaction, block_identifier, state_override, ccip_read_enabled)




class OverloadedContractFunctions(ContractFunctions):
"""ContractFunctions for the Overloaded contract."""
Expand Down
22 changes: 22 additions & 0 deletions snapshots/expected_overloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,33 @@ def __call__(self) -> "OverloadedBalanceOfContractFunction":
super().__call__()
return self

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = 'latest',
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None) -> int:
"""returns int"""
return super().call(transaction, block_identifier, state_override, ccip_read_enabled)




def __call__(self, who: str) -> "OverloadedBalanceOfContractFunction":
super().__call__(who)
return self

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = 'latest',
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None) -> int:
"""returns int"""
return super().call(transaction, block_identifier, state_override, ccip_read_enabled)




class OverloadedContractFunctions(ContractFunctions):
"""ContractFunctions for the Overloaded contract."""
Expand Down