Skip to content

Commit

Permalink
- more tuple coverage
Browse files Browse the repository at this point in the history
- fix failing test for empty box keys
- rerun tests
- simplify arc4 assignment by forwarding on lvalue target assignments. Rename new parameter to better indicate it's purpose
- update arc4 example to remove stale comments
- reduce arc4 .copy() calls when being passed to intrinsic functions
  • Loading branch information
daniel-makerx authored and achidlow committed Jun 25, 2024
1 parent 64776c6 commit 8faeac7
Show file tree
Hide file tree
Showing 58 changed files with 2,797 additions and 1,412 deletions.
4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
arc4_types/Arc4BoolType 412 76 336 76 0
arc4_types/Arc4DynamicBytes 389 190 199 190 0
arc4_types/Arc4DynamicStringArray 285 112 173 112 0
arc4_types/Arc4MutableParams 418 233 185 230 3
arc4_types/Arc4MutableParams 511 303 208 295 8
arc4_types/Arc4Mutation 3260 1464 1796 1464 0
arc4_types/Arc4NumericTypes 752 201 551 201 0
arc4_types/Arc4RefTypes 94 47 47 47 0
Expand Down Expand Up @@ -93,7 +93,7 @@
too_many_permutations 112 106 6 106 0
transaction/Transaction 996 932 64 932 0
tuple_support/TupleComparisons 136 75 61 75 0
tuple_support/TupleSupport 401 292 109 292 0
tuple_support/TupleSupport 434 299 135 299 0
typed_abi_call/Greeter 2861 2446 415 2446 0
typed_abi_call/Logger 896 762 134 762 0
unary/Unary 136 96 40 96 0
Expand Down
30 changes: 24 additions & 6 deletions src/puya/awst_build/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def type_to_pytype(
*,
source_location: SourceLocation | mypy.nodes.Context,
in_type_args: bool = False,
in_func_sig: bool = False,
) -> pytypes.PyType:
loc = self._maybe_convert_location(source_location)
proper_type_or_alias: mypy.types.ProperType | mypy.types.TypeAliasType
Expand All @@ -198,7 +199,10 @@ def type_to_pytype(
else:
proper_type_or_alias = mypy.types.get_proper_type(mypy_type)
recurse = functools.partial(
self.type_to_pytype, source_location=loc, in_type_args=in_type_args
self.type_to_pytype,
source_location=loc,
in_type_args=in_type_args,
in_func_sig=in_func_sig,
)
match proper_type_or_alias:
case mypy.types.TypeAliasType(alias=alias, args=args):
Expand All @@ -218,9 +222,9 @@ def type_to_pytype(
raise InternalError(
f"mypy tuple type as instance had unrecognised args: {args}", loc
) from None
return pytypes.VariadicTupleType(
items=recurse(arg)
) # TODO(frist): only in function args
if not in_func_sig:
raise CodeError("variadic tuples are not supported", loc)
return pytypes.VariadicTupleType(items=recurse(arg))
case mypy.types.Instance(args=args) as inst:
fullname = inst.type.fullname
result = self._pytypes.get(fullname)
Expand Down Expand Up @@ -291,7 +295,12 @@ def type_to_pytype(
func_like.arg_types, func_like.arg_names, func_like.arg_kinds, strict=True
):
try:
pt = recurse(at)
pt = self.type_to_pytype(
at,
source_location=loc,
in_type_args=in_type_args,
in_func_sig=True,
)
except TypeUnionError as union:
pts = union.types
else:
Expand All @@ -301,7 +310,16 @@ def type_to_pytype(
logger.debug(
"None contained in bound args for function reference", location=loc
)
bound_args = [recurse(ba) for ba in func_like.bound_args if ba is not None]
bound_args = [
self.type_to_pytype(
ba,
source_location=loc,
in_type_args=in_type_args,
in_func_sig=True,
)
for ba in func_like.bound_args
if ba is not None
]
if func_like.definition is not None:
name = func_like.definition.fullname
else:
Expand Down
26 changes: 24 additions & 2 deletions src/puya/awst_build/validation/arc4_copy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Iterator

import attrs

from puya import log
from puya.awst import (
nodes as awst_nodes,
Expand All @@ -21,9 +23,15 @@ def __init__(self) -> None:
super().__init__()
self._for_items: awst_nodes.Lvalue | None = None

# for nodes that can't modify the input don't need to check for copies unless an assignment
# expression is being used
def visit_submit_inner_transaction(self, call: awst_nodes.SubmitInnerTransaction) -> None:
# values passed to an inner transaction do not need to be copied
pass
if _HasAssignmentVisitor.check(call):
super().visit_submit_inner_transaction(call)

def visit_intrinsic_call(self, call: awst_nodes.IntrinsicCall) -> None:
if _HasAssignmentVisitor.check(call):
super().visit_intrinsic_call(call)

def visit_assignment_statement(self, statement: awst_nodes.AssignmentStatement) -> None:
_check_assignment(statement.target, statement.value)
Expand Down Expand Up @@ -138,3 +146,17 @@ def _is_arc4_mutable(wtype: wtypes.WType) -> bool:
case wtypes.WTuple(types=types):
return any(_is_arc4_mutable(t) for t in types)
return False


@attrs.define
class _HasAssignmentVisitor(AWSTTraverser):
has_assignment: bool = False

@classmethod
def check(cls, expr: awst_nodes.Expression) -> bool:
visitor = _HasAssignmentVisitor()
expr.accept(visitor)
return visitor.has_assignment

def visit_assignment_expression(self, _: awst_nodes.AssignmentExpression) -> None:
self.has_assignment = True
25 changes: 9 additions & 16 deletions src/puya/ir/builder/arc4.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def handle_arc4_assign(
value: ValueProvider,
source_location: SourceLocation,
*,
is_recursive_assign: bool = False,
is_mutation: bool = False,
) -> Value:
result: Value
match target:
Expand All @@ -298,7 +298,7 @@ def handle_arc4_assign(
source_location=source_location,
),
source_location=source_location,
is_recursive_assign=True,
is_mutation=True,
)
case awst_nodes.FieldExpression(
base=awst_nodes.Expression(wtype=wtypes.ARC4Struct() as struct_wtype) as base_expr,
Expand All @@ -316,7 +316,7 @@ def handle_arc4_assign(
source_location=source_location,
),
source_location=source_location,
is_recursive_assign=True,
is_mutation=True,
)
case awst_nodes.TupleItemExpression(
base=awst_nodes.Expression(wtype=wtypes.ARC4Tuple() as tuple_wtype) as base_expr,
Expand All @@ -334,8 +334,11 @@ def handle_arc4_assign(
source_location=source_location,
),
source_location=source_location,
is_recursive_assign=True,
is_mutation=True,
)
# this function is sometimes invoked outside an assignment expr/stmt, which
# is how a non l-value expression can be possible
# TODO: refactor this so that this special case is handled where it originates
case awst_nodes.TupleItemExpression(
base=awst_nodes.VarExpression(wtype=wtypes.WTuple(types=items_types)) as base_expr,
index=index_value,
Expand All @@ -347,25 +350,15 @@ def handle_arc4_assign(
source_location=source_location,
)
return result
case (
awst_nodes.VarExpression()
| awst_nodes.FieldExpression()
| awst_nodes.IndexExpression()
| awst_nodes.TupleExpression()
| awst_nodes.AppStateExpression()
| awst_nodes.AppAccountStateExpression()
| awst_nodes.BoxValueExpression()
) as target:
case _:
(result,) = handle_assignment(
context,
target,
value=value,
assignment_location=source_location,
is_recursive_assign=is_recursive_assign,
is_mutation=is_mutation,
)
return result
case _:
raise CodeError(f"Invalid assignment target {target}", source_location)


def concat_values(
Expand Down
31 changes: 26 additions & 5 deletions src/puya/ir/builder/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,40 @@ def handle_assignment_expr(


def handle_assignment(
context: IRFunctionBuildContext,
target: awst_nodes.Expression,
value: ValueProvider,
assignment_location: SourceLocation,
*,
is_mutation: bool = False,
) -> Sequence[Value]:
# separating out the target LValue check allows the _handle_assignment to statically assert
# all LValue types are covered
if not isinstance(target, awst_nodes.Lvalue): # type: ignore[arg-type]
raise CodeError("expression is not valid as an assignment target", target.source_location)
return _handle_assignment(
context,
typing.cast(awst_nodes.Lvalue, target),
value,
assignment_location,
is_mutation=is_mutation,
)


def _handle_assignment(
context: IRFunctionBuildContext,
target: awst_nodes.Lvalue,
value: ValueProvider,
assignment_location: SourceLocation,
*,
is_recursive_assign: bool = False, # TODO: why is this needed
is_mutation: bool,
) -> Sequence[Value]:
match target:
case awst_nodes.VarExpression(name=var_name, source_location=var_loc):
if (
var_name in (p.name for p in context.subroutine.parameters if p.implicit_return)
and not is_recursive_assign
):
is_implicit_return = var_name in (
p.name for p in context.subroutine.parameters if p.implicit_return
)
if is_implicit_return and not is_mutation:
raise CodeError(
f"Cannot reassign mutable parameter {var_name!r}"
" which is being passed by reference",
Expand Down
11 changes: 10 additions & 1 deletion src/puya/mir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import attrs

from puya import log
from puya.errors import InternalError
from puya.errors import CodeError, InternalError
from puya.ir import models as ir
from puya.ir.types_ import AVMBytesEncoding
from puya.ir.visitor import IRVisitor
Expand Down Expand Up @@ -132,6 +132,15 @@ def visit_method_constant(self, const: ir.MethodConstant) -> None:

def visit_intrinsic_op(self, intrinsic: ir.Intrinsic) -> None:
discard_results = intrinsic is self.active_op

if intrinsic.op.code.startswith("box_"):
try:
box_key = intrinsic.args[0]
except ValueError:
raise InternalError("box key arg not found", intrinsic.source_location) from None
if isinstance(box_key, ir.BytesConstant) and not box_key.value:
raise CodeError("AVM does not support empty box keys", intrinsic.source_location)

for arg in intrinsic.args:
arg.accept(self)
produces = len(intrinsic.op_signature.returns)
Expand Down
10 changes: 3 additions & 7 deletions test_cases/arc4_types/mutable_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,8 @@ def mutating_copies(self) -> None:
self.other_routine_2(my_array_copy_2)
assert my_array_copy_2[0] == UInt8(10), "my_array_copy_2 should have mutated value"

# This is commented out because mutable params do not work with tuples
# self.other_routine_3((my_array, my_array_copy_2, my_array_copy_2))
#
# assert my_array[0] == my_array_copy_2[0] == my_array_copy[0] == UInt8(99), \
# "All arrays should be mutated"
# tuples of mutable types only work with a .copy()
self.other_routine_3((my_array.copy(), my_array_copy_2.copy(), my_array_copy_2.copy()))

# Nested array items should still require a copy
nested = StructWithArray(test_array=my_array.copy())
Expand All @@ -103,11 +100,10 @@ def other_routine_2(self, array: TestArray) -> TestArray:

@subroutine
def other_routine_3(self, arrays: tuple[TestArray, TestArray, TestArray]) -> None:
# This doesn't mutate the params
# this modifies the local copy
for array in arrays:
array[0] = UInt8(99)

# This also doesn't work
arrays[0][0] = UInt8(99)
arrays[1][0] = UInt8(99)
arrays[2][0] = UInt8(99)
Expand Down
Loading

0 comments on commit 8faeac7

Please sign in to comment.