Skip to content

Commit

Permalink
feat: compare arc4.Address against Account
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Mar 22, 2024
1 parent c514753 commit 3888220
Show file tree
Hide file tree
Showing 19 changed files with 325 additions and 180 deletions.
2 changes: 1 addition & 1 deletion examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
arc4_types/Arc4MutableParams 362 222 140 220 2
arc4_types/Arc4Mutation 2803 1448 1355 1447 1
arc4_types/Arc4NumericTypes 571 8 563 8 0
arc4_types/Arc4RefTypes 47 39 8 39 0
arc4_types/Arc4RefTypes 92 47 45 47 0
arc4_types/Arc4StringTypes 304 8 296 8 0
arc4_types/Arc4StructsFromAnotherModule 67 12 55 12 0
arc4_types/Arc4StructsType 311 247 64 247 0
Expand Down
20 changes: 20 additions & 0 deletions src/puya/awst_build/eb/arc4/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,23 @@ def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> Expres
)

return var_expression(cmp_with_zero_expr)

def compare(
self, other: ExpressionBuilder | Literal, op: BuilderComparisonOp, location: SourceLocation
) -> ExpressionBuilder:
if self.wtype.alias != "address":
return super().compare(other, op=op, location=location)
match other:
case Literal(value=str(str_value), source_location=literal_loc):
rhs = get_bytes_expr(AddressConstant(value=str_value, source_location=literal_loc))
case ExpressionBuilder(value_type=wtypes.account_wtype):
rhs = get_bytes_expr(other.rvalue())
case _:
return super().compare(other, op=op, location=location)
cmp_expr = BytesComparisonExpression(
source_location=location,
lhs=get_bytes_expr(self.expr),
operator=EqualityComparison(op.value),
rhs=rhs,
)
return var_expression(cmp_expr)
6 changes: 6 additions & 0 deletions src/puyapy-stubs/arc4.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,12 @@ class Address(StaticArray[Byte, typing.Literal[32]]):
"""Return the Account representation of the address after ARC4 decoding"""
def __bool__(self) -> bool:
"""Returns `True` if not equal to the zero address"""
def __eq__(self, other: Address | puyapy.Account | str) -> bool: # type: ignore[override]
"""Address equality is determined by the address of another
`arc4.Address`, `Account` or `str`"""
def __ne__(self, other: Address | puyapy.Account | str) -> bool: # type: ignore[override]
"""Address equality is determined by the address of another
`arc4.Address`, `Account` or `str`"""

class DynamicBytes(DynamicArray[Byte]):
"""A variable sized array of bytes"""
Expand Down
80 changes: 44 additions & 36 deletions test_cases/arc4_types/out/Arc4RefTypesContract.approval.mir

Large diffs are not rendered by default.

31 changes: 16 additions & 15 deletions test_cases/arc4_types/out/Arc4RefTypesContract.approval.teal
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
#pragma version 10

test_cases.arc4_types.reference_types.Arc4RefTypesContract.approval_program:
// arc4_types/reference_types.py:9-11
// # When creating an address from bytes, we check the length is 32 as we don't know the
// # source of the bytes
// checked_address = arc4.Address(op.Txn.sender.bytes)
txn Sender
// arc4_types/reference_types.py:6-11
// arc4_types/reference_types.py:6-8
// # When creating an address from an account no need to check the length as we assume the
// # Account is valid
// sender_address = arc4.Address(op.Txn.sender)
txn Sender
// arc4_types/reference_types.py:9
// assert sender_address == op.Txn.sender
dupn 2
==
assert
// arc4_types/reference_types.py:10-12
// # When creating an address from bytes, we check the length is 32 as we don't know the
// # source of the bytes
// checked_address = arc4.Address(op.Txn.sender.bytes)
dupn 3
// arc4_types/reference_types.py:9-11
// # When creating an address from bytes, we check the length is 32 as we don't know the
// # source of the bytes
// checked_address = arc4.Address(op.Txn.sender.bytes)
txn Sender
dup
cover 2
dup
len
int 32
==
assert // Address length is 32 bytes
// arc4_types/reference_types.py:12-14
// arc4_types/reference_types.py:13-15
// # When using from_bytes, no validation is performed as per all implementations of
// # from_bytes
// unchecked_address = arc4.Address.from_bytes(op.Txn.sender.bytes)
txn Sender
cover 2
// arc4_types/reference_types.py:15
// arc4_types/reference_types.py:16
// assert sender_address == checked_address and checked_address == unchecked_address
==
bz main_bool_false@3
Expand All @@ -42,10 +43,10 @@ main_bool_false@3:
int 0

