Skip to content

Commit

Permalink
fix: support negative indexes on indexable types
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-makerx authored and achidlow committed Jun 25, 2024
1 parent 11aeff8 commit 9213996
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 39 deletions.
24 changes: 23 additions & 1 deletion src/puya/awst_build/eb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,36 @@
Expression,
ReinterpretCast,
)
from puya.awst_build.eb.interface import BuilderComparisonOp, InstanceBuilder
from puya.awst_build.eb.interface import (
BuilderBinaryOp,
BuilderComparisonOp,
BuilderUnaryOp,
InstanceBuilder,
LiteralBuilder,
)
from puya.awst_build.eb.uint64 import UInt64TypeBuilder

if typing.TYPE_CHECKING:
from puya.parse import SourceLocation

logger = log.get_logger(__name__)


def resolve_negative_literal_index(
index: InstanceBuilder, length: InstanceBuilder, location: SourceLocation
) -> InstanceBuilder:
match index:
case LiteralBuilder(value=int(int_index)) if int_index < 0:
return length.binary_op(
index.unary_op(BuilderUnaryOp.negative, location),
BuilderBinaryOp.sub,
location,
reverse=False,
)
case _:
return index.resolve_literal(UInt64TypeBuilder(index.source_location))


def bool_eval_to_constant(
*, value: bool, location: SourceLocation, negate: bool = False
) -> InstanceBuilder:
Expand Down
29 changes: 11 additions & 18 deletions src/puya/awst_build/eb/arc4/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
SingleEvaluation,
Statement,
TupleExpression,
UInt64BinaryOperation,
UInt64BinaryOperator,
UInt64Constant,
)
from puya.awst_build import intrinsic_factory, pytypes
Expand All @@ -33,7 +31,12 @@
BytesBackedInstanceExpressionBuilder,
BytesBackedTypeBuilder,
)
from puya.awst_build.eb._utils import bool_eval_to_constant, compare_bytes, compare_expr_bytes
from puya.awst_build.eb._utils import (
bool_eval_to_constant,
compare_bytes,
compare_expr_bytes,
resolve_negative_literal_index,
)
from puya.awst_build.eb.arc4.base import CopyBuilder, arc4_bool_bytes
from puya.awst_build.eb.factories import builder_for_instance
from puya.awst_build.eb.interface import (
Expand All @@ -48,7 +51,6 @@
from puya.awst_build.eb.uint64 import UInt64ExpressionBuilder
from puya.awst_build.eb.void import VoidExpressionBuilder
from puya.awst_build.utils import (
expect_operand_type,
require_instance_builder,
require_instance_builder_of_type,
)
Expand Down Expand Up @@ -254,22 +256,13 @@ def iterate(self) -> Iteration:

@typing.override
def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBuilder:
if isinstance(index, LiteralBuilder) and isinstance(index.value, int) and index.value < 0:
index_expr: Expression = UInt64BinaryOperation(
left=require_instance_builder(
self.member_access("length", index.source_location)
).resolve(),
op=UInt64BinaryOperator.sub,
right=UInt64Constant(
value=abs(index.value), source_location=index.source_location
),
source_location=index.source_location,
)
else:
index_expr = expect_operand_type(index, pytypes.UInt64Type).resolve()
array_length = require_instance_builder(
self.member_access("length", index.source_location)
)
index = resolve_negative_literal_index(index, array_length, location)
result_expr = IndexExpression(
base=self.resolve(),
index=index_expr,
index=index.resolve(),
wtype=self.pytype.items.wtype,
source_location=location,
)
Expand Down
11 changes: 6 additions & 5 deletions src/puya/awst_build/eb/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
InstanceExpressionBuilder,
TypeBuilder,
)
from puya.awst_build.eb._utils import compare_bytes
from puya.awst_build.eb._utils import compare_bytes, resolve_negative_literal_index
from puya.awst_build.eb.bool import BoolExpressionBuilder
from puya.awst_build.eb.interface import (
BuilderBinaryOp,
Expand Down Expand Up @@ -175,20 +175,21 @@ def to_bytes(self, location: SourceLocation) -> Expression:
return self.resolve()

@typing.override
def member_access(self, name: str, location: SourceLocation) -> NodeBuilder:
def member_access(self, name: str, location: SourceLocation) -> InstanceBuilder:
match name:
case "length":
len_call = intrinsic_factory.bytes_len(expr=self.resolve(), loc=location)
return UInt64ExpressionBuilder(len_call)
return super().member_access(name, location)
raise CodeError(f"unrecognised member of {self.pytype}: {name}", location)

@typing.override
def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBuilder:
index_expr = expect_operand_type(index, pytypes.UInt64Type).resolve()
length = self.member_access("length", location)
index = resolve_negative_literal_index(index, length, location)
expr = IndexExpression(
source_location=location,
base=self.resolve(),
index=index_expr,
index=index.resolve(),
wtype=self.pytype.wtype,
)
return BytesExpressionBuilder(expr)
Expand Down
19 changes: 4 additions & 15 deletions src/puya/awst_build/eb/storage/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from puya.awst_build import intrinsic_factory, pytypes
from puya.awst_build.contract_data import AppStorageDeclaration
from puya.awst_build.eb._utils import resolve_negative_literal_index
from puya.awst_build.eb.bytes import BytesExpressionBuilder
from puya.awst_build.eb.interface import (
BuilderBinaryOp,
Expand All @@ -38,26 +39,14 @@ def index_box_bytes(
index: InstanceBuilder,
location: SourceLocation,
) -> InstanceBuilder:

if isinstance(index, InstanceBuilder):
# no negatives
begin_index_expr = index.resolve()
elif not isinstance(index.value, int):
raise CodeError("Invalid literal index type", index.source_location)
elif index.value >= 0:
begin_index_expr = UInt64Constant(value=index.value, source_location=index.source_location)
else:
box_length = box_length_unchecked(box, location)
box_length_builder = UInt64ExpressionBuilder(box_length)
begin_index_expr = box_length_builder.binary_op(
index, BuilderBinaryOp.sub, location, reverse=False
).resolve()
box_length = UInt64ExpressionBuilder(box_length_unchecked(box, location))
begin_index = resolve_negative_literal_index(index, box_length, location)
return BytesExpressionBuilder(
IntrinsicCall(
op_code="box_extract",
stack_args=[
box.key,
begin_index_expr,
begin_index.resolve(),
UInt64Constant(value=1, source_location=location),
],
source_location=location,
Expand Down

0 comments on commit 9213996

Please sign in to comment.