Skip to content

Commit

Permalink
fix: raise an error when attempting to modify immutable arrays such a…
Browse files Browse the repository at this point in the history
…s `algopy.arc4.Address`

BREAKING CHANGE: modifying an `algopy.arc4.Address` will now raise an error
  • Loading branch information
daniel-makerx committed Nov 7, 2024
1 parent e53fd59 commit 9450c7a
Show file tree
Hide file tree
Showing 23 changed files with 121 additions and 122 deletions.
4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
application/Reference 177 167 - | 92 83 -
arc4_dynamic_arrays/DynamicArray 2695 1931 - | 1733 1138 -
arc4_numeric_comparisons/UIntNOrdering 1100 908 - | 786 597 -
arc4_types/Arc4Address 85 62 - | 37 18 -
arc4_types/Arc4Address 79 18 - | 34 11 -
arc4_types/Arc4Arrays 623 376 - | 368 182 -
arc4_types/Arc4BoolEval 751 14 - | 167 8 -
arc4_types/Arc4BoolType 381 69 - | 307 46 -
Expand Down Expand Up @@ -130,4 +130,4 @@
unssa/UnSSA 432 368 - | 241 204 -
voting/VotingRoundApp 1593 1483 - | 734 649 -
with_reentrancy/WithReentrancy 255 242 - | 132 122 -
Total 69200 53576 53517 | 32843 21764 21720
Total 69194 53532 53473 | 32840 21757 21713
49 changes: 49 additions & 0 deletions src/puya/awst/validation/immutable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from puya import log
from puya.awst import nodes as awst_nodes
from puya.awst.awst_traverser import AWSTTraverser

logger = log.get_logger(__name__)


class ImmutableValidator(AWSTTraverser):
@classmethod
def validate(cls, module: awst_nodes.AWST) -> None:
validator = cls()
for module_statement in module:
module_statement.accept(validator)

def visit_assignment_expression(self, expr: awst_nodes.AssignmentExpression) -> None:
super().visit_assignment_expression(expr)
_validate_lvalue(expr.target)

def visit_assignment_statement(self, statement: awst_nodes.AssignmentStatement) -> None:
super().visit_assignment_statement(statement)
_validate_lvalue(statement.target)

def visit_array_pop(self, expr: awst_nodes.ArrayPop) -> None:
super().visit_array_pop(expr)
if expr.base.wtype.immutable:
logger.error(
"cannot modify - object is immutable",
location=expr.source_location,
)

def visit_array_extend(self, expr: awst_nodes.ArrayExtend) -> None:
super().visit_array_extend(expr)
if expr.base.wtype.immutable:
logger.error(
"cannot modify - object is immutable",
location=expr.source_location,
)


