From 4c2ed035437a87496bc1a8d54c1fb7bdb3bff8b7 Mon Sep 17 00:00:00 2001 From: Daniel McGregor Date: Wed, 20 Dec 2023 15:30:25 +0800 Subject: [PATCH] perf: speed up codegen (#48) --------- Co-authored-by: Adam Chidlow --- examples/bytes_ops/puya.log | 20 +- .../contract.approval_unoptimized.debug.teal | 2 +- examples/unssa/puya.log | 4 +- .../voting.approval_unoptimized.debug.teal | 2 +- examples/with_reentrancy/puya.log | 24 +- src/puya/codegen/context.py | 23 +- src/puya/codegen/ops.py | 1 + src/puya/codegen/stack.py | 238 ++++++-------- src/puya/codegen/stack_assignment.py | 9 +- src/puya/codegen/stack_baileys.py | 24 +- src/puya/codegen/stack_frame_allocation.py | 17 +- src/puya/codegen/stack_koopmans.py | 309 +++++++++--------- src/puya/codegen/stack_simplify_teal.py | 12 +- src/puya/codegen/vla.py | 36 +- src/puya/utils.py | 30 +- 15 files changed, 378 insertions(+), 373 deletions(-) diff --git a/examples/bytes_ops/puya.log b/examples/bytes_ops/puya.log index f21fbcc11a..b5eb2e272f 100644 --- a/examples/bytes_ops/puya.log +++ b/examples/bytes_ops/puya.log @@ -217,16 +217,16 @@ debug: Inserted do_some_ops_block@0.ops[39]: 'store result#0 to l-stack (copy)' debug: Replaced do_some_ops_block@0.ops[41]: 'load result#0' with 'load result#0 from l-stack (no copy)' debug: Inserted do_some_ops_block@0.ops[44]: 'store tmp%3#0 to l-stack (copy)' debug: Replaced do_some_ops_block@0.ops[46]: 'load tmp%3#0' with 'load tmp%3#0 from l-stack (no copy)' -debug: Inserted do_augmented_assignment_ops_block@0.ops[6]: 'store tmp%0#0 to l-stack (copy)' -debug: Replaced do_augmented_assignment_ops_block@0.ops[8]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' -debug: Inserted do_augmented_assignment_ops_block@0.ops[16]: 'store tmp%1#0 to l-stack (copy)' -debug: Replaced do_augmented_assignment_ops_block@0.ops[18]: 'load tmp%1#0' with 'load tmp%1#0 from l-stack (no copy)' -debug: Inserted do_augmented_assignment_ops_block@0.ops[26]: 'store tmp%2#0 to l-stack (copy)' -debug: Replaced do_augmented_assignment_ops_block@0.ops[28]: 'load tmp%2#0' with 'load tmp%2#0 from l-stack (no copy)' -debug: Inserted do_augmented_assignment_ops_block@0.ops[36]: 'store tmp%3#0 to l-stack (copy)' -debug: Replaced do_augmented_assignment_ops_block@0.ops[38]: 'load tmp%3#0' with 'load tmp%3#0 from l-stack (no copy)' -debug: Inserted do_augmented_assignment_ops_block@0.ops[46]: 'store tmp%4#0 to l-stack (copy)' -debug: Replaced do_augmented_assignment_ops_block@0.ops[48]: 'load tmp%4#0' with 'load tmp%4#0 from l-stack (no copy)' +debug: Inserted do_augmented_assignment_ops_block@0.ops[7]: 'store tmp%0#0 to l-stack (copy)' +debug: Replaced do_augmented_assignment_ops_block@0.ops[9]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' +debug: Inserted do_augmented_assignment_ops_block@0.ops[18]: 'store tmp%1#0 to l-stack (copy)' +debug: Replaced do_augmented_assignment_ops_block@0.ops[20]: 'load tmp%1#0' with 'load tmp%1#0 from l-stack (no copy)' +debug: Inserted do_augmented_assignment_ops_block@0.ops[29]: 'store tmp%2#0 to l-stack (copy)' +debug: Replaced do_augmented_assignment_ops_block@0.ops[31]: 'load tmp%2#0' with 'load tmp%2#0 from l-stack (no copy)' +debug: Inserted do_augmented_assignment_ops_block@0.ops[40]: 'store tmp%3#0 to l-stack (copy)' +debug: Replaced do_augmented_assignment_ops_block@0.ops[42]: 'load tmp%3#0' with 'load tmp%3#0 from l-stack (no copy)' +debug: Inserted do_augmented_assignment_ops_block@0.ops[51]: 'store tmp%4#0 to l-stack (copy)' +debug: Replaced do_augmented_assignment_ops_block@0.ops[53]: 'load tmp%4#0' with 'load tmp%4#0 from l-stack (no copy)' info: Writing bytes_ops/out/contract.approval.teal info: Writing bytes_ops/out/contract.approval.debug.teal info: Writing bytes_ops/out/contract.clear.teal diff --git a/examples/everything/out/contract.approval_unoptimized.debug.teal b/examples/everything/out/contract.approval_unoptimized.debug.teal index a8c8883221..879b8ed92f 100644 --- a/examples/everything/out/contract.approval_unoptimized.debug.teal +++ b/examples/everything/out/contract.approval_unoptimized.debug.teal @@ -172,7 +172,7 @@ register_if_body@1: byte "name" // (𝕡) name#0 | 0,0,"name" self.name.maybe(account=0) File "everything/contract.py", line 43 app_local_get_ex // (𝕡) name#0 | {app_local_get_ex}.0,{app_local_get_ex}.1 self.name.maybe(account=0) File "everything/contract.py", line 43 swap // store tuple_assignment%3#0 to l-stack (no copy) (𝕡) name#0 | tuple_assignment%3#0,{app_local_get_ex}.0 self.name.maybe(account=0) File "everything/contract.py", line 43 - pop // (𝕡) name#0 | tuple_assignment%3#0 self.name.maybe(account=0) File "everything/contract.py", line 43 + pop // (𝕡) name#0 | tuple_assignment%3#0 sender_name File "everything/contract.py", line 43 bnz register_after_if_else@3 // (𝕡) name#0 | not sender_name_existed File "everything/contract.py", line 44 // Implicit fall through to register_if_body@2 // (𝕡) name#0 | not sender_name_existed File "everything/contract.py", line 44 diff --git a/examples/unssa/puya.log b/examples/unssa/puya.log index 9ce76b77f3..b7a9a2e5b8 100644 --- a/examples/unssa/puya.log +++ b/examples/unssa/puya.log @@ -927,8 +927,8 @@ debug: shared x-stack for test_swap_loop_while_top@3 -> test_swap_loop_after_whi debug: examples.unssa.contract.test_swap_loop f-stack entry: [] debug: examples.unssa.contract.test_swap_loop f-stack on first store: ['x#0', 'y#0'] debug: Simplified frame_dig 0; frame_bury 0; retsub to retsub -debug: Inserted test_param_update_with_reentrant_entry_block_while_top@1.ops[6]: 'store tmp%0#0 to l-stack (copy)' -debug: Replaced test_param_update_with_reentrant_entry_block_while_top@1.ops[8]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' +debug: Inserted test_param_update_with_reentrant_entry_block_while_top@1.ops[7]: 'store tmp%0#0 to l-stack (copy)' +debug: Replaced test_param_update_with_reentrant_entry_block_while_top@1.ops[9]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' debug: Found 2 edge set/s for examples.unssa.contract.test_param_update_with_reentrant_entry_block debug: Inserted test_param_update_with_reentrant_entry_block_v2_while_top@1.ops[3]: 'store tmp%0#0 to l-stack (copy)' debug: Replaced test_param_update_with_reentrant_entry_block_v2_while_top@1.ops[5]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' diff --git a/examples/voting/out/voting.approval_unoptimized.debug.teal b/examples/voting/out/voting.approval_unoptimized.debug.teal index 19eee52925..2b4c615434 100644 --- a/examples/voting/out/voting.approval_unoptimized.debug.teal +++ b/examples/voting/out/voting.approval_unoptimized.debug.teal @@ -788,7 +788,7 @@ already_voted_block@0: txn Sender // {txn} Transaction.sender() File "voting/voting.py", line 229 box_get // {box_get}.0,{box_get}.1 Box.get(Transaction.sender().bytes) File "voting/voting.py", line 229 swap // store tuple_assignment%2#0 to l-stack (no copy) tuple_assignment%2#0,{box_get}.0 Box.get(Transaction.sender().bytes) File "voting/voting.py", line 229 - pop // tuple_assignment%2#0 Box.get(Transaction.sender().bytes) File "voting/voting.py", line 229 + pop // tuple_assignment%2#0 votes File "voting/voting.py", line 229 retsub // exists#0 return exists File "voting/voting.py", line 230 diff --git a/examples/with_reentrancy/puya.log b/examples/with_reentrancy/puya.log index fcf44cdb89..79e41a7635 100644 --- a/examples/with_reentrancy/puya.log +++ b/examples/with_reentrancy/puya.log @@ -458,18 +458,18 @@ debug: Replaced fibonacci_after_if_else@2.ops[18]: 'load tmp%4#0' with 'load tmp debug: Inserted fibonacci_after_if_else@2.ops[7]: 'store tmp%2#0 to l-stack (copy)' debug: Replaced fibonacci_after_if_else@2.ops[18]: 'load tmp%2#0' with 'load tmp%2#0 from l-stack (no copy)' debug: Found 1 edge set/s for examples.with_reentrancy.contract.fibonacci -debug: Inserted silly_block@0.ops[12]: 'store tmp%1#0 to l-stack (copy)' -debug: Replaced silly_block@0.ops[14]: 'load tmp%1#0' with 'load tmp%1#0 from l-stack (no copy)' -debug: Inserted silly_block@0.ops[8]: 'store tmp%0#0 to l-stack (copy)' -debug: Replaced silly_block@0.ops[11]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' -debug: Inserted silly_block@0.ops[5]: 'store result#0 to l-stack (copy)' -debug: Replaced silly_block@0.ops[18]: 'load result#0' with 'load result#0 from l-stack (no copy)' -debug: Inserted silly2_block@0.ops[12]: 'store tmp%1#0 to l-stack (copy)' -debug: Replaced silly2_block@0.ops[14]: 'load tmp%1#0' with 'load tmp%1#0 from l-stack (no copy)' -debug: Inserted silly2_block@0.ops[8]: 'store tmp%0#0 to l-stack (copy)' -debug: Replaced silly2_block@0.ops[11]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' -debug: Inserted silly2_block@0.ops[5]: 'store result#0 to l-stack (copy)' -debug: Replaced silly2_block@0.ops[18]: 'load result#0' with 'load result#0 from l-stack (no copy)' +debug: Inserted silly_block@0.ops[13]: 'store tmp%1#0 to l-stack (copy)' +debug: Replaced silly_block@0.ops[15]: 'load tmp%1#0' with 'load tmp%1#0 from l-stack (no copy)' +debug: Inserted silly_block@0.ops[9]: 'store tmp%0#0 to l-stack (copy)' +debug: Replaced silly_block@0.ops[12]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' +debug: Inserted silly_block@0.ops[6]: 'store result#0 to l-stack (copy)' +debug: Replaced silly_block@0.ops[19]: 'load result#0' with 'load result#0 from l-stack (no copy)' +debug: Inserted silly2_block@0.ops[13]: 'store tmp%1#0 to l-stack (copy)' +debug: Replaced silly2_block@0.ops[15]: 'load tmp%1#0' with 'load tmp%1#0 from l-stack (no copy)' +debug: Inserted silly2_block@0.ops[9]: 'store tmp%0#0 to l-stack (copy)' +debug: Replaced silly2_block@0.ops[12]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' +debug: Inserted silly2_block@0.ops[6]: 'store result#0 to l-stack (copy)' +debug: Replaced silly2_block@0.ops[19]: 'load result#0' with 'load result#0 from l-stack (no copy)' debug: Inserted silly3_block@0.ops[3]: 'store tmp%0#0 to l-stack (copy)' debug: Replaced silly3_block@0.ops[5]: 'load tmp%0#0' with 'load tmp%0#0 from l-stack (no copy)' debug: Inserted silly3_block@0.ops[7]: 'store is_even#0 to l-stack (copy)' diff --git a/src/puya/codegen/context.py b/src/puya/codegen/context.py index cb7469abb0..edd93a2161 100644 --- a/src/puya/codegen/context.py +++ b/src/puya/codegen/context.py @@ -1,10 +1,31 @@ import attrs +from puya.codegen import ops as mir +from puya.codegen.vla import VariableLifetimeAnalysis from puya.context import CompileContext from puya.ir import models as ir +from puya.utils import attrs_extend -@attrs.frozen +@attrs.define class ProgramCodeGenContext(CompileContext): contract: ir.Contract program: ir.Program + + def for_subroutine(self, subroutine: mir.MemorySubroutine) -> "SubroutineCodeGenContext": + return attrs_extend(SubroutineCodeGenContext, self, subroutine=subroutine) + + +@attrs.define(frozen=False) +class SubroutineCodeGenContext(ProgramCodeGenContext): + subroutine: mir.MemorySubroutine + _vla: VariableLifetimeAnalysis | None = attrs.field(default=None) + + @property + def vla(self) -> VariableLifetimeAnalysis: + if self._vla is None: + self._vla = VariableLifetimeAnalysis.analyze(self.subroutine) + return self._vla + + def invalidate_vla(self) -> None: + self._vla = None diff --git a/src/puya/codegen/ops.py b/src/puya/codegen/ops.py index 61a28c0726..af23650b33 100644 --- a/src/puya/codegen/ops.py +++ b/src/puya/codegen/ops.py @@ -227,6 +227,7 @@ def __str__(self) -> str: return f"pop {self.n}" +@t.final @attrs.frozen(eq=False, init=False) class VirtualStackOp(BaseOp): original: Sequence[BaseOp] diff --git a/src/puya/codegen/stack.py b/src/puya/codegen/stack.py index ee7a3cc3af..f257bb1cc9 100644 --- a/src/puya/codegen/stack.py +++ b/src/puya/codegen/stack.py @@ -1,20 +1,20 @@ import contextlib import typing -from collections.abc import Iterable, Iterator -from copy import deepcopy +from collections.abc import Iterator import attrs from puya.codegen import ops, teal from puya.codegen.utils import format_bytes from puya.codegen.visitor import MIRVisitor -from puya.codegen.vla import VariableLifetimeAnalysis from puya.errors import InternalError from puya.ir.types_ import AVMBytesEncoding @attrs.define -class _StackState: +class Stack(MIRVisitor[list[teal.TealOp]]): + allow_virtual: bool = attrs.field(default=True) + use_frame: bool = attrs.field(default=False) parameters: list[str] = attrs.field(factory=list) f_stack: list[str] = attrs.field(factory=list) """f-stack holds variables above the current frame""" @@ -23,20 +23,15 @@ class _StackState: l_stack: list[str] = attrs.field(factory=list) """l-stack holds variables that are used within a block""" - @property - def full_stack(self) -> Iterable[str]: - yield from self.f_stack - yield from self.x_stack - yield from self.l_stack - - -@attrs.define -class Stack(MIRVisitor[list[teal.TealOp]]): - allow_virtual: bool = attrs.field(default=True) - _current_subroutine: ops.MemorySubroutine | None = attrs.field(default=None) - _use_frame: bool = attrs.field(default=False) - _vla: VariableLifetimeAnalysis | None = attrs.field(default=None) - state: _StackState = attrs.field(factory=_StackState) + def copy(self) -> "Stack": + return Stack( + allow_virtual=self.allow_virtual, + use_frame=self.use_frame, + parameters=self.parameters.copy(), + f_stack=self.f_stack.copy(), + x_stack=self.x_stack.copy(), + l_stack=self.l_stack.copy(), + ) @classmethod def for_full_stack( @@ -47,98 +42,55 @@ def for_full_stack( return stack def begin_block(self, subroutine: ops.MemorySubroutine, block: ops.MemoryBasicBlock) -> None: - if self._current_subroutine != subroutine: - self._vla = None - self._current_subroutine = subroutine - self.state = _StackState( - parameters=[p.local_id for p in subroutine.signature.parameters], - f_stack=list(block.f_stack_in), - x_stack=list(block.x_stack_in or ()), # x-stack might not be assigned yet - l_stack=[], - ) - self._use_frame = not subroutine.is_main - - @property - def _parameters(self) -> list[str]: - return self.state.parameters - - @property - def _f_stack(self) -> list[str]: - return self.state.f_stack - - @property - def _x_stack(self) -> list[str]: - return self.state.x_stack - - @property - def _l_stack(self) -> list[str]: - return self.state.l_stack - - @property - def f_x_l_stack_height(self) -> int: - return self.f_stack_height + self.x_stack_height + self.l_stack_height - - @property - def x_l_stack_height(self) -> int: - return self.x_stack_height + self.l_stack_height - - @property - def f_stack_height(self) -> int: - return len(self._f_stack) - - @property - def x_stack_height(self) -> int: - return len(self._x_stack) - - @property - def l_stack_height(self) -> int: - return len(self._l_stack) + self.parameters = [p.local_id for p in subroutine.signature.parameters] + self.f_stack = list(block.f_stack_in) + self.x_stack = list(block.x_stack_in or ()) # x-stack might not be assigned yet + self.l_stack = [] + self.use_frame = not subroutine.is_main @property def full_stack_desc(self) -> str: stack_descs = [] - if self._parameters: - stack_descs.append("(𝕡) " + ",".join(self._parameters)) # noqa: RUF001 - if self._f_stack: - stack_descs.append("(𝕗) " + ",".join(self._f_stack)) # noqa: RUF001 - if self._x_stack: - stack_descs.append("(𝕏) " + ",".join(self._x_stack)) # noqa: RUF001 - stack_descs.append(",".join(self._l_stack)) + if self.parameters: + stack_descs.append("(𝕡) " + ",".join(self.parameters)) # noqa: RUF001 + if self.f_stack: + stack_descs.append("(𝕗) " + ",".join(self.f_stack)) # noqa: RUF001 + if self.x_stack: + stack_descs.append("(𝕏) " + ",".join(self.x_stack)) # noqa: RUF001 + stack_descs.append(",".join(self.l_stack)) return " | ".join(stack_descs) def _stack_error(self, msg: str) -> typing.Never: raise InternalError(f"{msg}: {self.full_stack_desc}") def _l_stack_assign_name(self, value: str) -> None: - if not self._l_stack: + if not self.l_stack: self._stack_error(f"l-stack too small to assign name {value}") - self._l_stack[-1] = value + self.l_stack[-1] = value def _get_f_stack_dig_bury(self, value: str) -> int: - return self.f_x_l_stack_height - self._f_stack.index(value) - 1 - - def _get_x_stack_cover_n(self) -> int: - """Return n value for a (un)cover operation with the x stack""" - return self.x_l_stack_height - 1 - - def get_l_stack_cover_n(self) -> int: - """Return n value for a (un)cover operation with the l stack""" - return self.l_stack_height - 1 + return ( + len(self.f_stack) + + len(self.x_stack) + + len(self.l_stack) + - self.f_stack.index(value) + - 1 + ) def visit_push_int(self, push: ops.PushInt) -> list[teal.TealOp]: - self._l_stack.append(str(push.value)) + self.l_stack.append(str(push.value)) return [teal.PushInt(push.value)] def visit_push_bytes(self, push: ops.PushBytes) -> list[teal.TealOp]: - self._l_stack.append(format_bytes(push.value, push.encoding)) + self.l_stack.append(format_bytes(push.value, push.encoding)) return [teal.PushBytes(push.value, push.encoding)] def visit_push_address(self, addr: ops.PushAddress) -> list[teal.TealOp]: - self._l_stack.append(addr.value) + self.l_stack.append(addr.value) return [teal.PushAddress(addr.value)] def visit_push_method(self, method: ops.PushMethod) -> list[teal.TealOp]: - self._l_stack.append(f'method<"{method.value}">') + self.l_stack.append(f'method<"{method.value}">') return [teal.PushMethod(method.value)] def visit_comment(self, _comment: ops.Comment) -> list[teal.TealOp]: @@ -149,9 +101,9 @@ def visit_store_virtual(self, store: ops.StoreVirtual) -> list[teal.TealOp]: raise InternalError( "StoreVirtual op encountered during TEAL generation", store.source_location ) - if not self._l_stack: + if not self.l_stack: self._stack_error(f"l-stack too small to store {store.local_id}") - self._l_stack.pop() + self.l_stack.pop() return [] def visit_load_virtual(self, load: ops.LoadVirtual) -> list[teal.TealOp]: @@ -159,35 +111,35 @@ def visit_load_virtual(self, load: ops.LoadVirtual) -> list[teal.TealOp]: raise InternalError( "LoadVirtual op encountered during TEAL generation", load.source_location ) - self._l_stack.append(load.local_id) + self.l_stack.append(load.local_id) return [] def _store_f_stack(self, value: str) -> teal.Cover | teal.FrameBury | teal.Bury: """Updates the stack, and if insert returns the cover value, else the bury value""" - if value not in self._f_stack: + if value not in self.f_stack: self._stack_error(f"{value} not found in f-stack") - frame_bury = self._f_stack.index(value) + frame_bury = self.f_stack.index(value) bury = self._get_f_stack_dig_bury(value) - self._l_stack.pop() - if self._use_frame: + self.l_stack.pop() + if self.use_frame: return teal.FrameBury(frame_bury) return teal.Bury(bury) def _insert_f_stack(self, value: str) -> teal.Cover | teal.FrameBury | teal.Bury: """Updates the stack, and if insert returns the cover value, else the bury value""" - if value in self._f_stack: + if value in self.f_stack: raise self._stack_error(f"Could not insert {value} as it is already in f-stack") # inserting something at the top of the f-stack # is equivalent to inserting at the bottom of the x-stack - cover = self._get_x_stack_cover_n() - self._l_stack.pop() - self._f_stack.append(value) + cover = len(self.x_stack) + len(self.l_stack) - 1 + self.l_stack.pop() + self.f_stack.append(value) return teal.Cover(cover) def visit_store_f_stack(self, store: ops.StoreFStack) -> list[teal.TealOp]: - if not self._l_stack: + if not self.l_stack: self._stack_error(f"l-stack is empty, can not store {store.local_id} to f-stack") if store.insert: @@ -197,51 +149,51 @@ def visit_store_f_stack(self, store: ops.StoreFStack) -> list[teal.TealOp]: def visit_load_f_stack(self, load: ops.LoadFStack) -> list[teal.TealOp]: local_id = load.local_id - if local_id not in self._f_stack: + if local_id not in self.f_stack: self._stack_error(f"{local_id} not found in f-stack") - frame_dig = self._f_stack.index(local_id) + frame_dig = self.f_stack.index(local_id) dig = self._get_f_stack_dig_bury(local_id) - self._l_stack.append(local_id) - if self._use_frame: + self.l_stack.append(local_id) + if self.use_frame: return [teal.FrameDig(frame_dig)] return [teal.Dig(dig) if dig else teal.Dup()] def visit_store_x_stack(self, store: ops.StoreXStack) -> list[teal.TealOp]: local_id = store.local_id - if not self._l_stack: + if not self.l_stack: self._stack_error(f"l-stack too small to store {local_id} to x-stack") # re-alias top of l-stack self._l_stack_assign_name(local_id) - cover = self.x_l_stack_height - 1 - var = self._l_stack[-1] if store.copy else self._l_stack.pop() - self._x_stack.insert(0, var) + cover = len(self.x_stack) + len(self.l_stack) - 1 + var = self.l_stack[-1] if store.copy else self.l_stack.pop() + self.x_stack.insert(0, var) if store.copy: return [teal.Dup(), teal.Cover(cover)] return [teal.Cover(cover)] def visit_load_x_stack(self, load: ops.LoadXStack) -> list[teal.TealOp]: local_id = load.local_id - if local_id not in self._x_stack: + if local_id not in self.x_stack: self._stack_error(f"{local_id} not found in x-stack") - index = self._x_stack.index(local_id) - uncover = self.l_stack_height + (self.x_stack_height - index - 1) - self._x_stack.pop(index) - self._l_stack.append(local_id) + index = self.x_stack.index(local_id) + uncover = len(self.l_stack) + (len(self.x_stack) - index - 1) + self.x_stack.pop(index) + self.l_stack.append(local_id) return [teal.Uncover(uncover)] def visit_store_l_stack(self, store: ops.StoreLStack) -> list[teal.TealOp]: cover = store.cover - if cover >= self.l_stack_height: + if cover >= len(self.l_stack): self._stack_error( f"l-stack too small to store (cover {cover}) {store.local_id} to l-stack" ) # re-alias top of l-stack self._l_stack_assign_name(store.local_id) - index = self.l_stack_height - cover - 1 - var = self._l_stack[-1] if store.copy else self._l_stack.pop() - self._l_stack.insert(index, var) + index = len(self.l_stack) - cover - 1 + var = self.l_stack[-1] if store.copy else self.l_stack.pop() + self.l_stack.insert(index, var) if store.copy: result: list[teal.TealOp] = [teal.Dup()] if cover > 0: @@ -251,33 +203,33 @@ def visit_store_l_stack(self, store: ops.StoreLStack) -> list[teal.TealOp]: def visit_load_l_stack(self, load: ops.LoadLStack) -> list[teal.TealOp]: local_id = load.local_id - if local_id not in self._l_stack: + if local_id not in self.l_stack: self._stack_error(f"{local_id} not found in l-stack") - index = self._l_stack.index(local_id) - uncover = self.l_stack_height - index - 1 + index = self.l_stack.index(local_id) + uncover = len(self.l_stack) - index - 1 if load.copy: - self._l_stack.append(local_id) + self.l_stack.append(local_id) return [teal.Dup() if uncover == 0 else teal.Dig(uncover)] else: - self._l_stack.pop(index) - self._l_stack.append(local_id) + self.l_stack.pop(index) + self.l_stack.append(local_id) return [teal.Uncover(uncover)] def visit_load_param(self, load: ops.LoadParam) -> list[teal.TealOp]: - if load.local_id not in self._parameters: + if load.local_id not in self.parameters: self._stack_error(f"{load.local_id} is not a parameter") - self._l_stack.append(load.local_id) + self.l_stack.append(load.local_id) return [teal.FrameDig(load.index)] def visit_store_param(self, store: ops.StoreParam) -> list[teal.TealOp]: - if not self._l_stack: + if not self.l_stack: self._stack_error(f"l-stack too small to store param {store.local_id}") - if store.local_id not in self._parameters: + if store.local_id not in self.parameters: self._stack_error(f"{store.local_id} is not a parameter") self._l_stack_assign_name(store.local_id) if not store.copy: - self._l_stack.pop() + self.l_stack.pop() return [teal.FrameBury(store.index)] else: return [teal.Dup(), teal.FrameBury(store.index)] @@ -286,7 +238,7 @@ def visit_proto(self, proto: ops.Proto) -> list[teal.TealOp]: return [teal.Proto(proto.parameters, proto.returns)] def visit_allocate(self, allocate: ops.Allocate) -> list[teal.TealOp]: - self._f_stack.extend(allocate.allocate_on_entry) + self.f_stack.extend(allocate.allocate_on_entry) def push_n(value_op: teal.TealOp, n: int) -> list[teal.TealOp]: match n: @@ -307,14 +259,14 @@ def push_n(value_op: teal.TealOp, n: int) -> list[teal.TealOp]: ] def visit_pop(self, pop: ops.Pop) -> list[teal.TealOp]: - if self.l_stack_height < pop.n: + if len(self.l_stack) < pop.n: self._stack_error(f"l-stack too small too pop {pop.n}") for _ in range(pop.n): - self._l_stack.pop() + self.l_stack.pop() return [teal.PopN(pop.n) if pop.n > 1 else teal.Pop()] def visit_callsub(self, callsub: ops.CallSub) -> list[teal.TealOp]: - if self.l_stack_height < callsub.parameters: + if len(self.l_stack) < callsub.parameters: self._stack_error(f"l-stack too small to call {callsub}") op_code = f"{{{callsub.target}}}" if callsub.returns > 1: @@ -325,19 +277,19 @@ def visit_callsub(self, callsub: ops.CallSub) -> list[teal.TealOp]: produces = [] for _ in range(callsub.parameters): - self._l_stack.pop() - self._l_stack.extend(produces) + self.l_stack.pop() + self.l_stack.extend(produces) return [teal.CallSub(target=callsub.target)] def visit_retsub(self, retsub: ops.RetSub) -> list[teal.TealOp]: - if self.l_stack_height != retsub.returns: + if len(self.l_stack) != retsub.returns: self._stack_error( f"Inconsistent l-stack height for retsub. Expected {retsub.returns}, " - f"actual {self.l_stack_height}" + f"actual {len(self.l_stack)}" ) - sub_l_stack_height = self.f_stack_height + self.x_stack_height + sub_l_stack_height = len(self.f_stack) + len(self.x_stack) if retsub.returns < sub_l_stack_height: # move returns to base of frame in order ret_ops: list[teal.TealOp] = [ @@ -352,11 +304,13 @@ def visit_retsub(self, retsub: ops.RetSub) -> list[teal.TealOp]: # called and discards anything above. # represent this in the virtual stack with a new stack state with only the current # l-stack (i.e. discard all values in parameters, f-stack and x-stack) - self.state = _StackState(l_stack=self._l_stack) + self.parameters = [] + self.f_stack = [] + self.x_stack = [] return ret_ops def visit_intrinsic(self, intrinsic: ops.IntrinsicOp) -> list[teal.TealOp]: - if intrinsic.consumes > self.l_stack_height: + if intrinsic.consumes > len(self.l_stack): self._stack_error( f"l-stack too small to provide {intrinsic.consumes} arg/s for {intrinsic}: " ) @@ -371,8 +325,8 @@ def visit_intrinsic(self, intrinsic: ops.IntrinsicOp) -> list[teal.TealOp]: produces = [] for _ in range(intrinsic.consumes): - self._l_stack.pop() - self._l_stack.extend(produces) + self.l_stack.pop() + self.l_stack.extend(produces) return [ teal.Intrinsic( @@ -397,11 +351,5 @@ def visit_virtual_stack(self, virtual: ops.VirtualStackOp) -> list[teal.TealOp]: original.accept(self) return [*(virtual.replacement or ())] - def clone(self) -> "Stack": - # deep copy stack state - state = deepcopy(self.state) - # share instances of other fields such as current_subroutine and _vla - return attrs.evolve(self, state=state) - def __str__(self) -> str: return self.full_stack_desc diff --git a/src/puya/codegen/stack_assignment.py b/src/puya/codegen/stack_assignment.py index 2ad05cba82..bd89e08e6d 100644 --- a/src/puya/codegen/stack_assignment.py +++ b/src/puya/codegen/stack_assignment.py @@ -16,8 +16,9 @@ def global_stack_assignment( context: ProgramCodeGenContext, subroutines: list[ops.MemorySubroutine] ) -> None: for subroutine in subroutines: - koopmans(context, subroutine) - baileys(context, subroutine) - allocate_locals_on_stack(context, subroutine) + sub_ctx = context.for_subroutine(subroutine) + koopmans(sub_ctx) + baileys(sub_ctx) + allocate_locals_on_stack(sub_ctx) if context.options.optimization_level > 0: - simplify_teal_ops(context, subroutine) + simplify_teal_ops(sub_ctx) diff --git a/src/puya/codegen/stack_baileys.py b/src/puya/codegen/stack_baileys.py index f0701a99b2..6428d9929a 100644 --- a/src/puya/codegen/stack_baileys.py +++ b/src/puya/codegen/stack_baileys.py @@ -5,9 +5,8 @@ import structlog from puya.codegen import ops -from puya.codegen.context import ProgramCodeGenContext +from puya.codegen.context import SubroutineCodeGenContext from puya.codegen.stack_koopmans import peephole_optimization -from puya.codegen.vla import VariableLifetimeAnalysis from puya.errors import InternalError logger = structlog.get_logger(__name__) @@ -188,9 +187,9 @@ def get_edge_set(block: BlockRecord) -> EdgeSet | None: return EdgeSet(out_blocks, in_blocks) if in_blocks else None -def get_edge_sets( - subroutine: ops.MemorySubroutine, vla: VariableLifetimeAnalysis -) -> Sequence[EdgeSet]: +def get_edge_sets(ctx: SubroutineCodeGenContext) -> Sequence[EdgeSet]: + subroutine = ctx.subroutine + vla = ctx.vla records = { block: BlockRecord( block=block, @@ -254,11 +253,12 @@ def get_edge_sets( return list(edge_sets.keys()) -def schedule_sets(edge_sets: Sequence[EdgeSet], vla: VariableLifetimeAnalysis) -> None: +def schedule_sets(ctx: SubroutineCodeGenContext, edge_sets: Sequence[EdgeSet]) -> None: # determine all blocks referencing variables, so we can track if all references to a # variable are scheduled to x-stack stores = dict[str, set[ops.MemoryBasicBlock]]() loads = dict[str, set[ops.MemoryBasicBlock]]() + vla = ctx.vla for variable in vla.all_variables: stores[variable] = vla.get_store_blocks(variable) loads[variable] = vla.get_load_blocks(variable) @@ -315,6 +315,7 @@ def schedule_sets(edge_sets: Sequence[EdgeSet], vla: VariableLifetimeAnalysis) - ) if variables_successfully_scheduled: + ctx.invalidate_vla() logger.debug( f"Allocated {len(variables_successfully_scheduled)} " f"variable/s to x-stack: {', '.join(variables_successfully_scheduled)}" @@ -346,18 +347,17 @@ def validate_x_stacks(edge_sets: Sequence[EdgeSet]) -> bool: return ok -def baileys(_context: ProgramCodeGenContext, subroutine: ops.MemorySubroutine) -> None: - vla = VariableLifetimeAnalysis.analyze(subroutine) - edge_sets = get_edge_sets(subroutine, vla) +def baileys(ctx: SubroutineCodeGenContext) -> None: + edge_sets = get_edge_sets(ctx) if not edge_sets: # nothing to do return - logger.debug(f"Found {len(edge_sets)} edge set/s for {subroutine.signature.name}") - schedule_sets(edge_sets, vla) + logger.debug(f"Found {len(edge_sets)} edge set/s for {ctx.subroutine.signature.name}") + schedule_sets(ctx, edge_sets) if not validate_x_stacks(edge_sets): raise InternalError("Could not schedule x-stack") add_x_stack_ops_to_edge_sets(edge_sets) - peephole_optimization(subroutine) + peephole_optimization(ctx) diff --git a/src/puya/codegen/stack_frame_allocation.py b/src/puya/codegen/stack_frame_allocation.py index 1dbbe12882..fef8ad5ab0 100644 --- a/src/puya/codegen/stack_frame_allocation.py +++ b/src/puya/codegen/stack_frame_allocation.py @@ -5,9 +5,8 @@ from puya.avm_type import AVMType from puya.codegen import ops -from puya.codegen.context import ProgramCodeGenContext +from puya.codegen.context import SubroutineCodeGenContext from puya.codegen.stack_koopmans import peephole_optimization -from puya.codegen.vla import VariableLifetimeAnalysis logger = structlog.get_logger(__name__) @@ -73,14 +72,13 @@ def get_allocate_op( ) -def allocate_locals_on_stack( - _context: ProgramCodeGenContext, subroutine: ops.MemorySubroutine -) -> None: - vla = VariableLifetimeAnalysis.analyze(subroutine) +def allocate_locals_on_stack(ctx: SubroutineCodeGenContext) -> None: + vla = ctx.vla all_variables = vla.all_variables if not all_variables: return + subroutine = ctx.subroutine first_store_ops = get_lazy_fstack(subroutine) allocate_on_first_store = [op.local_id for op in first_store_ops] @@ -95,6 +93,7 @@ def allocate_locals_on_stack( for block in subroutine.body[1:]: block.f_stack_in = [*allocate_at_entry, *allocate_on_first_store] + removed_virtual = False for block in subroutine.body: for index, op in enumerate(block.ops): match op: @@ -107,6 +106,7 @@ def allocate_locals_on_stack( insert=op in first_store_ops, atype=atype, ) + removed_virtual = True case ops.LoadVirtual( local_id=local_id, source_location=src_location, @@ -117,6 +117,9 @@ def allocate_locals_on_stack( source_location=src_location, atype=atype, ) + removed_virtual = True case ops.RetSub() as retsub: block.ops[index] = attrs.evolve(retsub, f_stack_size=len(all_variables)) - peephole_optimization(subroutine) + if removed_virtual: + ctx.invalidate_vla() + peephole_optimization(ctx) diff --git a/src/puya/codegen/stack_koopmans.py b/src/puya/codegen/stack_koopmans.py index 5c5ddd8180..bd07750292 100644 --- a/src/puya/codegen/stack_koopmans.py +++ b/src/puya/codegen/stack_koopmans.py @@ -6,9 +6,8 @@ import structlog from puya.codegen import ops, teal -from puya.codegen.context import ProgramCodeGenContext +from puya.codegen.context import SubroutineCodeGenContext from puya.codegen.stack import Stack -from puya.codegen.vla import VariableLifetimeAnalysis from puya.utils import invert_ordered_binary_op logger = structlog.get_logger(__name__) @@ -92,7 +91,7 @@ def copy_usage_pairs( insert_stack_in = _get_stack_after_op(subroutine, block, insert_index - 1) dup = ops.StoreLStack( - cover=insert_stack_in.get_l_stack_cover_n(), + cover=len(insert_stack_in.l_stack) - 1, local_id=local_id, source_location=a.source_location, atype=a.atype, @@ -117,158 +116,163 @@ def copy_usage_pairs( logger.debug(f"Replaced {block.block_name}.ops[{b_index}]: '{b}' with '{uncover}'") -def _copy_and_apply_ops(stack: Stack, *maybe_ops: ops.BaseOp | None) -> Stack: - stack = stack.clone() - for op in filter(None, maybe_ops): - op.accept(stack) - return stack - - def is_stack_swap(stack_before_op: Stack, op: ops.MemoryOp) -> bool: - teal_ops = op.accept(stack_before_op.clone()) + teal_ops = op.accept(stack_before_op.copy()) match teal_ops: case [teal.Cover(1) | teal.Uncover(1)]: return True return False -def is_virtual_op(stack_before_op: Stack, op: ops.MemoryOp) -> bool: - teal_ops = op.accept(stack_before_op.clone()) - match teal_ops: - case [teal.Cover(0) | teal.Uncover(0)]: - return True - return False - - def optimize_single(stack_before_a: Stack, a: ops.BaseOp) -> ops.BaseOp | None: - match a: - case ops.MemoryOp() as mem_a if is_virtual_op(stack_before_a, mem_a): - return ops.VirtualStackOp(mem_a) + if isinstance(a, ops.MemoryOp): + teal_ops = a.accept(stack_before_a.copy()) + match teal_ops: + case [teal.Cover(0) | teal.Uncover(0)]: + return ops.VirtualStackOp(a) return a def is_redundant_rotate( - stack_before_a: Stack, a: ops.MemoryOp, stack_before_b: Stack, b: ops.MemoryOp + stack_before_a: Stack, a: ops.MemoryOp, maybe_virtual: ops.BaseOp | None, b: ops.MemoryOp ) -> bool: - a_teal = a.accept(stack_before_a.clone()) - b_teal = b.accept(stack_before_b.clone()) - match a_teal, b_teal: - case [teal.Cover(n=a_n)], [teal.Uncover(n=b_n)] if a_n == b_n: + stack = stack_before_a.copy() + a_teal = a.accept(stack) + try: + (a_op,) = a_teal + except ValueError: + return False + + # optimization: the virtual op is applied here instead of outside optimize_pair + # as it is a hot path so deferring it until it is actually required saves some time + + if maybe_virtual: + maybe_virtual.accept(stack) + b_teal = b.accept(stack) + try: + (b_op,) = b_teal + except ValueError: + return False + match a_op, b_op: + case teal.Cover(n=a_n), teal.Uncover(n=b_n) if a_n == b_n: + return True + case teal.Uncover(n=a_n), teal.Cover(n=b_n) if a_n == b_n: return True - case [teal.Uncover(n=a_n)], [teal.Cover(n=b_n)] if a_n == b_n: + case teal.Cover(n=1), teal.Cover(n=1): + return True + case teal.Uncover(n=1), teal.Uncover(n=1): return True return False -COMMUTATIVE_OPS = { - "+", - "*", - "&", - "&&", - "|", - "||", - "^", - "==", - "!=", - "b*", - "b+", - "b&", - "b|", - "b^", - "b==", - "b!=", - "addw", - "mulw", -} +COMMUTATIVE_OPS = frozenset( + [ + "+", + "*", + "&", + "&&", + "|", + "||", + "^", + "==", + "!=", + "b*", + "b+", + "b&", + "b|", + "b^", + "b==", + "b!=", + "addw", + "mulw", + ] +) +ORDERING_OPS = frozenset(["<", "<=", ">", ">=", "b<", "b<=", "b>", "b>="]) def optimize_pair( - vla: VariableLifetimeAnalysis, - stack_before_a: Stack, + ctx: SubroutineCodeGenContext, + stack: Stack, # stack state before a a: ops.BaseOp, - stack_before_b: Stack, + maybe_virtual: ops.BaseOp | None, # represents virtual ops that may be between a and b b: ops.BaseOp, ) -> tuple[()] | tuple[ops.BaseOp] | tuple[ops.BaseOp, ops.BaseOp]: """Given a pair of ops, returns which ops should be kept including replacements""" - match a, b: - case ops.StoreLStack(copy=True) | ops.StoreXStack(copy=True) as cover, ops.StoreVirtual( - local_id=local_id - ) if local_id not in vla.get_live_out_variables( - b - ): # aka dead store removal, this should handle both x-stack and l-stack cases - # StoreLStack is used to: - # 1.) store a variable for retrieval later via a load - # 2.) store a copy at the bottom of the stack for use in a later op - # If it is a dead store, then the 1st scenario is no longer needed - # and instead just need to ensure the value is copied to the bottom of the stack - return (attrs.evolve(cover, copy=False),) - case _, ops.StoreVirtual(local_id=local_id) if local_id not in vla.get_live_out_variables( - b - ): # aka dead store removal - return a, ops.Pop(n=1, source_location=b.source_location) - case ops.LoadLStack(local_id=a_local_id, copy=False) as load, ops.StoreLStack( - local_id=b_local_id, copy=True - ) if a_local_id == b_local_id: - return (attrs.evolve(load, copy=True),) - case ops.LoadLStack(copy=False) as load, ops.StoreLStack( - copy=False - ) as store if is_redundant_rotate(stack_before_a, load, stack_before_b, store): - # loading and storing to the same spot in the same stack can be removed entirely if the - # local_id does not change - if load.local_id == store.local_id: - return () - # otherwise keep around as virtual stack op - else: - return ops.VirtualStackOp(load), ops.VirtualStackOp(store) - case ops.LoadOp() as mem_a, ops.Pop(n=1) as mem_b: - return ops.VirtualStackOp(mem_a), ops.VirtualStackOp(mem_b) - case ops.LoadXStack(local_id=a_local_id), ops.StoreXStack( - local_id=b_local_id, copy=False - ) if a_local_id == b_local_id: - return () - case ops.LoadFStack(local_id=a_local_id), ops.StoreFStack( - local_id=b_local_id - ) if a_local_id == b_local_id: - return () - case ops.MemoryOp() as a_mem, ops.MemoryOp() as b_mem if is_stack_swap( - stack_before_a, a_mem - ) and is_stack_swap(stack_before_b, b_mem): - return ops.VirtualStackOp(a_mem), ops.VirtualStackOp(b_mem) - case ops.MemoryOp() as a_mem, ops.MemoryOp() as b_mem if is_redundant_rotate( - stack_before_a, a_mem, stack_before_b, b_mem - ): - return ops.VirtualStackOp(a_mem), ops.VirtualStackOp(b_mem) - case ops.MemoryOp() as a_mem, ops.IntrinsicOp(op_code=op_code) if is_stack_swap( - stack_before_a, a_mem - ) and op_code in COMMUTATIVE_OPS: + # this function has been optimized to reduce the number of isinstance checks, + # consider this when making any modifications + + if isinstance(b, ops.StoreVirtual) and b.local_id not in ctx.vla.get_live_out_variables(b): + # aka dead store removal + match a: + case ops.StoreLStack(copy=True) | ops.StoreXStack(copy=True) as cover: + # this should handle both x-stack and l-stack cases StoreLStack is used to: + # 1.) store a variable for retrieval later via a load + # 2.) store a copy at the bottom of the stack for use in a later op + # If it is a dead store, then the 1st scenario is no longer needed + # and instead just need to ensure the value is moved to the bottom of the stack + return (attrs.evolve(cover, copy=False),) + return a, ops.Pop(n=1, source_location=b.source_location) + + # optimization: cases after here are only applicable if "a" is a MemoryOp + if not isinstance(a, ops.MemoryOp): + return a, b + + if isinstance(b, ops.Pop) and b.n == 1 and isinstance(a, ops.LoadOp): + return ops.VirtualStackOp(a), ops.VirtualStackOp(b) + + if isinstance(b, ops.IntrinsicOp) and is_stack_swap(stack, a): + if b.op_code in COMMUTATIVE_OPS: if isinstance(a, ops.LoadLStack | ops.StoreLStack): return (b,) - return ops.VirtualStackOp(a_mem), b - case ops.MemoryOp() as a_mem, ops.IntrinsicOp( - op_code=("<" | "<=" | ">" | ">=" | "b<" | "b<=" | "b>" | "b>=") as op_code - ) as binary_op if is_stack_swap(stack_before_a, a_mem): - inverse_ordering_op = invert_ordered_binary_op(op_code) - new_b = attrs.evolve(binary_op, op_code=inverse_ordering_op) + return ops.VirtualStackOp(a), b + elif b.op_code in ORDERING_OPS: + inverse_ordering_op = invert_ordered_binary_op(b.op_code) + new_b = attrs.evolve(b, op_code=inverse_ordering_op) if isinstance(a, ops.LoadLStack | ops.StoreLStack): return (new_b,) - return ops.VirtualStackOp(a_mem), new_b - case ( - ops.LoadVirtual() as load, - ops.StoreVirtual() as store, - ) if load.local_id == store.local_id: - return () - case ( - ops.StoreParam(copy=False) as store_param, - ops.LoadParam() as load_param, - ) if load_param.local_id == store_param.local_id: - # if we have a store to param and then read from param, - # we can reduce the program size byte 1 byte by copying - # and then storing instead - # i.e. frame_bury -x; frame_dig -x - # => dup; frame_bury -x - store_with_copy = attrs.evolve(store_param, copy=True) - return (store_with_copy,) + return ops.VirtualStackOp(a), new_b + + # optimization: cases after here are only applicable if "b" is a MemoryOp + if not isinstance(b, ops.MemoryOp): + return a, b + + if is_redundant_rotate(stack, a, maybe_virtual, b): + match a, b: + case ( + ops.LoadLStack(copy=False, local_id=a_local_id), + ops.StoreLStack(copy=False, local_id=b_local_id), + ) if a_local_id == b_local_id: + # loading and storing to the same spot in the same stack can be removed entirely + # if the local_id does not change + return () + # otherwise keep around as virtual stack op + return ops.VirtualStackOp(a), ops.VirtualStackOp(b) + + if isinstance(a, ops.LoadOp) and isinstance(b, ops.StoreOp): + if a.local_id == b.local_id: + match a, b: + case ops.LoadLStack(copy=False) as load, ops.StoreLStack(copy=True): + return (attrs.evolve(load, copy=True),) + case ops.LoadXStack(), ops.StoreXStack(copy=False): + return () + case ops.LoadFStack(), ops.StoreFStack(): + return () + case ops.LoadVirtual(), ops.StoreVirtual(): + return () + else: + match a, b: + case ( + ops.StoreParam(copy=False, local_id=a_local_id) as store_param, + ops.LoadParam(local_id=b_local_id), + ) if a_local_id == b_local_id: + # if we have a store to param and then read from param, + # we can reduce the program size byte 1 byte by copying + # and then storing instead + # i.e. frame_bury -x; frame_dig -x + # => dup; frame_bury -x + store_with_copy = attrs.evolve(store_param, copy=True) + return (store_with_copy,) return a, b @@ -305,7 +309,10 @@ def _merge_virtual_ops(maybe_virtuals: Sequence[ops.BaseOp]) -> Sequence[ops.Bas virtuals = list[ops.VirtualStackOp]() # final None will trigger merging any remaining virtuals for op in [*maybe_virtuals, None]: - if isinstance(op, ops.VirtualStackOp): # collect virtual ops + # note: uses type instead of isinstance because this is a + # hotspot as determined by profiling. VirtualStackOp + # has been annotated with @typing.final so that this is equivalent here + if type(op) is ops.VirtualStackOp: # collect virtual ops virtuals.append(op) continue if virtuals: # merge any existing virtuals if non-virtual found @@ -323,21 +330,31 @@ def _merge_virtual_ops(maybe_virtuals: Sequence[ops.BaseOp]) -> Sequence[ops.Bas return result -def peephole_optimization(subroutine: ops.MemorySubroutine) -> None: +def peephole_optimization(ctx: SubroutineCodeGenContext) -> None: # replace sequences of stack manipulations with shorter ones - vla = VariableLifetimeAnalysis.analyze(subroutine) - for block in subroutine.body: - while peephole_optimization_single(subroutine, vla, block): - pass + vla_modified = False + for block in ctx.subroutine.body: + while (result := peephole_optimization_single(ctx, block)) and result.modified: + vla_modified = vla_modified or result.vla_modified + vla_modified = vla_modified or result.vla_modified + if vla_modified: + ctx.invalidate_vla() + + +@attrs.define(kw_only=True) +class PeepholeResult: + modified: bool + vla_modified: bool def peephole_optimization_single( - subroutine: ops.MemorySubroutine, vla: VariableLifetimeAnalysis, block: ops.MemoryBasicBlock -) -> bool: + ctx: SubroutineCodeGenContext, block: ops.MemoryBasicBlock +) -> PeepholeResult: result = list[ops.BaseOp]() op_iter = ManualIter(block.ops) b = op_iter.next() - stack = Stack.for_full_stack(subroutine, block) + stack = Stack.for_full_stack(ctx.subroutine, block) + vla_modified = False while b: a = b b = op_iter.next() @@ -357,14 +374,15 @@ def peephole_optimization_single( b = op_iter.next() # if b: - stack_before_a = stack - stack_before_b = _copy_and_apply_ops(stack_before_a, a, maybe_virtual) - ops_to_keep: Sequence[ops.BaseOp] = optimize_pair( - vla, stack_before_a, a, stack_before_b, b - ) + ops_to_keep: Sequence[ops.BaseOp] = optimize_pair(ctx, stack, a, maybe_virtual, b) else: ops_to_keep = (a,) - + if ( + not vla_modified + and (a not in ops_to_keep and isinstance(a, ops.StoreVirtual | ops.LoadVirtual)) + or (b not in ops_to_keep and isinstance(b, ops.StoreVirtual | ops.LoadVirtual)) + ): + vla_modified = True # based on peephole optimization result, insert virtual op if maybe_virtual is not None: if len(ops_to_keep) == 2: @@ -390,12 +408,11 @@ def peephole_optimization_single( before = block.ops block.ops = result - return before != result + return PeepholeResult(modified=before != result, vla_modified=vla_modified) -def koopmans(_context: ProgramCodeGenContext, subroutine: ops.MemorySubroutine) -> None: - peephole_optimization(subroutine) - for block in subroutine.body: +def koopmans(ctx: SubroutineCodeGenContext) -> None: + for block in ctx.subroutine.body: usage_pairs = find_usage_pairs(block) - copy_usage_pairs(subroutine, block, usage_pairs) - peephole_optimization(subroutine) + copy_usage_pairs(ctx.subroutine, block, usage_pairs) + peephole_optimization(ctx) diff --git a/src/puya/codegen/stack_simplify_teal.py b/src/puya/codegen/stack_simplify_teal.py index f88163722e..ef4fa1241f 100644 --- a/src/puya/codegen/stack_simplify_teal.py +++ b/src/puya/codegen/stack_simplify_teal.py @@ -9,10 +9,9 @@ import structlog from puya.codegen import ops, teal -from puya.codegen.context import ProgramCodeGenContext +from puya.codegen.context import SubroutineCodeGenContext from puya.codegen.stack import Stack from puya.codegen.stack_koopmans import peephole_optimization_single -from puya.codegen.vla import VariableLifetimeAnalysis from puya.errors import InternalError logger = structlog.get_logger(__name__) @@ -510,14 +509,15 @@ def try_simplify_rotation_ops( maybe_remove_rotations = [] -def simplify_teal_ops(context: ProgramCodeGenContext, subroutine: ops.MemorySubroutine) -> None: - vla = VariableLifetimeAnalysis.analyze(subroutine) +def simplify_teal_ops(ctx: SubroutineCodeGenContext) -> None: + subroutine = ctx.subroutine for block in subroutine.body: modified = True while modified: - modified = peephole_optimization_single(subroutine, vla, block) + result = peephole_optimization_single(ctx, block) + modified = result.modified modified = modified or try_simplify_repeated_ops(subroutine, block) modified = modified or try_simplify_pairwise_ops(subroutine, block) modified = modified or try_simplify_triple_ops(subroutine, block) - if context.options.optimization_level >= 2: + if ctx.options.optimization_level >= 2: try_simplify_rotation_ops(subroutine, block) diff --git a/src/puya/codegen/vla.py b/src/puya/codegen/vla.py index 107736d543..a7cb30eb82 100644 --- a/src/puya/codegen/vla.py +++ b/src/puya/codegen/vla.py @@ -14,7 +14,8 @@ class _OpLifetime: block: ops.MemoryBasicBlock used: StableSet[str] = attrs.field(on_setattr=attrs.setters.frozen) defined: StableSet[str] = attrs.field(on_setattr=attrs.setters.frozen) - successors: Sequence[ops.BaseOp] = attrs.field(on_setattr=attrs.setters.frozen) + successors: "Sequence[_OpLifetime]" = attrs.field(default=()) + predecessors: "list[_OpLifetime]" = attrs.field(factory=list) live_in: StableSet[str] = attrs.field(factory=StableSet) live_out: StableSet[str] = attrs.field(factory=StableSet) @@ -39,24 +40,28 @@ def _op_lifetimes_factory(self) -> dict[ops.BaseOp, _OpLifetime]: result = dict[ops.BaseOp, _OpLifetime]() block_map = {b.block_name: b.ops[0] for b in self.subroutine.body} for block in self.subroutine.all_blocks: - for op, next_op in itertools.zip_longest(block.ops, block.ops[1:]): + for op in block.ops: used = StableSet[str]() defined = StableSet[str]() if isinstance(op, ops.StoreVirtual): defined.add(op.local_id) elif isinstance(op, ops.LoadVirtual): used.add(op.local_id) - if next_op is None: - # for last op, add first op of each successor block - successors = [block_map[s] for s in block.successors] - else: - successors = [next_op] result[op] = _OpLifetime( block=block, used=used, defined=defined, - successors=successors, ) + for block in self.subroutine.all_blocks: + for op, next_op in itertools.zip_longest(block.ops, block.ops[1:]): + op_lifetime = result[op] + if next_op is None: + # for last op, add first op of each successor block + op_lifetime.successors = tuple(result[block_map[s]] for s in block.successors) + else: + op_lifetime.successors = (result[next_op],) + for s in op_lifetime.successors: + s.predecessors.append(op_lifetime) return result def get_live_out_variables(self, op: ops.BaseOp) -> Set[str]: @@ -78,22 +83,23 @@ def analyze(cls, subroutine: ops.MemorySubroutine) -> typing.Self: return analysis def _analyze(self) -> None: - changes = True - while changes: - changes = False - for n in self._op_lifetimes.values(): + changed = list(self._op_lifetimes.values()) + while changed: + orig_changed = changed + changed = [] + for n in orig_changed: # For OUT, find out the union of previous variables # in the IN set for each succeeding node of n. # out[n] = U s ∈ succ[n] in[s] live_out = StableSet[str]() for s in n.successors: - live_out |= self._op_lifetimes[s].live_in + live_out |= s.live_in # in[n] = use[n] U (out[n] - def [n]) live_in = n.used | (live_out - n.defined) - if not (live_in == n.live_in and live_out == n.live_out): + if live_out != n.live_out or live_in != n.live_in: n.live_in = live_in n.live_out = live_out - changes = True + changed.extend(n.predecessors) diff --git a/src/puya/utils.py b/src/puya/utils.py index 8991324518..cf41821ee7 100644 --- a/src/puya/utils.py +++ b/src/puya/utils.py @@ -69,29 +69,31 @@ def unique(items: Iterable[T]) -> list[T]: class StableSet(MutableSet[T]): + __slots__ = ("_data",) + def __init__(self, *items: T) -> None: self._data = dict.fromkeys(items) def __eq__(self, other: object) -> bool: if isinstance(other, StableSet): - return self._data == other._data + return self._data.__eq__(other._data) else: return self._data.keys() == other def __ne__(self, other: object) -> bool: if isinstance(other, StableSet): - return self._data != other._data + return self._data.__ne__(other._data) else: return self._data.keys() != other def __contains__(self, x: object) -> bool: - return x in self._data + return self._data.__contains__(x) def __len__(self) -> int: - return len(self._data) + return self._data.__len__() def __iter__(self) -> Iterator[T]: - yield from self._data.keys() + return self._data.__iter__() def add(self, value: T) -> None: self._data[value] = None @@ -100,24 +102,30 @@ def discard(self, value: T) -> None: self._data.pop(value, None) def __or__(self, other: Iterable[T]) -> "StableSet[T]": # type: ignore[override] + result = StableSet.__new__(StableSet) if isinstance(other, StableSet): - return StableSet(*self._data, *other._data) + other_data = other._data else: - return StableSet(*self._data, *other) + other_data = dict.fromkeys(other) + result._data = self._data | other_data + return result def __ior__(self, other: Iterable[T]) -> Self: # type: ignore[override] if isinstance(other, StableSet): other_data = other._data else: other_data = dict.fromkeys(other) - self._data.update(other_data) + self._data |= other_data return self def __sub__(self, other: Set[T]) -> "StableSet[T]": + result = StableSet.__new__(StableSet) if isinstance(other, StableSet): - return StableSet(*(self._data.keys() - other._data.keys())) - data = (k for k in self._data if k not in other) - return StableSet(*data) + data: Iterable[T] = self._data.keys() - other._data.keys() + else: + data = (k for k in self._data if k not in other) + result._data = dict.fromkeys(data) + return result def __repr__(self) -> str: return type(self).__name__ + "(" + ", ".join(map(repr, self._data)) + ")"