main_bool_merge@4:
// arc4_types/reference_types.py:15
// arc4_types/reference_types.py:16
// assert sender_address == checked_address and checked_address == unchecked_address
assert
// arc4_types/reference_types.py:17
// arc4_types/reference_types.py:19
// return True
int 1
return
4 changes: 2 additions & 2 deletions test_cases/arc4_types/out/Arc4RefTypesContract.clear.mir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

// test_cases.arc4_types.reference_types.Arc4RefTypesContract.clear_state_program() -> uint64:
main_block@0:
int 1 // 1 True arc4_types/reference_types.py:20
return // return True arc4_types/reference_types.py:20
int 1 // 1 True arc4_types/reference_types.py:22
return // return True arc4_types/reference_types.py:22

2 changes: 1 addition & 1 deletion test_cases/arc4_types/out/Arc4RefTypesContract.clear.teal
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma version 10

test_cases.arc4_types.reference_types.Arc4RefTypesContract.clear_state_program:
// arc4_types/reference_types.py:20
// arc4_types/reference_types.py:22
// return True
int 1
return
33 changes: 18 additions & 15 deletions test_cases/arc4_types/out/Arc4RefTypesContract.destructured.ir
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,30 @@ contract test_cases.arc4_types.reference_types.Arc4RefTypesContract:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.approval_program() -> uint64:
block@0: // L5
let sender_address#0: bytes = (txn Sender)
let tmp%0#0: bytes = (txn Sender)
let tmp%1#0: uint64 = (== sender_address#0 tmp%0#0)
(assert tmp%1#0)
let checked_address#0: bytes = (txn Sender)
let tmp%1#0: uint64 = (len checked_address#0)
let tmp%2#0: uint64 = (== 32u tmp%1#0)
(assert tmp%2#0) // Address length is 32 bytes
let tmp%3#0: uint64 = (len checked_address#0)
let tmp%4#0: uint64 = (== 32u tmp%3#0)
(assert tmp%4#0) // Address length is 32 bytes
let unchecked_address#0: bytes = (txn Sender)
let tmp%5#0: uint64 = (== sender_address#0 checked_address#0)
goto tmp%5#0 ? block@1 : block@3
block@1: // and_contd_L15
let tmp%6#0: uint64 = (== checked_address#0 unchecked_address#0)
goto tmp%6#0 ? block@2 : block@3
block@2: // bool_true_L15
let and_result%7#0: uint64 = 1u
let tmp%7#0: uint64 = (== sender_address#0 checked_address#0)
goto tmp%7#0 ? block@1 : block@3
block@1: // and_contd_L16
let tmp%8#0: uint64 = (== checked_address#0 unchecked_address#0)
goto tmp%8#0 ? block@2 : block@3
block@2: // bool_true_L16
let and_result%9#0: uint64 = 1u
goto block@4
block@3: // bool_false_L15
let and_result%7#0: uint64 = 0u
block@3: // bool_false_L16
let and_result%9#0: uint64 = 0u
goto block@4
block@4: // bool_merge_L15
(assert and_result%7#0)
block@4: // bool_merge_L16
(assert and_result%9#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.clear_state_program() -> uint64:
block@0: // L19
block@0: // L21
return 1u
44 changes: 25 additions & 19 deletions test_cases/arc4_types/out/Arc4RefTypesContract.ssa.ir
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,36 @@ contract test_cases.arc4_types.reference_types.Arc4RefTypesContract:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.approval_program() -> uint64:
block@0: // L5
let sender_address#0: bytes = (txn Sender)
let awst_tmp%0#0: bytes = (txn Sender)
let tmp%1#0: uint64 = (len awst_tmp%0#0)
let tmp%2#0: uint64 = (== 32u tmp%1#0)
let (value%3#0: bytes, check%4#0: uint64) = (awst_tmp%0#0, tmp%2#0)
(assert check%4#0) // Address length is 32 bytes
let checked_address#0: bytes = value%3#0
let tmp%0#0: bytes = (txn Sender)
let tmp%1#0: uint64 = (== sender_address#0 tmp%0#0)
(assert tmp%1#0)
let awst_tmp%2#0: bytes = (txn Sender)
let tmp%3#0: uint64 = (len awst_tmp%2#0)
let tmp%4#0: uint64 = (== 32u tmp%3#0)
let (value%5#0: bytes, check%6#0: uint64) = (awst_tmp%2#0, tmp%4#0)
(assert check%6#0) // Address length is 32 bytes
let checked_address#0: bytes = value%5#0
let unchecked_address#0: bytes = (txn Sender)
let tmp%5#0: uint64 = (== sender_address#0 checked_address#0)
goto tmp%5#0 ? block@1 : block@3
block@1: // and_contd_L15
let tmp%6#0: uint64 = (== checked_address#0 unchecked_address#0)
goto tmp%6#0 ? block@2 : block@3
block@2: // bool_true_L15
let and_result%7#0: uint64 = 1u
let tmp%7#0: uint64 = (== sender_address#0 checked_address#0)
goto tmp%7#0 ? block@1 : block@3
block@1: // and_contd_L16
let tmp%8#0: uint64 = (== checked_address#0 unchecked_address#0)
goto tmp%8#0 ? block@2 : block@3
block@2: // bool_true_L16
let and_result%9#0: uint64 = 1u
goto block@4
block@3: // bool_false_L15
let and_result%7#1: uint64 = 0u
block@3: // bool_false_L16
let and_result%9#1: uint64 = 0u
goto block@4
block@4: // bool_merge_L15
let and_result%7#2: uint64 = φ(and_result%7#0 <- block@2, and_result%7#1 <- block@3)
(assert and_result%7#2)
block@4: // bool_merge_L16
let and_result%9#2: uint64 = φ(and_result%9#0 <- block@2, and_result%9#1 <- block@3)
(assert and_result%9#2)
let tmp%10#0: bytes = (global ZeroAddress)
let tmp%11#0: uint64 = (== tmp%10#0 addr AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAY5HFKQ)
(assert tmp%11#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.clear_state_program() -> uint64:
block@0: // L19
block@0: // L21
return 1u
37 changes: 21 additions & 16 deletions test_cases/arc4_types/out/Arc4RefTypesContract.ssa.opt_pass_1.ir
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,33 @@ contract test_cases.arc4_types.reference_types.Arc4RefTypesContract:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.approval_program() -> uint64:
block@0: // L5
let sender_address#0: bytes = (txn Sender)
let tmp%0#0: bytes = (txn Sender)
let tmp%1#0: uint64 = (== sender_address#0 tmp%0#0)
(assert tmp%1#0)
let checked_address#0: bytes = (txn Sender)
let tmp%1#0: uint64 = (len checked_address#0)
let tmp%2#0: uint64 = (== 32u tmp%1#0)
(assert tmp%2#0) // Address length is 32 bytes
let tmp%3#0: uint64 = (len checked_address#0)
let tmp%4#0: uint64 = (== 32u tmp%3#0)
(assert tmp%4#0) // Address length is 32 bytes
let unchecked_address#0: bytes = (txn Sender)
let tmp%5#0: uint64 = (== sender_address#0 checked_address#0)
goto tmp%5#0 ? block@1 : block@3
block@1: // and_contd_L15
let tmp%6#0: uint64 = (== checked_address#0 unchecked_address#0)
goto tmp%6#0 ? block@2 : block@3
block@2: // bool_true_L15
let and_result%7#0: uint64 = 1u
let tmp%7#0: uint64 = (== sender_address#0 checked_address#0)
goto tmp%7#0 ? block@1 : block@3
block@1: // and_contd_L16
let tmp%8#0: uint64 = (== checked_address#0 unchecked_address#0)
goto tmp%8#0 ? block@2 : block@3
block@2: // bool_true_L16
let and_result%9#0: uint64 = 1u
goto block@4
block@3: // bool_false_L15
let and_result%7#1: uint64 = 0u
block@3: // bool_false_L16
let and_result%9#1: uint64 = 0u
goto block@4
block@4: // bool_merge_L15
let and_result%7#2: uint64 = φ(and_result%7#0 <- block@2, and_result%7#1 <- block@3)
(assert and_result%7#2)
block@4: // bool_merge_L16
let and_result%9#2: uint64 = φ(and_result%9#0 <- block@2, and_result%9#1 <- block@3)
(assert and_result%9#2)
let tmp%11#0: uint64 = 1u
(assert tmp%11#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.clear_state_program() -> uint64:
block@0: // L19
block@0: // L21
return 1u
33 changes: 33 additions & 0 deletions test_cases/arc4_types/out/Arc4RefTypesContract.ssa.opt_pass_2.ir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
contract test_cases.arc4_types.reference_types.Arc4RefTypesContract:
program approval:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.approval_program() -> uint64:
block@0: // L5
let sender_address#0: bytes = (txn Sender)
let tmp%0#0: bytes = (txn Sender)
let tmp%1#0: uint64 = (== sender_address#0 tmp%0#0)
(assert tmp%1#0)
let checked_address#0: bytes = (txn Sender)
let tmp%3#0: uint64 = (len checked_address#0)
let tmp%4#0: uint64 = (== 32u tmp%3#0)
(assert tmp%4#0) // Address length is 32 bytes
let unchecked_address#0: bytes = (txn Sender)
let tmp%7#0: uint64 = (== sender_address#0 checked_address#0)
goto tmp%7#0 ? block@1 : block@3
block@1: // and_contd_L16
let tmp%8#0: uint64 = (== checked_address#0 unchecked_address#0)
goto tmp%8#0 ? block@2 : block@3
block@2: // bool_true_L16
let and_result%9#0: uint64 = 1u
goto block@4
block@3: // bool_false_L16
let and_result%9#1: uint64 = 0u
goto block@4
block@4: // bool_merge_L16
let and_result%9#2: uint64 = φ(and_result%9#0 <- block@2, and_result%9#1 <- block@3)
(assert and_result%9#2)
return 1u

program clear-state:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.clear_state_program() -> uint64:
block@0: // L21
return 1u
2 changes: 2 additions & 0 deletions test_cases/arc4_types/out/reference_types.awst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ contract Arc4RefTypesContract
approval_program(): bool
{
sender_address: puyapy.arc4.Address = reinterpret_cast<puyapy.arc4.Address>(reinterpret_cast<puyapy.Bytes>(txn<Sender>()))
assert(reinterpret_cast<puyapy.Bytes>(sender_address) == reinterpret_cast<puyapy.Bytes>(txn<Sender>()))
checked_address: puyapy.arc4.Address = reinterpret_cast<puyapy.arc4.Address>(checked_maybe((SINGLE_EVAL(id=0, source=reinterpret_cast<puyapy.Bytes>(txn<Sender>())), 32u == len(SINGLE_EVAL(id=0, source=reinterpret_cast<puyapy.Bytes>(txn<Sender>()))))))
unchecked_address: puyapy.arc4.Address = reinterpret_cast<puyapy.arc4.Address>(reinterpret_cast<puyapy.Bytes>(txn<Sender>()))
assert(reinterpret_cast<puyapy.Bytes>(sender_address) == reinterpret_cast<puyapy.Bytes>(checked_address) and reinterpret_cast<puyapy.Bytes>(checked_address) == reinterpret_cast<puyapy.Bytes>(unchecked_address))
assert(reinterpret_cast<puyapy.Bytes>(global<ZeroAddress>()) == reinterpret_cast<puyapy.Bytes>(Address("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAY5HFKQ")))
return true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

test_cases.arc4_types.reference_types.Arc4RefTypesContract.approval_program:
txn Sender
dupn 3
dupn 2
==
assert
txn Sender
dup
cover 2
dup
len
int 32
==
Expand Down
33 changes: 18 additions & 15 deletions test_cases/arc4_types/out_O2/Arc4RefTypesContract.destructured.ir
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,30 @@ contract test_cases.arc4_types.reference_types.Arc4RefTypesContract:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.approval_program() -> uint64:
block@0: // L5
let sender_address#0: bytes = (txn Sender)
let tmp%0#0: bytes = (txn Sender)
let tmp%1#0: uint64 = (== sender_address#0 tmp%0#0)
(assert tmp%1#0)
let checked_address#0: bytes = (txn Sender)
let tmp%1#0: uint64 = (len checked_address#0)
let tmp%2#0: uint64 = (== 32u tmp%1#0)
(assert tmp%2#0) // Address length is 32 bytes
let tmp%3#0: uint64 = (len checked_address#0)
let tmp%4#0: uint64 = (== 32u tmp%3#0)
(assert tmp%4#0) // Address length is 32 bytes
let unchecked_address#0: bytes = (txn Sender)
let tmp%5#0: uint64 = (== sender_address#0 checked_address#0)
goto tmp%5#0 ? block@1 : block@3
block@1: // and_contd_L15
let tmp%6#0: uint64 = (== checked_address#0 unchecked_address#0)
goto tmp%6#0 ? block@2 : block@3
block@2: // bool_true_L15
let and_result%7#0: uint64 = 1u
let tmp%7#0: uint64 = (== sender_address#0 checked_address#0)
goto tmp%7#0 ? block@1 : block@3
block@1: // and_contd_L16
let tmp%8#0: uint64 = (== checked_address#0 unchecked_address#0)
goto tmp%8#0 ? block@2 : block@3
block@2: // bool_true_L16
let and_result%9#0: uint64 = 1u
goto block@4
block@3: // bool_false_L15
let and_result%7#0: uint64 = 0u
block@3: // bool_false_L16
let and_result%9#0: uint64 = 0u
goto block@4
block@4: // bool_merge_L15
(assert and_result%7#0)
block@4: // bool_merge_L16
(assert and_result%9#0)
return 1u

program clear-state:
subroutine test_cases.arc4_types.reference_types.Arc4RefTypesContract.clear_state_program() -> uint64:
block@0: // L19
block@0: // L21
return 1u
Loading

0 comments on commit 3888220

Please sign in to comment.