diff --git a/src/puya/awst_build/eb/_utils.py b/src/puya/awst_build/eb/_utils.py index 67fcc79072..62dcfe2beb 100644 --- a/src/puya/awst_build/eb/_utils.py +++ b/src/puya/awst_build/eb/_utils.py @@ -10,7 +10,14 @@ Expression, ReinterpretCast, ) -from puya.awst_build.eb.interface import BuilderComparisonOp, InstanceBuilder +from puya.awst_build.eb.interface import ( + BuilderBinaryOp, + BuilderComparisonOp, + BuilderUnaryOp, + InstanceBuilder, + LiteralBuilder, +) +from puya.awst_build.eb.uint64 import UInt64TypeBuilder if typing.TYPE_CHECKING: from puya.parse import SourceLocation @@ -18,6 +25,21 @@ logger = log.get_logger(__name__) +def resolve_negative_literal_index( + index: InstanceBuilder, length: InstanceBuilder, location: SourceLocation +) -> InstanceBuilder: + match index: + case LiteralBuilder(value=int(int_index)) if int_index < 0: + return length.binary_op( + index.unary_op(BuilderUnaryOp.negative, location), + BuilderBinaryOp.sub, + location, + reverse=False, + ) + case _: + return index.resolve_literal(UInt64TypeBuilder(index.source_location)) + + def bool_eval_to_constant( *, value: bool, location: SourceLocation, negate: bool = False ) -> InstanceBuilder: diff --git a/src/puya/awst_build/eb/arc4/arrays.py b/src/puya/awst_build/eb/arc4/arrays.py index 58ec88d22d..cbefb032ac 100644 --- a/src/puya/awst_build/eb/arc4/arrays.py +++ b/src/puya/awst_build/eb/arc4/arrays.py @@ -23,8 +23,6 @@ SingleEvaluation, Statement, TupleExpression, - UInt64BinaryOperation, - UInt64BinaryOperator, UInt64Constant, ) from puya.awst_build import intrinsic_factory, pytypes @@ -33,7 +31,12 @@ BytesBackedInstanceExpressionBuilder, BytesBackedTypeBuilder, ) -from puya.awst_build.eb._utils import bool_eval_to_constant, compare_bytes, compare_expr_bytes +from puya.awst_build.eb._utils import ( + bool_eval_to_constant, + compare_bytes, + compare_expr_bytes, + resolve_negative_literal_index, +) from puya.awst_build.eb.arc4.base import CopyBuilder, arc4_bool_bytes from puya.awst_build.eb.factories import builder_for_instance from puya.awst_build.eb.interface import ( @@ -48,7 +51,6 @@ from puya.awst_build.eb.uint64 import UInt64ExpressionBuilder from puya.awst_build.eb.void import VoidExpressionBuilder from puya.awst_build.utils import ( - expect_operand_type, require_instance_builder, require_instance_builder_of_type, ) @@ -254,22 +256,13 @@ def iterate(self) -> Iteration: @typing.override def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBuilder: - if isinstance(index, LiteralBuilder) and isinstance(index.value, int) and index.value < 0: - index_expr: Expression = UInt64BinaryOperation( - left=require_instance_builder( - self.member_access("length", index.source_location) - ).resolve(), - op=UInt64BinaryOperator.sub, - right=UInt64Constant( - value=abs(index.value), source_location=index.source_location - ), - source_location=index.source_location, - ) - else: - index_expr = expect_operand_type(index, pytypes.UInt64Type).resolve() + array_length = require_instance_builder( + self.member_access("length", index.source_location) + ) + index = resolve_negative_literal_index(index, array_length, location) result_expr = IndexExpression( base=self.resolve(), - index=index_expr, + index=index.resolve(), wtype=self.pytype.items.wtype, source_location=location, ) diff --git a/src/puya/awst_build/eb/bytes.py b/src/puya/awst_build/eb/bytes.py index ecee65a72f..8304673494 100644 --- a/src/puya/awst_build/eb/bytes.py +++ b/src/puya/awst_build/eb/bytes.py @@ -30,7 +30,7 @@ InstanceExpressionBuilder, TypeBuilder, ) -from puya.awst_build.eb._utils import compare_bytes +from puya.awst_build.eb._utils import compare_bytes, resolve_negative_literal_index from puya.awst_build.eb.bool import BoolExpressionBuilder from puya.awst_build.eb.interface import ( BuilderBinaryOp, @@ -175,20 +175,21 @@ def to_bytes(self, location: SourceLocation) -> Expression: return self.resolve() @typing.override - def member_access(self, name: str, location: SourceLocation) -> NodeBuilder: + def member_access(self, name: str, location: SourceLocation) -> InstanceBuilder: match name: case "length": len_call = intrinsic_factory.bytes_len(expr=self.resolve(), loc=location) return UInt64ExpressionBuilder(len_call) - return super().member_access(name, location) + raise CodeError(f"unrecognised member of {self.pytype}: {name}", location) @typing.override def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBuilder: - index_expr = expect_operand_type(index, pytypes.UInt64Type).resolve() + length = self.member_access("length", location) + index = resolve_negative_literal_index(index, length, location) expr = IndexExpression( source_location=location, base=self.resolve(), - index=index_expr, + index=index.resolve(), wtype=self.pytype.wtype, ) return BytesExpressionBuilder(expr) diff --git a/src/puya/awst_build/eb/storage/_util.py b/src/puya/awst_build/eb/storage/_util.py index 486d11dfb1..9a13194600 100644 --- a/src/puya/awst_build/eb/storage/_util.py +++ b/src/puya/awst_build/eb/storage/_util.py @@ -19,6 +19,7 @@ ) from puya.awst_build import intrinsic_factory, pytypes from puya.awst_build.contract_data import AppStorageDeclaration +from puya.awst_build.eb._utils import resolve_negative_literal_index from puya.awst_build.eb.bytes import BytesExpressionBuilder from puya.awst_build.eb.interface import ( BuilderBinaryOp, @@ -38,26 +39,14 @@ def index_box_bytes( index: InstanceBuilder, location: SourceLocation, ) -> InstanceBuilder: - - if isinstance(index, InstanceBuilder): - # no negatives - begin_index_expr = index.resolve() - elif not isinstance(index.value, int): - raise CodeError("Invalid literal index type", index.source_location) - elif index.value >= 0: - begin_index_expr = UInt64Constant(value=index.value, source_location=index.source_location) - else: - box_length = box_length_unchecked(box, location) - box_length_builder = UInt64ExpressionBuilder(box_length) - begin_index_expr = box_length_builder.binary_op( - index, BuilderBinaryOp.sub, location, reverse=False - ).resolve() + box_length = UInt64ExpressionBuilder(box_length_unchecked(box, location)) + begin_index = resolve_negative_literal_index(index, box_length, location) return BytesExpressionBuilder( IntrinsicCall( op_code="box_extract", stack_args=[ box.key, - begin_index_expr, + begin_index.resolve(), UInt64Constant(value=1, source_location=location), ], source_location=location,