Skip to content

Commit

Permalink
feat: optimisation of extract_uint16/32/64 with constants
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Mar 6, 2024
1 parent e3c6253 commit bf00f0d
Show file tree
Hide file tree
Showing 49 changed files with 9,018 additions and 9,089 deletions.
4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ Name O0 size O1 size O2 size
abi_routing/Reference 1179 1019 1019
amm/ConstantProductAMM 1213 1114 1114
application/Reference 168 161 161
arc4_types/Arc4Arrays 588 466 466
arc4_types/Arc4Arrays 588 387 387
arc4_types/Arc4BoolEval 569 20 20
arc4_types/Arc4BoolType 329 57 57
arc4_types/Arc4DynamicStringArray 230 112 112
arc4_types/Arc4Mutation 2803 1488 1487
arc4_types/Arc4Mutation 2803 1452 1451
arc4_types/Arc4NumericTypes 364 216 216
arc4_types/Arc4RefTypes 47 43 43
arc4_types/Arc4StringTypes 336 8 8
Expand Down
20 changes: 20 additions & 0 deletions src/puya/ir/optimize/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ def try_simplify_arithmetic_ops(
return models.BytesConstant(
source_location=op_loc, encoding=byte_const.encoding, value=extracted
)
case models.Intrinsic(
op=(
AVMOp.extract_uint16 | AVMOp.extract_uint32 | AVMOp.extract_uint64
) as extract_uint_op,
args=[
models.BytesConstant(value=bytes_value),
models.UInt64Constant(value=offset),
],
source_location=op_loc,
):
bit_size = int(extract_uint_op.code.removeprefix("extract_uint"))
byte_size = bit_size // 8
extracted = bytes_value[offset : offset + byte_size]
if len(extracted) != byte_size:
raise CodeError(f"{extract_uint_op.code} would fail at runtime", op_loc)
uint64_result = int.from_bytes(extracted, byteorder="big", signed=False)
return models.UInt64Constant(
value=uint64_result,
source_location=op_loc,
)
case models.Intrinsic(
op=AVMOp.concat,
args=[models.Value(atype=AVMType.bytes) as ba, models.BytesConstant(value=b"")],
Expand Down
596 changes: 274 additions & 322 deletions test_cases/arc4_types/out/Arc4ArraysContract.approval.mir

Large diffs are not rendered by default.

131 changes: 44 additions & 87 deletions test_cases/arc4_types/out/Arc4ArraysContract.approval.teal
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,17 @@
test_cases.arc4_types.array.Arc4ArraysContract.approval_program:
int 0
byte ""
dupn 4
dupn 2
// arc4_types/array.py:26
// total = UInt64(0)
int 0
// arc4_types/array.py:25
// dynamic_uint8_array = DynamicArray[UInt8](UInt8(1), UInt8(2))
byte 0x00020102
// arc4_types/array.py:27
// for uint8_item in dynamic_uint8_array:
int 0
extract_uint16
int 0

main_for_header@1:
// arc4_types/array.py:27
// for uint8_item in dynamic_uint8_array:
dup
dig 2
int 2
<
bz main_after_for@4
byte 0x0102
Expand All @@ -32,9 +25,9 @@ main_for_header@1:
// arc4_types/array.py:28
// total += uint8_item.decode()
btoi
dig 4
dig 3
+
bury 4
bury 3
int 1
+
bury 1
Expand All @@ -43,29 +36,21 @@ main_for_header@1:
main_after_for@4:
// arc4_types/array.py:30
// assert total == 3, "Total should be sum of dynamic_uint8_array items"
dig 2
dig 1
int 3
==
assert // Total should be sum of dynamic_uint8_array items
// arc4_types/array.py:31
// aliased_dynamic = AliasedDynamicArray(UInt16(1))
byte 0x00010001
// arc4_types/array.py:32
// for uint16_item in aliased_dynamic:
int 0
extract_uint16
bury 8
int 0
bury 6
bury 4

main_for_header@5:
// arc4_types/array.py:32
// for uint16_item in aliased_dynamic:
dig 5
dig 8
dig 3
int 1
<
bz main_after_for@8
dig 5
dig 3
dup
int 2
*
Expand All @@ -76,68 +61,40 @@ main_for_header@5:
// arc4_types/array.py:33
// total += uint16_item.decode()
btoi
dig 4
dig 3
+
bury 4
bury 3
int 1
+
bury 6
bury 4
b main_for_header@5

main_after_for@8:
// arc4_types/array.py:34
// assert total == 4, "Total should now include sum of aliased_dynamic items"
dig 2
dig 1
int 4
==
assert // Total should now include sum of aliased_dynamic items
// arc4_types/array.py:35
// dynamic_string_array = DynamicArray[String](String("Hello"), String("World"))
byte base64 AAIABAALAAVIZWxsbwAFV29ybGQ=
// arc4_types/array.py:36
// assert dynamic_string_array.length == 2
int 0
extract_uint16
dup
bury 5
dup
int 2
==
assert
// arc4_types/array.py:37
// assert dynamic_string_array[0] == String("Hello")
int 0
>
assert // Index access is out of bounds
byte base64 AAQACwAFSGVsbG8ABVdvcmxk
int 0
extract_uint16
byte base64 AAQACwAFSGVsbG8ABVdvcmxk
dig 1
extract_uint16
int 2
+
byte base64 AAQACwAFSGVsbG8ABVdvcmxk
cover 2
extract3
byte "\x00\x05Hello"
==
assert
// arc4_types/array.py:38
// result = Bytes(b"")
byte ""
bury 9
bury 6
int 0
bury 7
bury 5

main_for_header@9:
// arc4_types/array.py:39
// for index, string_item in uenumerate(dynamic_string_array):
dig 6
dig 4
// arc4_types/array.py:36
// assert dynamic_string_array.length == 2
int 2
// arc4_types/array.py:39
// for index, string_item in uenumerate(dynamic_string_array):
<
bz main_after_for@15
dig 6
dig 4
dup
int 2
*
Expand Down Expand Up @@ -171,7 +128,7 @@ main_for_header@9:
// arc4_types/array.py:41
// result = string_item.decode()
extract 2 0
bury 9
bury 6
b main_after_if_else@13

main_else_body@12:
Expand All @@ -181,36 +138,36 @@ main_else_body@12:
byte " "
swap
concat
dig 9
dig 6
swap
concat
bury 9
bury 6

main_after_if_else@13:
dig 6
dig 4
int 1
+
bury 7
bury 5
b main_for_header@9

main_after_for@15:
// arc4_types/array.py:45
// assert result == b"Hello World"
dig 8
dig 5
byte "Hello World"
==
assert
int 0
bury 5
bury 3

main_for_header@16:
// arc4_types/array.py:49
// for uint32_item in static_uint32_array:
dig 4
dig 2
int 4
<
bz main_after_for@19
dig 4
dig 2
dup
int 4
*
Expand All @@ -225,18 +182,18 @@ main_for_header@16:
// arc4_types/array.py:50
// total += uint32_item.decode()
btoi
dig 4
dig 3
+
bury 4
bury 3
int 1
+
bury 5
bury 3
b main_for_header@16

main_after_for@19:
// arc4_types/array.py:52
// assert total == 4 + 1 + 10 + 255 + 128
dig 2
dig 1
int 398
==
assert
Expand All @@ -252,18 +209,18 @@ main_after_for@19:
// arc4_types/array.py:62
// result = Bytes(b"")
byte ""
bury 9
bury 6
int 0
bury 7
bury 5

main_for_header@20:
// arc4_types/array.py:63
// for index, string_item in uenumerate(static_string_array):
dig 6
dig 4
int 2
<
bz main_after_for@26
dig 6
dig 4
dup
int 2
*
Expand Down Expand Up @@ -297,7 +254,7 @@ main_for_header@20:
// arc4_types/array.py:65
// result = string_item.decode()
extract 2 0
bury 9
bury 6
b main_after_if_else@24

main_else_body@23:
Expand All @@ -307,22 +264,22 @@ main_else_body@23:
byte " "
swap
concat
dig 9
dig 6
swap
concat
bury 9
bury 6

main_after_if_else@24:
dig 6
dig 4
int 1
+
bury 7
bury 5
b main_for_header@20

main_after_for@26:
// arc4_types/array.py:69
// assert result == b"Ping Pong"
dig 8
dig 5
byte "Ping Pong"
==
assert
Expand Down
19 changes: 3 additions & 16 deletions test_cases/arc4_types/out/Arc4ArraysContract.destructured.ir
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ contract test_cases.arc4_types.array.Arc4ArraysContract:
subroutine test_cases.arc4_types.array.Arc4ArraysContract.approval_program() -> uint64:
block@0: // L24
let total#0: uint64 = 0u
let array_length%1#0: uint64 = (extract_uint16 0x00020102 0u)
let item_index_internal%3#0: uint64 = 0u
goto block@1
block@1: // for_header_L27
let continue_looping%5#0: uint64 = (< item_index_internal%3#0 array_length%1#0)
let continue_looping%5#0: uint64 = (< item_index_internal%3#0 2u)
goto continue_looping%5#0 ? block@2 : block@4
block@2: // for_body_L27
let uint8_item#0: bytes = (extract3 0x0102 item_index_internal%3#0 1u)
Expand All @@ -18,11 +17,10 @@ contract test_cases.arc4_types.array.Arc4ArraysContract:
block@4: // after_for_L27
let tmp%8#0: uint64 = (== total#0 3u)
(assert tmp%8#0) // Total should be sum of dynamic_uint8_array items
let array_length%10#0: uint64 = (extract_uint16 0x00010001 0u)
let item_index_internal%12#0: uint64 = 0u
goto block@5
block@5: // for_header_L32
let continue_looping%14#0: uint64 = (< item_index_internal%12#0 array_length%10#0)
let continue_looping%14#0: uint64 = (< item_index_internal%12#0 1u)
goto continue_looping%14#0 ? block@6 : block@8
block@6: // for_body_L32
let item_index%15#0: uint64 = (* item_index_internal%12#0 2u)
Expand All @@ -34,23 +32,12 @@ contract test_cases.arc4_types.array.Arc4ArraysContract:
block@8: // after_for_L32
let tmp%17#0: uint64 = (== total#0 4u)
(assert tmp%17#0) // Total should now include sum of aliased_dynamic items
let tmp%26#0: uint64 = (extract_uint16 AAIABAALAAVIZWxsbwAFV29ybGQ= 0u)
let tmp%27#0: uint64 = (== tmp%26#0 2u)
(assert tmp%27#0)
let index_is_in_bounds%29#0: uint64 = (< 0u tmp%26#0)
(assert index_is_in_bounds%29#0) // Index access is out of bounds
let item_index%32#0: uint64 = (extract_uint16 AAQACwAFSGVsbG8ABVdvcmxk 0u)
let item_length%33#0: uint64 = (extract_uint16 AAQACwAFSGVsbG8ABVdvcmxk item_index%32#0)
let item_length_plus_2%34#0: uint64 = (+ item_length%33#0 2u)
let tmp%35#0: bytes = (extract3 AAQACwAFSGVsbG8ABVdvcmxk item_index%32#0 item_length_plus_2%34#0)
let tmp%36#0: uint64 = (== tmp%35#0 "\x00\x05Hello")
(assert tmp%36#0)
let result#0: bytes = ""
let item_index_internal%39#0: uint64 = 0u
let index#0: uint64 = item_index_internal%39#0
goto block@9
block@9: // for_header_L39
let continue_looping%41#0: uint64 = (< index#0 tmp%26#0)
let continue_looping%41#0: uint64 = (< index#0 2u)
goto continue_looping%41#0 ? block@10 : block@15
block@10: // for_body_L39
let item_index_index%42#0: uint64 = (* index#0 2u)
Expand Down
Loading

0 comments on commit bf00f0d

Please sign in to comment.