Skip to content

Commit

Permalink
fix: fix bug with resolving super/direct base method invocation
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Jun 25, 2024
1 parent 8a1dfb6 commit a0618cb
Show file tree
Hide file tree
Showing 83 changed files with 1,771 additions and 1,517 deletions.
1 change: 1 addition & 0 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
hello_world_arc4/HelloWorld 102 88 14 88 0
inheritance/Child 88 88 0 88 0
inheritance/GrandParent 51 51 0 51 0
inheritance/GreatGrandParent 51 51 0 51 0
inheritance/Parent 89 89 0 89 0
inner_transactions 1245 1193 52 1193 0
inner_transactions/ArrayAccess 212 195 17 195 0
Expand Down
27 changes: 3 additions & 24 deletions src/puya/awst_build/eb/arc4/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,22 @@
import typing

import attrs
import mypy.nodes
import mypy.types

from puya import arc4_util, log
from puya.awst import (
nodes as awst_nodes,
wtypes,
)
from puya.awst.nodes import DecimalConstant, Expression, Literal
from puya.awst_build import constants, pytypes
from puya.awst_build.arc4_utils import arc4_encode, get_arc4_method_config, get_func_types
from puya.awst_build import pytypes
from puya.awst_build.arc4_utils import arc4_encode
from puya.awst_build.eb.base import ExpressionBuilder
from puya.awst_build.utils import convert_literal, get_decorators_by_fullname
from puya.awst_build.utils import convert_literal
from puya.errors import CodeError

if typing.TYPE_CHECKING:
from collections.abc import Sequence

from puya.awst_build.context import ASTConversionModuleContext
from puya.parse import SourceLocation

logger = log.get_logger(__name__)
Expand Down Expand Up @@ -147,24 +144,6 @@ def method_selector(self) -> str:
return f"{self.method_name}({args}){arc4_util.pytype_to_arc4(return_type)}"


def get_arc4_signature(
context: ASTConversionModuleContext,
type_info: mypy.nodes.TypeInfo,
member_name: str,
location: SourceLocation,
) -> ARC4Signature:
dec = type_info.get_method(member_name)
if isinstance(dec, mypy.nodes.Decorator):
decorators = get_decorators_by_fullname(context, dec)
abimethod_dec = decorators.get(constants.ABIMETHOD_DECORATOR)
if abimethod_dec is not None:
func_def = dec.func
arc4_method_config = get_arc4_method_config(context, abimethod_dec, func_def)
*arg_types, return_type = get_func_types(context, func_def, location).values()
return ARC4Signature(arc4_method_config.name, arg_types, return_type)
raise CodeError(f"'{type_info.fullname}.{member_name}' is not a valid ARC4 method", location)


