Skip to content

Commit

Permalink
feat: pop/popn collapse optimisation
Browse files Browse the repository at this point in the history
refactor: make checking for ABI router only calls (in the context of implicit return elisions) more robust

refactor: other non-functional refactorings of the implicit return feature

test: ensure there's a multi-valued explicit return combined with implicit returns being tested
  • Loading branch information
achidlow authored and tristanmenzel committed Mar 8, 2024
1 parent 7a8f949 commit 0b90505
Show file tree
Hide file tree
Showing 60 changed files with 1,748 additions and 1,594 deletions.
2 changes: 1 addition & 1 deletion examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ arc4_types/Arc4Arrays 588 376 376
arc4_types/Arc4BoolEval 569 20 20
arc4_types/Arc4BoolType 329 57 57
arc4_types/Arc4DynamicStringArray 230 112 112
arc4_types/Arc4MutableParams 344 213 211
arc4_types/Arc4MutableParams 362 222 220
arc4_types/Arc4Mutation 2803 1452 1451
arc4_types/Arc4NumericTypes 345 8 8
arc4_types/Arc4RefTypes 47 43 43
Expand Down
2 changes: 1 addition & 1 deletion examples/voting/out/VotingRoundApp.destructured.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions examples/voting/out/VotingRoundApp.ssa.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/voting/out/VotingRoundApp.ssa.opt_pass_1.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/voting/out/VotingRoundApp.ssa.opt_pass_2.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/voting/out_O2/VotingRoundApp.destructured.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions examples/voting/puya.log

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/puya/awst_build/validation/arc4_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def visit_assignment_statement(self, statement: awst_nodes.AssignmentStatement)

def _check_for_arc4_copy(self, expr: awst_nodes.Expression) -> None:
match expr.wtype:
case wtypes.ARC4Array(immutable=False) | wtypes.ARC4Struct(immutable=False):
case wtypes.ARC4Type(immutable=False):
match expr:
case (
awst_nodes.ARC4ArrayEncode()
Expand Down
8 changes: 5 additions & 3 deletions src/puya/ir/builder/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ControlOp,
Goto,
Op,
Parameter,
Register,
SubroutineReturn,
)
Expand Down Expand Up @@ -129,13 +130,14 @@ def goto_and_activate(self, block: BasicBlock) -> None:
self.goto(block)
self.activate_block(block)

