Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[tool]: include structs in -f interface output #4294

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/functional/codegen/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)


# TODO CMC 2024-10-13: this should probably be in tests/unit/compiler/
def test_basic_extract_interface():
code = """
# Events
Expand All @@ -22,6 +23,7 @@ def test_basic_extract_interface():
_to: address
_value: uint256


# Functions

@view
Expand All @@ -37,6 +39,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256):
assert code_pass.strip() == out.strip()


# TODO CMC 2024-10-13: this should probably be in tests/unit/compiler/
def test_basic_extract_external_interface():
code = """
@view
Expand Down Expand Up @@ -68,6 +71,7 @@ def test(_owner: address): nonpayable
assert interface.strip() == out.strip()


# TODO CMC 2024-10-13: should probably be in syntax tests
def test_basic_interface_implements(assert_compile_failed):
code = """
from ethereum.ercs import IERC20
Expand All @@ -82,6 +86,7 @@ def test() -> bool:
assert_compile_failed(lambda: compile_code(code), InterfaceViolation)


# TODO CMC 2024-10-13: should probably be in syntax tests
def test_external_interface_parsing(make_input_bundle, assert_compile_failed):
interface_code = """
@external
Expand Down Expand Up @@ -126,6 +131,7 @@ def foo() -> uint256:
compile_code(not_implemented_code, input_bundle=input_bundle)


# TODO CMC 2024-10-13: should probably be in syntax tests
def test_log_interface_event(make_input_bundle, assert_compile_failed):
interface_code = """
event Foo:
Expand Down Expand Up @@ -160,6 +166,7 @@ def bar() -> uint256:
]


# TODO CMC 2024-10-13: should probably be in syntax tests
@pytest.mark.parametrize("code,filename", VALID_IMPORT_CODE)
def test_extract_file_interface_imports(code, filename, make_input_bundle):
input_bundle = make_input_bundle({filename: ""})
Expand All @@ -177,6 +184,7 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle):
]


# TODO CMC 2024-10-13: should probably be in syntax tests
@pytest.mark.parametrize("code,exception_type", BAD_IMPORT_CODE)
def test_extract_file_interface_imports_raises(
code, exception_type, assert_compile_failed, make_input_bundle
Expand Down Expand Up @@ -726,3 +734,43 @@ def bar() -> uint256:
c = get_contract(code, input_bundle=input_bundle)

assert c.foo() == c.bar() == 1


def test_interface_with_structures():
code = """
struct MyStruct:
a: address
b: uint256

event Transfer:
sender: indexed(address)
receiver: indexed(address)
value: uint256

struct Voter:
weight: int128
voted: bool
delegate: address
vote: int128

@external
def bar():
pass

event Buy:
buyer: indexed(address)
buy_order: uint256

@external
@view
def foo(s: MyStruct) -> MyStruct:
return s
"""

out = compile_code(code, contract_path="code.vy", output_formats=["interface"])["interface"]

assert "# Structs" in out
assert "struct MyStruct:" in out
assert "b: uint256" in out
assert "struct Voter:" in out
assert "voted: bool" in out
25 changes: 18 additions & 7 deletions vyper/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,33 @@ def build_interface_output(compiler_data: CompilerData) -> str:
interface = compiler_data.annotated_vyper_module._metadata["type"].interface
out = ""

if interface.events:
out = "# Events\n\n"
if len(interface.structs) > 0:
out += "# Structs\n\n"
for struct in interface.structs.values():
out += f"struct {struct.name}:\n"
for member_name, member_type in struct.members.items():
out += f" {member_name}: {member_type}\n"
out += "\n\n"

if len(interface.events) > 0:
out += "# Events\n\n"
for event in interface.events.values():
encoded_args = "\n ".join(f"{name}: {typ}" for name, typ in event.arguments.items())
out = f"{out}event {event.name}:\n {encoded_args if event.arguments else 'pass'}\n"
out += f"event {event.name}:\n {encoded_args if event.arguments else 'pass'}\n\n\n"

if interface.functions:
out = f"{out}\n# Functions\n\n"
if len(interface.functions) > 0:
out += "# Functions\n\n"
for func in interface.functions.values():
if func.visibility == FunctionVisibility.INTERNAL or func.name == "__init__":
continue
if func.mutability != StateMutability.NONPAYABLE:
out = f"{out}@{func.mutability.value}\n"
out += f"@{func.mutability.value}\n"
args = ", ".join([f"{arg.name}: {arg.typ}" for arg in func.arguments])
return_value = f" -> {func.return_type}" if func.return_type is not None else ""
out = f"{out}@external\ndef {func.name}({args}){return_value}:\n ...\n\n"
out += f"@external\ndef {func.name}({args}){return_value}:\n ...\n\n\n"

out = out.rstrip("\n")
out += "\n"

return out

Expand Down
Loading