Skip to content

Commit

Permalink
feat: empty constructor for arc4 numeric types, defaults to zero
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Mar 22, 2024
1 parent ad89fbd commit c514753
Show file tree
Hide file tree
Showing 23 changed files with 182 additions and 86 deletions.
2 changes: 1 addition & 1 deletion examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
arc4_types/Arc4DynamicStringArray 230 112 118 112 0
arc4_types/Arc4MutableParams 362 222 140 220 2
arc4_types/Arc4Mutation 2803 1448 1355 1447 1
arc4_types/Arc4NumericTypes 472 8 464 8 0
arc4_types/Arc4NumericTypes 571 8 563 8 0
arc4_types/Arc4RefTypes 47 39 8 39 0
arc4_types/Arc4StringTypes 304 8 296 8 0
arc4_types/Arc4StructsFromAnotherModule 67 12 55 12 0
Expand Down
33 changes: 23 additions & 10 deletions src/puya/awst_build/eb/arc4/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from puya.awst import wtypes
from puya.awst.nodes import (
ARC4Encode,
ConstantValue,
Expression,
Literal,
NumericComparison,
Expand Down Expand Up @@ -36,12 +37,12 @@
logger: structlog.types.FilteringBoundLogger = structlog.get_logger(__name__)


class NumericARC4ClassExpressionBuilder(ARC4ClassExpressionBuilder):
class NumericARC4ClassExpressionBuilder(ARC4ClassExpressionBuilder, abc.ABC):
def __init__(self, location: SourceLocation):
super().__init__(location)
self.wtype: wtypes.ARC4UIntN | wtypes.ARC4UFixedNxM | None = None

def produces(self) -> wtypes.WType:
def produces(self) -> wtypes.ARC4Type:
if self.wtype is None:
raise InternalError(
"Cannot resolve wtype of generic EB until the index method is called with the "
Expand All @@ -56,14 +57,13 @@ def call(
arg_names: list[str | None],
location: SourceLocation,
) -> ExpressionBuilder:
if not self.wtype:
raise InternalError(
"Cannot resolve wtype of generic EB until the index method is called with"
" the generic type parameter."
)
wtype = self.produces()
match args:
case []:
zero_literal = Literal(value=self.zero_literal(), source_location=location)
return var_expression(convert_arc4_literal(zero_literal, wtype, location))
case [Literal() as lit]:
return var_expression(convert_arc4_literal(lit, self.wtype, location))
return var_expression(convert_arc4_literal(lit, wtype, location))
case [ExpressionBuilder(value_type=wtypes.WType() as value_type) as eb]:
value = eb.rvalue()
if value_type not in (
Expand All @@ -72,24 +72,31 @@ def call(
wtypes.biguint_wtype,
):
raise CodeError(
f"{self.wtype} constructor expects an int literal or a "
f"{wtype} constructor expects an int literal or a "
"uint64 expression or a biguint expression"
)
return var_expression(
ARC4Encode(value=value, source_location=location, wtype=self.wtype)
ARC4Encode(value=value, source_location=location, wtype=wtype)
)
case _:
raise CodeError(
"Invalid/unhandled arguments",
location,
)

@abc.abstractmethod
def zero_literal(self) -> ConstantValue:
...


class ByteClassExpressionBuilder(NumericARC4ClassExpressionBuilder):
def __init__(self, location: SourceLocation):
super().__init__(location)
self.wtype = wtypes.arc4_byte_type

def zero_literal(self) -> ConstantValue:
return 0


class _UIntNClassExpressionBuilder(NumericARC4ClassExpressionBuilder, abc.ABC):
def index(
Expand All @@ -104,6 +111,9 @@ def index(
def check_bitsize(self, n: int, location: SourceLocation) -> None:
...

def zero_literal(self) -> ConstantValue:
return 0


class UIntNClassExpressionBuilder(_UIntNClassExpressionBuilder):
def check_bitsize(self, n: int, location: SourceLocation) -> None:
Expand Down Expand Up @@ -144,6 +154,9 @@ def index_multiple(
def check_bitsize(self, n: int, location: SourceLocation) -> None:
...

def zero_literal(self) -> ConstantValue:
return "0.0"


class UFixedNxMClassExpressionBuilder(_UFixedNxMClassExpressionBuilder):
def check_bitsize(self, n: int, location: SourceLocation) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/puyapy-stubs/arc4.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class String(_ABIEncoded):
_TBitSize = typing.TypeVar("_TBitSize", bound=int)

class _UIntN(_ABIEncoded, typing.Protocol):
def __init__(self, value: puyapy.BigUInt | puyapy.UInt64 | int, /) -> None: ...
def __init__(self, value: puyapy.BigUInt | puyapy.UInt64 | int = 0, /) -> None: ...

# ~~~ https://docs.python.org/3/reference/datamodel.html#basic-customization ~~~
# TODO: mypy suggests due to Liskov below should be other: object
Expand Down Expand Up @@ -143,7 +143,7 @@ class UFixedNxM(_ABIEncoded, typing.Generic[_TBitSize, _TDecimalPlaces]):
Max size: 64 bits"""

def __init__(self, value: str, /):
def __init__(self, value: str = "0.0", /):
"""
Construct an instance of UFixedNxM where value (v) is determined from the original
decimal value (d) by the formula v = round(d * (10^M))
Expand All @@ -156,7 +156,7 @@ class BigUFixedNxM(_ABIEncoded, typing.Generic[_TBitSize, _TDecimalPlaces]):
Max size: 512 bits"""

def __init__(self, value: str, /):
def __init__(self, value: str = "0.0", /):
"""
Construct an instance of UFixedNxM where value (v) is determined from the original
decimal value (d) by the formula v = round(d * (10^M))
Expand Down
8 changes: 8 additions & 0 deletions test_cases/arc4_types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from puyapy.arc4 import (
BigUFixedNxM,
BigUIntN,
Byte,
UFixedNxM,
UInt8,
UInt16,
UInt32,
UInt64 as ARC4UInt64,
UInt512,
UIntN,
)

Expand Down Expand Up @@ -96,4 +98,10 @@ def approval_program(self) -> bool:
return True

def clear_state_program(self) -> bool:
assert BigUInt.from_bytes(Decimal().bytes) == 0
assert BigUInt.from_bytes(BigUFixedNxM[t.Literal[512], t.Literal[5]]().bytes) == 0
assert Byte() == 0
assert ARC4UInt64() == 0
assert UInt512() == 0

return True
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

// test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
main_block@0:
int 1 // 1 True arc4_types/numeric.py:96
return // return True arc4_types/numeric.py:96
int 1 // 1 True arc4_types/numeric.py:98
return // return True arc4_types/numeric.py:98

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma version 10

test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program:
// arc4_types/numeric.py:96
// arc4_types/numeric.py:98
// return True
int 1
return
4 changes: 2 additions & 2 deletions test_cases/arc4_types/out/Arc4NumericTypesContract.clear.mir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

// test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
main_block@0:
int 1 // 1 True arc4_types/numeric.py:99
return // return True arc4_types/numeric.py:99
int 1 // 1 True arc4_types/numeric.py:107
return // return True arc4_types/numeric.py:107

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma version 10

test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program:
// arc4_types/numeric.py:99
// arc4_types/numeric.py:107
// return True
int 1
return
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
return 1u

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
return 1u
14 changes: 12 additions & 2 deletions test_cases/arc4_types/out/Arc4NumericTypesContract.ssa.ir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
let uint8#0: uint64 = 255u
let val_as_bytes%0#0: bytes = (itob uint8#0)
let int8_encoded#0: bytes = ((extract 7 1) val_as_bytes%0#0)
Expand Down Expand Up @@ -117,5 +117,15 @@ contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
let tmp%0#0: uint64 = (b== 0x0000000000000000 0b)
(assert tmp%0#0)
let tmp%1#0: uint64 = (b== 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 0b)
(assert tmp%1#0)
let tmp%2#0: uint64 = (b== 0x00 0x00)
(assert tmp%2#0)
let tmp%3#0: uint64 = (b== 0x0000000000000000 0x0000000000000000)
(assert tmp%3#0)
let tmp%4#0: uint64 = (b== 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000)
(assert tmp%4#0)
return 1u
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
let int8_encoded#0: bytes = 0xff
let int8_decoded#0: uint64 = (btoi int8_encoded#0)
let tmp%1#0: uint64 = (== 255u int8_decoded#0)
Expand Down Expand Up @@ -87,5 +87,15 @@ contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
let tmp%0#0: uint64 = 1u
(assert tmp%0#0)
let tmp%1#0: uint64 = 1u
(assert tmp%1#0)
let tmp%2#0: uint64 = 1u
(assert tmp%2#0)
let tmp%3#0: uint64 = 1u
(assert tmp%3#0)
let tmp%4#0: uint64 = 1u
(assert tmp%4#0)
return 1u
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
let int8_decoded#0: uint64 = 255u
let tmp%1#0: uint64 = (== 255u int8_decoded#0)
(assert tmp%1#0)
Expand Down Expand Up @@ -50,5 +50,5 @@ contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
return 1u
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
let tmp%1#0: uint64 = 1u
(assert tmp%1#0)
let tmp%6#0: uint64 = 1u
Expand Down Expand Up @@ -38,5 +38,5 @@ contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
return 1u
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
let tmp%7#0: bytes = 0x7f
let tmp%8#0: uint64 = (btoi tmp%7#0)
let tmp%9#0: uint64 = (== tmp%8#0 127u)
Expand All @@ -28,5 +28,5 @@ contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
return 1u
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
let tmp%8#0: uint64 = 127u
let tmp%9#0: uint64 = (== tmp%8#0 127u)
(assert tmp%9#0)
Expand All @@ -21,5 +21,5 @@ contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
return 1u
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
let tmp%9#0: uint64 = 1u
(assert tmp%9#0)
let tmp%17#0: uint64 = 1u
Expand All @@ -16,5 +16,5 @@ contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
return 1u
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
return 1u

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
return 1u
5 changes: 5 additions & 0 deletions test_cases/arc4_types/out/numeric.awst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ contract Arc4NumericTypesContract

clear_state_program(): bool
{
assert(reinterpret_cast<puyapy.BigUInt>(reinterpret_cast<puyapy.Bytes>(0E-10arc4u64x10)) == 0n)
assert(reinterpret_cast<puyapy.BigUInt>(reinterpret_cast<puyapy.Bytes>(0.00000arc4n512x5)) == 0n)
assert(reinterpret_cast<puyapy.BigUInt>(0arc4u8) == reinterpret_cast<puyapy.BigUInt>(0arc4u8))
assert(reinterpret_cast<puyapy.BigUInt>(0arc4u64) == reinterpret_cast<puyapy.BigUInt>(0arc4u64))
assert(reinterpret_cast<puyapy.BigUInt>(0arc4n512) == reinterpret_cast<puyapy.BigUInt>(0arc4n512))
return true
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
contract test_cases.arc4_types.numeric.Arc4NumericTypesContract:
program approval:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.approval_program() -> uint64:
block@0: // L25
block@0: // L27
return 1u

program clear-state:
subroutine test_cases.arc4_types.numeric.Arc4NumericTypesContract.clear_state_program() -> uint64:
block@0: // L98
block@0: // L100
return 1u
Loading

0 comments on commit c514753

Please sign in to comment.