def maybe_add_implicit_subroutine_return(self, implicit_returns: Sequence[Register]) -> None:
def maybe_add_implicit_subroutine_return(self, params: Sequence[Parameter]) -> None:
if not self._blocks[-1].terminated:
self.terminate(
SubroutineReturn(
result=[
self.ssa.read_variable(r.name, r.atype, self._blocks[-1])
for r in implicit_returns
self.ssa.read_variable(p.name, p.atype, self._blocks[-1])
for p in params
if p.implicit_return
],
source_location=None,
)
Expand Down
95 changes: 95 additions & 0 deletions src/puya/ir/builder/callsub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from collections.abc import Sequence

import attrs

from puya.awst import (
nodes as awst_nodes,
wtypes,
)
from puya.ir.builder._utils import reassign
from puya.ir.context import IRFunctionBuildContext
from puya.ir.models import InvokeSubroutine, Register, Value, ValueProvider, ValueTuple
from puya.ir.utils import format_tuple_index


def visit_subroutine_call_expression(
context: IRFunctionBuildContext, expr: awst_nodes.SubroutineCallExpression
) -> ValueProvider | None:
sref = context.resolve_function_reference(expr.target, expr.source_location)
target = context.subroutines[sref]

arg_lookup = _build_arg_lookup(context, expr.args)

resolved_args = []
implicit_args = []
for idx, param in enumerate(target.parameters):
arg_val = arg_lookup.get(index=idx, name=param.name)
resolved_args.append(arg_val)
if param.implicit_return:
implicit_args.append(arg_val)

invoke_expr = InvokeSubroutine(
source_location=expr.source_location, args=resolved_args, target=target
)
if not implicit_args:
return invoke_expr

return_values = list(
context.visitor.materialise_value_provider(invoke_expr, target.method_name)
)
while implicit_args:
in_arg = implicit_args.pop()
out_register = return_values.pop()
if isinstance(in_arg, Register):
reassign(
context,
source=out_register,
reg=in_arg,
source_location=expr.source_location,
)

return (
ValueTuple(values=return_values, source_location=expr.source_location)
if return_values
else None
)


@attrs.define
class _ArgLookup:
_positional_args: dict[int, Value] = attrs.field(factory=dict, init=False)
_named_args: dict[str, Value] = attrs.field(factory=dict, init=False)
_arg_idx: int = attrs.field(default=0, init=False)

def add(self, name: str | None, value: Value) -> None:
if name is None:
self._positional_args[self._arg_idx] = value
else:
self._named_args[name] = value
self._arg_idx += 1

def get(self, index: int, name: str | None) -> Value:
if name is not None:
by_name = self._named_args.get(name)
if by_name is not None:
return by_name
return self._positional_args[index]


def _build_arg_lookup(
context: IRFunctionBuildContext, args: Sequence[awst_nodes.CallArg]
) -> _ArgLookup:
lookup = _ArgLookup()
for expr_arg in args:
if not isinstance(expr_arg.value.wtype, wtypes.WTuple):
value = context.visitor.visit_and_materialise_single(expr_arg.value)
lookup.add(name=expr_arg.name, value=value)
else:
values = context.visitor.visit_and_materialise(expr_arg.value)
for tup_idx, tup_value in enumerate(values):
if expr_arg.name is None:
tup_item_name = None
else:
tup_item_name = format_tuple_index(expr_arg.name, tup_idx)
lookup.add(name=tup_item_name, value=tup_value)
return lookup
77 changes: 12 additions & 65 deletions src/puya/ir/builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
assert_value,
assign,
mkblocks,
reassign,
)
from puya.ir.builder.assignment import handle_assignment, handle_assignment_expr
from puya.ir.builder.callsub import visit_subroutine_call_expression
from puya.ir.builder.iteration import handle_for_in_loop
from puya.ir.builder.itxn import InnerTransactionBuilder
from puya.ir.context import IRBuildContext, IRFunctionBuildContext
Expand All @@ -37,7 +37,6 @@
MethodConstant,
Op,
ProgramExit,
Register,
Subroutine,
SubroutineReturn,
UInt64Constant,
Expand Down Expand Up @@ -86,9 +85,7 @@ def build_body(
insert_on_create_call(func_ctx, to=on_create)
function.body.accept(builder)
if function.return_type == wtypes.void_wtype:
func_ctx.block_builder.maybe_add_implicit_subroutine_return(
subroutine.implicit_returns
)
func_ctx.block_builder.maybe_add_implicit_subroutine_return(subroutine.parameters)
func_ctx.ssa.verify_complete()
func_ctx.block_builder.validate_block_predecessors()
result = list(func_ctx.block_builder.blocks)
Expand All @@ -108,7 +105,7 @@ def visit_copy(self, expr: puya.awst.nodes.Copy) -> TExpression:
# will effectively be a copy. We assign the copy to a new register in case it is
# mutated.
match expr.value.wtype:
case wtypes.ARC4Array() | wtypes.ARC4Struct():
case wtypes.ARC4Type(immutable=False):
# Arc4 encoded types are value types
original_value = self.visit_and_materialise_single(expr.value)
(copy,) = assign(
Expand Down Expand Up @@ -540,58 +537,7 @@ def visit_bytes_comparison_expression(
def visit_subroutine_call_expression(
self, expr: awst_nodes.SubroutineCallExpression
) -> TExpression:
sref = self.context.resolve_function_reference(expr.target, expr.source_location)
target = self.context.subroutines[sref]
# TODO: what if args are multi-valued?
args_expanded = list[tuple[str | None, Value]]()
for expr_arg in expr.args:
if not isinstance(expr_arg.value.wtype, wtypes.WTuple):
arg = self.visit_and_materialise_single(expr_arg.value)
args_expanded.append((expr_arg.name, arg))
else:
tup_args = self.visit_and_materialise(expr_arg.value)
for tup_idx, tup_arg in enumerate(tup_args):
if expr_arg.name is None:
tup_name: str | None = None
else:
tup_name = format_tuple_index(expr_arg.name, tup_idx)
args_expanded.append((tup_name, tup_arg))
target_name_to_index = {par.name: idx for idx, par in enumerate(target.parameters)}
resolved_args = [val for name, val in args_expanded]
for name, val in args_expanded:
if name is not None:
name_idx = target_name_to_index[name]
resolved_args[name_idx] = val
invoke_expr = InvokeSubroutine(
source_location=expr.source_location, args=resolved_args, target=target
)
if not target.implicit_returns:
return invoke_expr

return_values = self.materialise_value_provider(invoke_expr, "r_tmp")

for value, register in zip(
return_values[-len(target.implicit_returns) :], target.implicit_returns, strict=True
):
arg_index = target_name_to_index[register.name]
arg_value = resolved_args[arg_index]
if isinstance(arg_value, Register):
reassign(
self.context,
source_location=expr.source_location,
source=value,
reg=arg_value,
)

explicit_return_values = list(return_values[0 : -len(target.implicit_returns)])
return (
ValueTuple(
values=explicit_return_values,
source_location=expr.source_location,
)
if explicit_return_values
else None
)
return visit_subroutine_call_expression(self.context, expr)

def visit_bytes_binary_operation(self, expr: awst_nodes.BytesBinaryOperation) -> TExpression:
left = self.visit_and_materialise_single(expr.left)
Expand Down Expand Up @@ -743,14 +689,15 @@ def visit_return_statement(self, statement: awst_nodes.ReturnStatement) -> TStat
else:
result = []

for implicit_return in self.context.subroutine.implicit_returns:
result.append(
self.context.ssa.read_variable(
implicit_return.name,
implicit_return.atype,
self.context.block_builder.active_block,
for param in self.context.subroutine.parameters:
if param.implicit_return:
result.append(
self.context.ssa.read_variable(
param.name,
param.atype,
self.context.block_builder.active_block,
)
)
)
return_types = [r.atype for r in result]
if not (
len(return_types) == len(self.context.subroutine.returns)
Expand Down
1 change: 0 additions & 1 deletion src/puya/ir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
class IRBuildContext(CompileContext):
module_awsts: Mapping[str, awst_nodes.Module]
subroutines: dict[awst_nodes.Function, Subroutine]
function_call_sites: dict[awst_nodes.Function, list[awst_nodes.Function]]
embedded_funcs: Sequence[awst_nodes.Function] = attrs.field()
contract: awst_nodes.ContractFragment | None = None

Expand Down
Loading

0 comments on commit 0b90505

Please sign in to comment.