Skip to content

Commit

Permalink
fix: ensure expressions are only evaluated once
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-makerx committed Nov 14, 2024
1 parent dbf6e2e commit 359956c
Showing 1 changed file with 63 additions and 4 deletions.
67 changes: 63 additions & 4 deletions src/puya/ir/builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
wtypes,
)
from puya.awst.nodes import BigUIntBinaryOperator, UInt64BinaryOperator
from puya.awst.to_code_visitor import ToCodeVisitor
from puya.awst.txn_fields import TxnField
from puya.awst.wtypes import WInnerTransaction, WInnerTransactionFields
from puya.errors import CodeError, InternalError
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
self.context = context.for_function(function, subroutine, self)
self._itxn = InnerTransactionBuilder(self.context)
self._single_eval_cache = dict[awst_nodes.SingleEvaluation, TExpression]()
self._visited_exprs = dict[_IdentityEquality, TExpression]()

@classmethod
def build_body(
Expand Down Expand Up @@ -1022,7 +1024,7 @@ def visit_loop_continue(self, statement: awst_nodes.LoopContinue) -> TStatement:

def visit_expression_statement(self, statement: awst_nodes.ExpressionStatement) -> TStatement:
# NOTE: popping of ignored return values should happen at code gen time
result = statement.expr.accept(self)
result = self._visit_and_check_for_double_eval(statement.expr)
if result is None:
wtype = statement.expr.wtype
match wtype:
Expand Down Expand Up @@ -1195,18 +1197,61 @@ def visit_and_materialise_single(
def visit_and_materialise(
self, expr: awst_nodes.Expression, temp_description: str | Sequence[str] = "tmp"
) -> Sequence[Value]:
value_provider = self.visit_expr(expr)
return self.materialise_value_provider(value_provider, description=temp_description)
value_seq_or_provider = self._visit_and_check_for_double_eval(expr, temp_description)
if value_seq_or_provider is None:
raise InternalError(
"No value produced by expression IR conversion", expr.source_location
)
return self.materialise_value_provider(value_seq_or_provider, description=temp_description)

def visit_expr(self, expr: awst_nodes.Expression) -> ValueProvider:
"""Visit the expression and ensure result is not None"""
value_seq_or_provider = expr.accept(self)
value_seq_or_provider = self._visit_and_check_for_double_eval(expr)
if value_seq_or_provider is None:
raise InternalError(
"No value produced by expression IR conversion", expr.source_location
)
return value_seq_or_provider

def _visit_and_check_for_double_eval(
self, expr: awst_nodes.Expression, desc: str | Sequence[str] | None = None
) -> ValueProvider | None:
# explicit SingleEvaluation nodes already handle this
if isinstance(expr, awst_nodes.SingleEvaluation):
return expr.accept(self)
# we need a wrapper object to provide identity equality for the expression, and
# to ensure the lifetime of the expression is as long as the cache.
# Temporary nodes may end up with the same id if nothing is referencing them
# e.g. such as used in _update_implicit_out_var
expr_id = _IdentityEquality(expr)
try:
result = self._visited_exprs[expr_id]
except KeyError:
pass
else:
if isinstance(result, ValueProvider) and not isinstance(result, ValueTuple | Value):
raise InternalError(
"double evaluation of expression without materialization", expr.source_location
)
expr_str = expr.accept(ToCodeVisitor())
logger.debug(
f"encountered already materialized expression ({expr_str}),"
f" reusing result: {result!s}",
location=expr.source_location,
)
return result
source = expr.accept(self)
if desc is None or source is None or not source.types or isinstance(source, Value):
result = source
else:
values = self.materialise_value_provider(source, description=desc)
if len(values) == 1:
(result,) = values
else:
result = ValueTuple(values=values, source_location=expr.source_location)
self._visited_exprs[expr_id] = result
return result

def materialise_value_provider(
self, provider: ValueProvider, description: str | Sequence[str]
) -> list[Value]:
Expand Down Expand Up @@ -1337,3 +1382,17 @@ def get_comparison_op_for_wtype(
raise InternalError(
f"unsupported operation of {numeric_comparison_equivalent} on type of {wtype}"
)


class _IdentityEquality:

def __init__(self, obj: object):
self.obj = obj

def __hash__(self) -> int:
return id(self.obj)

def __eq__(self, other: object) -> bool:
if not isinstance(other, _IdentityEquality):
return NotImplemented
return self.obj is other.obj

0 comments on commit 359956c

Please sign in to comment.