From 359956c9e216222388aa10a001dbbd92305bcb35 Mon Sep 17 00:00:00 2001 From: Daniel McGregor Date: Tue, 12 Nov 2024 14:15:37 +0800 Subject: [PATCH] fix: ensure expressions are only evaluated once --- src/puya/ir/builder/main.py | 67 ++++++++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/src/puya/ir/builder/main.py b/src/puya/ir/builder/main.py index 31a1dd60eb..20f5d409ff 100644 --- a/src/puya/ir/builder/main.py +++ b/src/puya/ir/builder/main.py @@ -12,6 +12,7 @@ wtypes, ) from puya.awst.nodes import BigUIntBinaryOperator, UInt64BinaryOperator +from puya.awst.to_code_visitor import ToCodeVisitor from puya.awst.txn_fields import TxnField from puya.awst.wtypes import WInnerTransaction, WInnerTransactionFields from puya.errors import CodeError, InternalError @@ -92,6 +93,7 @@ def __init__( self.context = context.for_function(function, subroutine, self) self._itxn = InnerTransactionBuilder(self.context) self._single_eval_cache = dict[awst_nodes.SingleEvaluation, TExpression]() + self._visited_exprs = dict[_IdentityEquality, TExpression]() @classmethod def build_body( @@ -1022,7 +1024,7 @@ def visit_loop_continue(self, statement: awst_nodes.LoopContinue) -> TStatement: def visit_expression_statement(self, statement: awst_nodes.ExpressionStatement) -> TStatement: # NOTE: popping of ignored return values should happen at code gen time - result = statement.expr.accept(self) + result = self._visit_and_check_for_double_eval(statement.expr) if result is None: wtype = statement.expr.wtype match wtype: @@ -1195,18 +1197,61 @@ def visit_and_materialise_single( def visit_and_materialise( self, expr: awst_nodes.Expression, temp_description: str | Sequence[str] = "tmp" ) -> Sequence[Value]: - value_provider = self.visit_expr(expr) - return self.materialise_value_provider(value_provider, description=temp_description) + value_seq_or_provider = self._visit_and_check_for_double_eval(expr, temp_description) + if value_seq_or_provider is None: + raise InternalError( + "No value produced by expression IR conversion", expr.source_location + ) + return self.materialise_value_provider(value_seq_or_provider, description=temp_description) def visit_expr(self, expr: awst_nodes.Expression) -> ValueProvider: """Visit the expression and ensure result is not None""" - value_seq_or_provider = expr.accept(self) + value_seq_or_provider = self._visit_and_check_for_double_eval(expr) if value_seq_or_provider is None: raise InternalError( "No value produced by expression IR conversion", expr.source_location ) return value_seq_or_provider + def _visit_and_check_for_double_eval( + self, expr: awst_nodes.Expression, desc: str | Sequence[str] | None = None + ) -> ValueProvider | None: + # explicit SingleEvaluation nodes already handle this + if isinstance(expr, awst_nodes.SingleEvaluation): + return expr.accept(self) + # we need a wrapper object to provide identity equality for the expression, and + # to ensure the lifetime of the expression is as long as the cache. + # Temporary nodes may end up with the same id if nothing is referencing them + # e.g. such as used in _update_implicit_out_var + expr_id = _IdentityEquality(expr) + try: + result = self._visited_exprs[expr_id] + except KeyError: + pass + else: + if isinstance(result, ValueProvider) and not isinstance(result, ValueTuple | Value): + raise InternalError( + "double evaluation of expression without materialization", expr.source_location + ) + expr_str = expr.accept(ToCodeVisitor()) + logger.debug( + f"encountered already materialized expression ({expr_str})," + f" reusing result: {result!s}", + location=expr.source_location, + ) + return result + source = expr.accept(self) + if desc is None or source is None or not source.types or isinstance(source, Value): + result = source + else: + values = self.materialise_value_provider(source, description=desc) + if len(values) == 1: + (result,) = values + else: + result = ValueTuple(values=values, source_location=expr.source_location) + self._visited_exprs[expr_id] = result + return result + def materialise_value_provider( self, provider: ValueProvider, description: str | Sequence[str] ) -> list[Value]: @@ -1337,3 +1382,17 @@ def get_comparison_op_for_wtype( raise InternalError( f"unsupported operation of {numeric_comparison_equivalent} on type of {wtype}" ) + + +class _IdentityEquality: + + def __init__(self, obj: object): + self.obj = obj + + def __hash__(self) -> int: + return id(self.obj) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _IdentityEquality): + return NotImplemented + return self.obj is other.obj