diff --git a/pypechain/templates/contract.py/base.py.jinja2 b/pypechain/templates/contract.py/base.py.jinja2 index 8476c2f9..0cb76508 100644 --- a/pypechain/templates/contract.py/base.py.jinja2 +++ b/pypechain/templates/contract.py/base.py.jinja2 @@ -46,7 +46,7 @@ from web3._utils.filters import LogFilter {% for struct_info in structs_used %} from .{{struct_info.contract_name}}Types import {{struct_info.name}} {% endfor %} -from .utilities import dataclass_to_tuple, get_abi_input_types, rename_returned_types, tuple_to_dataclass +from .utilities import dataclass_to_tuple, get_abi_input_types, rename_returned_types, try_bytecode_hexbytes, tuple_to_dataclass structs = { diff --git a/pypechain/templates/contract.py/contract.py.jinja2 b/pypechain/templates/contract.py/contract.py.jinja2 index 7eb3f3df..92b57ae2 100644 --- a/pypechain/templates/contract.py/contract.py.jinja2 +++ b/pypechain/templates/contract.py/contract.py.jinja2 @@ -3,17 +3,10 @@ class {{contract_name}}Contract(Contract): abi: ABI = {{contract_name | lower}}_abi {%- if has_bytecode %} - bytecode: bytes + bytecode: bytes | None = try_bytecode_hexbytes({{contract_name | lower}}_bytecode, "{{contract_name | lower}}") {%- endif %} - def __init__(self, address: ChecksumAddress | None = None, bytecode: str | bytes | None = None) -> None: - if bytes is None: - self.bytecode = HexBytes({{contract_name | lower}}_bytecode) - elif isinstance(bytecode, bytes): - self.bytecode = bytecode - elif isinstance(bytecode, str): - self.bytecode = HexBytes(bytecode) - + def __init__(self, address: ChecksumAddress | None = None) -> None: try: # Initialize parent Contract class super().__init__(address=address) diff --git a/pypechain/templates/utilities.py b/pypechain/templates/utilities.py index 00b87d9d..0072e85a 100644 --- a/pypechain/templates/utilities.py +++ b/pypechain/templates/utilities.py @@ -9,6 +9,7 @@ from typing import Any, Tuple, TypeVar, cast from eth_utils.abi import collapse_if_tuple +from hexbytes import HexBytes from web3.types import ABIFunction T = TypeVar("T") @@ -134,3 +135,28 @@ def get_abi_input_types(abi: ABIFunction) -> list[str]: if "inputs" not in abi and (abi.get("type") == "fallback" or abi.get("type") == "receive"): return [] return [collapse_if_tuple(cast(dict[str, Any], arg)) for arg in abi.get("inputs", [])] + + +def try_bytecode_hexbytes(in_bytecode: Any, contract_name: str | None = None) -> HexBytes | None: + """Attempts to convert bytecode input to HexBytes. Returns None if it fails. + + Parameters + ---------- + in_bytecode : Any + The bytecode to attempt to convert to HexBytes + contract_name : str | None, optional + The name of the contract being deployed. Used for better warning printing. + + Returns + ------- + HexBytes | None + The HexBytes if it succeeds, otherwise None + """ + try: + return HexBytes(in_bytecode) + except Exception as e: # pylint: disable=broad-except + if contract_name is None: + print(f"Warning: failed to convert bytecode to HexBytes: {e}") + else: + print(f"Warning: failed to convert bytecode for {contract_name} to HexBytes: {e}") + return None diff --git a/pypechain/templates/utilities_test.py b/pypechain/templates/utilities_test.py index 13113c6d..e928ac12 100644 --- a/pypechain/templates/utilities_test.py +++ b/pypechain/templates/utilities_test.py @@ -3,7 +3,14 @@ from dataclasses import dataclass from typing import TypeVar -from pypechain.templates.utilities import dataclass_to_tuple, rename_returned_types, tuple_to_dataclass +from hexbytes import HexBytes + +from pypechain.templates.utilities import ( + dataclass_to_tuple, + rename_returned_types, + try_bytecode_hexbytes, + tuple_to_dataclass, +) T = TypeVar("T") @@ -118,3 +125,17 @@ def test_non_dataclass_passthrough(self): assert result == "not a dataclass" # Add any additional test cases here + + +class TestTryHexbytes: + def test_correct_hexbytes_return(self): + """Test that try hexbytes converts to HexBytes if possible.""" + str_bytes = "1234" + expected = HexBytes(str_bytes) + out = try_bytecode_hexbytes(expected) + assert expected == out + + def test_incorrect_hexbytes_return(self): + str_bytes = "asdf" + out = try_bytecode_hexbytes(str_bytes) + assert out is None diff --git a/pyproject.toml b/pyproject.toml index 27b9b667..456a1126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "pypechain" -version = "0.0.28" +version = "0.0.29" authors = [ { name = "Matthew Brown", email = "matt@delv.tech" }, { name = "Dylan Paiton", email = "dylan@delv.tech" },