def get_arc4_args_and_signature(
method_sig: str,
arg_typs: Sequence[pytypes.PyType],
Expand Down
40 changes: 29 additions & 11 deletions src/puya/awst_build/eb/arc4/abi_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@
UInt64Constant,
)
from puya.awst_build import constants, pytypes
from puya.awst_build.arc4_utils import get_arc4_method_config, get_func_types
from puya.awst_build.eb.arc4._utils import (
ARC4Signature,
arc4_tuple_from_items,
expect_arc4_operand_wtype,
get_arc4_args_and_signature,
get_arc4_signature,
)
from puya.awst_build.eb.arc4.base import ARC4FromLogBuilder
from puya.awst_build.eb.base import ExpressionBuilder, IntermediateExpressionBuilder
from puya.awst_build.eb.subroutine import BaseClassSubroutineInvokerExpressionBuilder
from puya.awst_build.eb.transaction.fields import get_field_python_name
from puya.awst_build.eb.transaction.inner_params import get_field_expr
from puya.awst_build.eb.var_factory import var_expression
from puya.awst_build.utils import get_decorators_by_fullname, resolve_method_from_type_info
from puya.errors import CodeError, InternalError

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -97,21 +98,22 @@ def call(

@typing.override
def member_access(self, name: str, location: SourceLocation) -> ExpressionBuilder | Literal:
return ARC4ClientMethodExpressionBuilder(self.context, self.type_info, name, location)
func_or_dec = resolve_method_from_type_info(self.type_info, name, location)
if func_or_dec is None:
raise CodeError(f"Unknown member {name!r} of {self.type_info.fullname!r}", location)
return ARC4ClientMethodExpressionBuilder(self.context, func_or_dec, location)


class ARC4ClientMethodExpressionBuilder(IntermediateExpressionBuilder):
def __init__(
self,
context: ASTConversionModuleContext,
type_info: mypy.nodes.TypeInfo,
name: str,
context: ASTConversionModuleContext, # TODO: yeet me
node: mypy.nodes.FuncBase | mypy.nodes.Decorator,
location: SourceLocation,
):
super().__init__(location)
self.context = context
self.type_info = type_info
self.name = name
self.node = node

@typing.override
def call(
Expand Down Expand Up @@ -197,10 +199,10 @@ def _abi_call(
f"does not match provided method selector: '{method_str}'",
method.source_location,
)
case (
ARC4ClientMethodExpressionBuilder() | BaseClassSubroutineInvokerExpressionBuilder()
) as eb: # TODO: can probably use func type from arg_typs now
signature = get_arc4_signature(eb.context, eb.type_info, eb.name, location)
case ARC4ClientMethodExpressionBuilder(
context=context, node=node
) | BaseClassSubroutineInvokerExpressionBuilder(context=context, node=node):
signature = _get_arc4_signature(context, node, location)
abi_return_type = signature.return_type
num_args = len(abi_call_expr.abi_args)
num_types = len(signature.arg_types)
Expand Down Expand Up @@ -230,6 +232,22 @@ def _abi_call(
)


def _get_arc4_signature(
context: ASTConversionModuleContext,
func_or_dec: mypy.nodes.FuncBase | mypy.nodes.Decorator,
location: SourceLocation,
) -> ARC4Signature:
if isinstance(func_or_dec, mypy.nodes.Decorator):
decorators = get_decorators_by_fullname(context, func_or_dec)
abimethod_dec = decorators.get(constants.ABIMETHOD_DECORATOR)
if abimethod_dec is not None:
func_def = func_or_dec.func
arc4_method_config = get_arc4_method_config(context, abimethod_dec, func_def)
*arg_types, return_type = get_func_types(context, func_def, location).values()
return ARC4Signature(arc4_method_config.name, arg_types, return_type)
raise CodeError(f"{func_or_dec.fullname!r} is not a valid ARC4 method", location)


def _is_typed(typ: pytypes.PyType | None) -> typing.TypeGuard[pytypes.PyType]:
return typ not in (None, pytypes.NoneType)

Expand Down
41 changes: 20 additions & 21 deletions src/puya/awst_build/eb/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from puya.awst import wtypes
from puya.awst.nodes import (
AppStateExpression,
BaseClassSubroutineTarget,
BoxProxyField,
InstanceSubroutineTarget,
)
Expand All @@ -24,7 +25,7 @@
SubroutineInvokerExpressionBuilder,
)
from puya.awst_build.eb.var_factory import var_expression
from puya.awst_build.utils import qualified_class_name
from puya.awst_build.utils import qualified_class_name, resolve_method_from_type_info
from puya.errors import CodeError, InternalError
from puya.parse import SourceLocation

Expand All @@ -41,10 +42,15 @@ def __init__(
super().__init__(location)
self.context = context
self._type_info = type_info
self._cref = qualified_class_name(type_info)

def member_access(self, name: str, location: SourceLocation) -> ExpressionBuilder:
func_or_dec = resolve_method_from_type_info(self._type_info, name, location)
if func_or_dec is None:
raise CodeError(f"Unknown member {name!r} of {self._type_info.fullname!r}", location)
target = BaseClassSubroutineTarget(self._cref, name)
return BaseClassSubroutineInvokerExpressionBuilder(
context=self.context, type_info=self._type_info, name=name, location=location
context=self.context, target=target, node=func_or_dec, location=location
)


Expand All @@ -64,26 +70,19 @@ def member_access(self, name: str, location: SourceLocation) -> ExpressionBuilde
if state_decl is not None:
return _builder_for_storage_access(state_decl, location)

sym_node = self._type_info.get(name)
if sym_node is None or sym_node.node is None:
raise CodeError(f"Unknown member: {name}", location)
match sym_node.node:
# matching types taken from mypy.nodes.TypeInfo.get_method
case mypy.nodes.FuncBase() | mypy.nodes.Decorator() as func_or_dec:
func_type = func_or_dec.type
if not isinstance(func_type, mypy.types.CallableType):
raise CodeError(f"Couldn't resolve signature of {name!r}", location)
func_or_dec = resolve_method_from_type_info(self._type_info, name, location)
if func_or_dec is None:
raise CodeError(f"Unknown member {name!r} of {self._type_info.fullname!r}", location)
func_type = func_or_dec.type
if not isinstance(func_type, mypy.types.CallableType):
raise CodeError(f"Couldn't resolve signature of {name!r}", location)

return SubroutineInvokerExpressionBuilder(
context=self.context,
target=InstanceSubroutineTarget(name=name),
location=location,
func_type=func_type,
)
case _:
raise CodeError(
f"Non-storage member {name!r} has unsupported function type", location
)
return SubroutineInvokerExpressionBuilder(
context=self.context,
target=InstanceSubroutineTarget(name=name),
location=location,
func_type=func_type,
)


def _builder_for_storage_access(
Expand Down
20 changes: 6 additions & 14 deletions src/puya/awst_build/eb/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from puya.awst_build.context import ASTConversionModuleContext
from puya.awst_build.eb.base import ExpressionBuilder, IntermediateExpressionBuilder
from puya.awst_build.eb.var_factory import var_expression
from puya.awst_build.utils import qualified_class_name, require_expression_builder
from puya.awst_build.utils import require_expression_builder
from puya.errors import CodeError
from puya.parse import SourceLocation

Expand Down Expand Up @@ -78,22 +78,14 @@ class BaseClassSubroutineInvokerExpressionBuilder(SubroutineInvokerExpressionBui
def __init__(
self,
context: ASTConversionModuleContext,
type_info: mypy.nodes.TypeInfo,
name: str,
target: BaseClassSubroutineTarget,
location: SourceLocation,
node: mypy.nodes.FuncBase | mypy.nodes.Decorator,
):
self.name = name
self.type_info = type_info
cref = qualified_class_name(type_info)

func_or_dec = type_info.get_method(name)
if func_or_dec is None:
raise CodeError(f"Unknown member: {name}", location)
func_type = func_or_dec.type
self.node = node
func_type = node.type
if not isinstance(func_type, mypy.types.CallableType):
raise CodeError(f"Couldn't resolve signature of {name!r}", location)

target = BaseClassSubroutineTarget(cref, name)
raise CodeError(f"Couldn't resolve signature of {node.fullname!r}", location)
super().__init__(context, target, location, func_type)

def call(
Expand Down
42 changes: 16 additions & 26 deletions src/puya/awst_build/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
iterate_user_bases,
qualified_class_name,
require_expression_builder,
resolve_method_from_type_info,
)
from puya.errors import CodeError, InternalError, PuyaError
from puya.models import ARC4MethodConfig
Expand Down Expand Up @@ -1211,33 +1212,22 @@ def visit_super_expr(self, super_expr: mypy.nodes.SuperExpr) -> ExpressionBuilde
self._location(super_expr.call),
)
for base in iterate_user_bases(self.contract_method_info.type_info):
base_sym = base.get(super_expr.name)
if base_sym is None:
base_method = resolve_method_from_type_info(base, super_expr.name, super_loc)
if base_method is None:
continue
match base_sym.node:
case None:
raise CodeError(
f"Unable to resolve type of member {super_expr.name!r}", super_loc
)
# matching types taken from mypy.nodes.TypeInfo.get_method
case (mypy.nodes.FuncBase() | mypy.nodes.Decorator()) as base_method:
cref = qualified_class_name(base)
if not isinstance(base_method.type, mypy.types.CallableType):
# this shouldn't be hit unless there's typing.overload or weird
# decorators going on, both of which we don't allow
raise CodeError(
f"Unable to retrieve type of {cref.full_name}.{super_expr.name}",
super_loc,
)
return SubroutineInvokerExpressionBuilder(
context=self.context,
target=BaseClassSubroutineTarget(base_class=cref, name=super_expr.name),
location=super_loc,
func_type=base_method.type,
)
case _:
raise CodeError("super() is only supported for method calls", super_loc)

if not isinstance(base_method.type, mypy.types.CallableType):
# this shouldn't be hit unless there's typing.overload or weird
# decorators going on, both of which we don't allow
raise CodeError(f"Unable to retrieve type of {base_method.fullname!r}", super_loc)
super_target = BaseClassSubroutineTarget(
base_class=qualified_class_name(base), name=super_expr.name
)
return SubroutineInvokerExpressionBuilder(
context=self.context,
target=super_target,
func_type=base_method.type,
location=super_loc,
)
raise CodeError(
f"Unable to locate method {super_expr.name}"
f" in bases of {self.contract_method_info.cref.full_name}",
Expand Down
33 changes: 33 additions & 0 deletions src/puya/awst_build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,36 @@ def eval_slice_component(
source_location=location,
)
)


def resolve_method_from_type_info(
type_info: mypy.nodes.TypeInfo, name: str, location: SourceLocation
) -> mypy.nodes.FuncBase | mypy.nodes.Decorator | None:
"""Get a function member from TypeInfo, or return None.
Differs from TypeInfo.get_method() if there are conflicting definitions of name,
one being a method and another being an attribute.
This is important for semantic compatibility.
If the found member is not a function, an exception is raised.
Also raises if the SymbolTableNode is unresolved (it shouldn't be once we can see it).
"""
member = type_info.get(name)
if member is None:
return None
match member.node:
case None:
raise InternalError(
"mypy cross reference remains unresolved:"
f" member {name!r} of {type_info.fullname!r}",
location,
)
# matching types taken from mypy.nodes.TypeInfo.get_method
case mypy.nodes.FuncBase() | mypy.nodes.Decorator() as func_or_dec:
return func_or_dec
case other_node:
logger.debug(
f"Non-function member: type={type(other_node).__name__!r}, value={other_node}",
location=location,
)
raise CodeError(f"unsupported reference to non-function member {name!r}", location)
Loading

0 comments on commit a0618cb

Please sign in to comment.