def _validate_lvalue(lvalue: awst_nodes.Expression) -> None:
if isinstance(lvalue, awst_nodes.FieldExpression | awst_nodes.IndexExpression):
if lvalue.base.wtype.immutable:
logger.error(
"expression is not valid as an assignment target - object is immutable",
location=lvalue.source_location,
)
elif isinstance(lvalue, awst_nodes.TupleExpression):
for item in lvalue.items:
_validate_lvalue(item)
2 changes: 2 additions & 0 deletions src/puya/awst/validation/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from puya.awst import nodes as awst_nodes
from puya.awst.validation.arc4_copy import ARC4CopyValidator
from puya.awst.validation.base_invoker import BaseInvokerValidator
from puya.awst.validation.immutable import ImmutableValidator
from puya.awst.validation.inner_transactions import (
InnerTransactionsValidator,
InnerTransactionUsedInALoopValidator,
Expand All @@ -20,3 +21,4 @@ def validate_awst(module: awst_nodes.AWST) -> None:
BaseInvokerValidator.validate(module)
StorageTypesValidator.validate(module)
LabelsValidator.validate(module)
ImmutableValidator.validate(module)
9 changes: 1 addition & 8 deletions src/puyapy/awst_build/eb/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
BinaryBooleanOperator,
CompileTimeConstantExpression,
Expression,
FieldExpression,
Lvalue,
SingleEvaluation,
Statement,
Expand Down Expand Up @@ -220,13 +219,7 @@ def _validate_lvalue(typ: pytypes.PyType, resolved: Expression) -> Lvalue:
raise CodeError(
"expression is not valid as an assignment target", resolved.source_location
)
if isinstance(resolved, FieldExpression):
if resolved.base.wtype.immutable:
raise CodeError(
"expression is not valid as an assignment target - object is immutable",
resolved.source_location,
)
elif isinstance(resolved, TupleExpression):
if isinstance(resolved, TupleExpression):
assert isinstance(typ, pytypes.TupleLikeType)
for item_typ, item in zip(typ.items, resolved.items, strict=True):
_validate_lvalue(item_typ, item)
Expand Down
15 changes: 8 additions & 7 deletions src/puyapy/awst_build/eb/arc4/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
ARC4Decode,
ARC4Encode,
ArrayConcat,
ArrayExtend,
AssignmentStatement,
Expression,
ExpressionStatement,
Statement,
StringConstant,
)
Expand Down Expand Up @@ -104,13 +103,15 @@ def augmented_assignment(
else:
value = expect.argument_of_type_else_dummy(rhs, self.pytype).resolve()

return ExpressionStatement(
ArrayExtend(
base=self.resolve(),
other=value,
return AssignmentStatement(
target=self.resolve_lvalue(),
value=ArrayConcat(
left=self.resolve(),
right=value,
wtype=wtypes.arc4_string_alias,
source_location=location,
)
),
source_location=location,
)

@typing.override
Expand Down
5 changes: 3 additions & 2 deletions test_cases/arc4_types/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def approval_program(self) -> bool:
some_address = arc4.Address(SOME_ADDRESS)
assert some_address == SOME_ADDRESS

some_address[0] = arc4.Byte(123)
assert some_address != SOME_ADDRESS
address_copy = some_address

assert some_address == address_copy
return True

def clear_state_program(self) -> bool:
Expand Down
38 changes: 13 additions & 25 deletions test_cases/arc4_types/out/Arc4AddressContract.approval.mir
Original file line number Diff line number Diff line change
@@ -1,38 +1,26 @@
// Op Stack (out)
// Op Stack (out)
// test_cases.arc4_types.address.Arc4AddressContract.approval_program() -> uint64:
main_block@0:
// arc4_types/address.py:8
// address = arc4.Address(Txn.sender)
txn Sender address#0
txn Sender address#0
// arc4_types/address.py:9
// assert address == Txn.sender
txn Sender address#0,tmp%0#0
l-load-copy address#0 1 address#0,tmp%0#0,address#0 (copy)
l-load tmp%0#0 1 address#0,address#0 (copy),tmp%0#0
== address#0,tmp%1#0
assert address#0
txn Sender address#0,tmp%0#0
l-load-copy address#0 1 address#0,tmp%0#0,address#0 (copy)
l-load tmp%0#0 1 address#0,address#0 (copy),tmp%0#0
== address#0,tmp%1#0
assert address#0
// arc4_types/address.py:11
// assert address.native == Txn.sender
txn Sender address#0,tmp%3#0
l-load address#0 1 tmp%3#0,address#0
l-load tmp%3#0 1 address#0,tmp%3#0
== tmp%4#0
txn Sender address#0,tmp%3#0
l-load address#0 1 tmp%3#0,address#0
l-load tmp%3#0 1 address#0,tmp%3#0
== tmp%4#0
assert
// arc4_types/address.py:16
// some_address = arc4.Address(SOME_ADDRESS)
addr "VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA" Address(VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA)
// arc4_types/address.py:19
// some_address[0] = arc4.Byte(123)
byte 0x7b Address(VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA),0x7b
replace2 0 some_address#1
// arc4_types/address.py:20
// assert some_address != SOME_ADDRESS
addr "VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA" some_address#1,Address(VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA)
!= tmp%10#0
assert
// arc4_types/address.py:21
// arc4_types/address.py:22
// return True
int 1 1
int 1 1
return


15 changes: 1 addition & 14 deletions test_cases/arc4_types/out/Arc4AddressContract.approval.teal
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma version 10

test_cases.arc4_types.address.Arc4AddressContract.approval_program:
bytecblock base32(VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJPQ)
// arc4_types/address.py:8
// address = arc4.Address(Txn.sender)
txn Sender
Expand All @@ -15,19 +14,7 @@ test_cases.arc4_types.address.Arc4AddressContract.approval_program:
txn Sender
==
assert
// arc4_types/address.py:16
// some_address = arc4.Address(SOME_ADDRESS)
bytec_0 // addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA
// arc4_types/address.py:19
// some_address[0] = arc4.Byte(123)
pushbytes 0x7b
replace2 0
// arc4_types/address.py:20
// assert some_address != SOME_ADDRESS
bytec_0 // addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA
!=
assert
// arc4_types/address.py:21
// arc4_types/address.py:22
// return True
pushint 1 // 1
return
2 changes: 1 addition & 1 deletion test_cases/arc4_types/out/Arc4AddressContract.clear.mir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.address.Arc4AddressContract.clear_state_program() -> uint64:
main_block@0:
// arc4_types/address.py:24
// arc4_types/address.py:25
// return True
int 1 1
return
Expand Down
2 changes: 1 addition & 1 deletion test_cases/arc4_types/out/Arc4AddressContract.clear.teal
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma version 10

test_cases.arc4_types.address.Arc4AddressContract.clear_state_program:
// arc4_types/address.py:24
// arc4_types/address.py:25
// return True
pushint 1 // 1
return
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@ contract test_cases.arc4_types.address.Arc4AddressContract:
let tmp%3#0: bytes = (txn Sender)
let tmp%4#0: bool = (== address#0 tmp%3#0)
(assert tmp%4#0)
let some_address#1: bytes = ((replace2 0) addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA 0x7b)
let tmp%10#0: bool = (!= some_address#1 addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA)
(assert tmp%10#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.address.Arc4AddressContract.clear_state_program() -> bool:
block@0: // L23
block@0: // L24
return 1u
8 changes: 3 additions & 5 deletions test_cases/arc4_types/out/Arc4AddressContract.ssa.ir
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ contract test_cases.arc4_types.address.Arc4AddressContract:
let some_address#0: bytes = addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA
let tmp%9#0: bool = (== some_address#0 addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA)
(assert tmp%9#0)
let assigned_value%0#0: bytes = 0x7b
let updated_target%0#0: bytes = (replace3 some_address#0 0u assigned_value%0#0)
let some_address#1: bytes = updated_target%0#0
let tmp%10#0: bool = (!= some_address#1 addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA)
let address_copy#0: bytes = some_address#0
let tmp%10#0: bool = (== some_address#0 address_copy#0)
(assert tmp%10#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.address.Arc4AddressContract.clear_state_program() -> bool:
block@0: // L23
block@0: // L24
return 1u
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ contract test_cases.arc4_types.address.Arc4AddressContract:
(assert tmp%8#0)
let tmp%9#0: bool = 1u
(assert tmp%9#0)
let some_address#1: bytes = ((replace2 0) addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA 0x7b)
let tmp%10#0: bool = (!= some_address#1 addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA)
let tmp%10#0: bool = 1u
(assert tmp%10#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.address.Arc4AddressContract.clear_state_program() -> bool:
block@0: // L23
block@0: // L24
return 1u
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@ contract test_cases.arc4_types.address.Arc4AddressContract:
(assert tmp%4#0)
let tmp%6#0: bool = 1u
(assert tmp%6#0) // Address length is 32 bytes
let some_address#1: bytes = ((replace2 0) addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA 0x7b)
let tmp%10#0: bool = (!= some_address#1 addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA)
(assert tmp%10#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.address.Arc4AddressContract.clear_state_program() -> bool:
block@0: // L23
block@0: // L24
return 1u
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@ contract test_cases.arc4_types.address.Arc4AddressContract:
let tmp%3#0: bytes = (txn Sender)
let tmp%4#0: bool = (== address#0 tmp%3#0)
(assert tmp%4#0)
let some_address#1: bytes = ((replace2 0) addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA 0x7b)
let tmp%10#0: bool = (!= some_address#1 addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA)
(assert tmp%10#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.address.Arc4AddressContract.clear_state_program() -> bool:
block@0: // L23
block@0: // L24
return 1u
12 changes: 6 additions & 6 deletions test_cases/arc4_types/out/module.awst
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ contract Arc4StringTypesContract
world: arc4.dynamic_array<arc4.uint8> = arc4_encode('World!', arc4.dynamic_array<arc4.uint8>)
assert(arc4_encode('Hello World!', arc4.dynamic_array<arc4.uint8>) == hello + space + world)
thing: arc4.dynamic_array<arc4.uint8> = arc4_encode('hi', arc4.dynamic_array<arc4.uint8>)
thing.extend(thing)
thing: arc4.dynamic_array<arc4.uint8> = thing + thing
assert(thing == arc4_encode('hihi', arc4.dynamic_array<arc4.uint8>))
value: arc4.dynamic_array<arc4.uint8> = arc4_encode('a', arc4.dynamic_array<arc4.uint8>) + arc4_encode('b', arc4.dynamic_array<arc4.uint8>) + arc4_encode('cd', arc4.dynamic_array<arc4.uint8>)
value.extend(arc4_encode('e', arc4.dynamic_array<arc4.uint8>))
value.extend(arc4_encode('f', arc4.dynamic_array<arc4.uint8>))
value.extend(arc4_encode('g', arc4.dynamic_array<arc4.uint8>))
value: arc4.dynamic_array<arc4.uint8> = value + arc4_encode('e', arc4.dynamic_array<arc4.uint8>)
value: arc4.dynamic_array<arc4.uint8> = value + arc4_encode('f', arc4.dynamic_array<arc4.uint8>)
value: arc4.dynamic_array<arc4.uint8> = value + arc4_encode('g', arc4.dynamic_array<arc4.uint8>)
assert(arc4_encode('abcdefg', arc4.dynamic_array<arc4.uint8>) == value)
assert(arc4_decode(arc4_encode('', arc4.dynamic_array<arc4.uint8>), string) == '')
assert(arc4_decode(arc4_encode('hello', arc4.dynamic_array<arc4.uint8>), string) == 'hello')
Expand Down Expand Up @@ -667,8 +667,8 @@ contract Arc4AddressContract
assert(reinterpret_cast<bytes>(zero_address) == reinterpret_cast<bytes>(global<ZeroAddress>()))
some_address: arc4.static_array<arc4.uint8, 32> = Address("VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA")
assert(some_address == Address("VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA"))
some_address[0u]: arc4.uint8 = 123_arc4u8
assert(some_address != Address("VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA"))
address_copy: arc4.static_array<arc4.uint8, 32> = some_address
assert(some_address == address_copy)
return true
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
#pragma version 10

test_cases.arc4_types.address.Arc4AddressContract.approval_program:
bytecblock base32(VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJPQ)
txn Sender
dupn 2
==
assert
txn Sender
==
assert
bytec_0 // addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA
pushbytes 0x7b
replace2 0
bytec_0 // addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA
!=
assert
pushint 1 // 1
return
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@ contract test_cases.arc4_types.address.Arc4AddressContract:
let tmp%3#0: bytes = (txn Sender)
let tmp%4#0: bool = (== address#0 tmp%3#0)
(assert tmp%4#0)
let some_address#1: bytes = ((replace2 0) addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA 0x7b)
let tmp%10#0: bool = (!= some_address#1 addr VCMJKWOY5P5P7SKMZFFOCEROPJCZOTIJMNIYNUCKH7LRO45JMJP6UYBIJA)
(assert tmp%10#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.address.Arc4AddressContract.clear_state_program() -> bool:
block@0: // L23
block@0: // L24
return 1u
Loading

0 comments on commit 9450c7a

Please sign in to comment.