Skip to content

Commit

Permalink
fix: fix semantic issues with tuple comparisons of different length /…
Browse files Browse the repository at this point in the history
… types
  • Loading branch information
achidlow committed Jun 25, 2024
1 parent b938faf commit 840118a
Show file tree
Hide file tree
Showing 22 changed files with 739 additions and 192 deletions.
2 changes: 1 addition & 1 deletion examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
tictactoe/TicTacToe 803 682 121 670 12
too_many_permutations 108 106 2 106 0
transaction/Transaction 991 932 59 932 0
tuple_support/TupleComparisons 56 44 12 44 0
tuple_support/TupleComparisons 130 75 55 75 0
tuple_support/TupleSupport 389 292 97 292 0
typed_abi_call/Greeter 2818 2446 372 2446 0
typed_abi_call/Logger 876 762 114 762 0
Expand Down
60 changes: 38 additions & 22 deletions src/puya/awst_build/eb/tuple.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import typing
from collections.abc import Callable, Sequence

Expand Down Expand Up @@ -75,9 +76,9 @@ def call(

def _init(args: Sequence[NodeBuilder], location: SourceLocation) -> InstanceBuilder:
if not args:
raise CodeError("empty tuples are not supported", location)
return TupleLiteralBuilder(items=[], location=location)
if len(args) != 1:
raise CodeError("tuple constructor takes a single argument")
raise CodeError(f"tuple expected at most 1 argument, got {len(args)}", location)
(arg,) = args
arg = require_instance_builder(arg)
# TODO: generalise statically-iterable expressions at InstanceBuilder level
Expand All @@ -98,8 +99,6 @@ def _init(args: Sequence[NodeBuilder], location: SourceLocation) -> InstanceBuil
fixed_size = len(str_lit)
case _:
raise CodeError("unhandled argument type", arg.source_location)
if fixed_size == 0:
raise CodeError("empty tuples are not supported", location)
indexer = _make_maybe_tuple_indexer(arg, location)
return TupleLiteralBuilder(
items=[indexer(idx) for idx in range(fixed_size)], location=location
Expand Down Expand Up @@ -373,42 +372,59 @@ def _compare(
match op:
case BuilderComparisonOp.eq:
chain_op = "&&"
result_if_types_differ = False
inverse = BuilderComparisonOp.ne
result_if_both_empty = True
case BuilderComparisonOp.ne:
chain_op = "||"
result_if_types_differ = True
inverse = BuilderComparisonOp.eq
result_if_both_empty = False
case _:
raise CodeError(
f"the {op.value!r} operator is not currently supported with tuples", location
)

if len(lhs.pytype.items) != len(rhs.pytype.items):
# TODO: semantic compatibility issue
return bool_eval_to_constant(value=result_if_types_differ, location=location)

lhs_indexer = _make_maybe_tuple_indexer(lhs, location)
rhs_indexer = _make_maybe_tuple_indexer(rhs, location)

def compare_at_index(idx: int) -> Expression:
left = lhs_indexer(idx)
right = rhs_indexer(idx)
cmp_builder = left.compare(right, op=op, location=location)
lhs_items = [lhs_indexer(idx) for idx, _ in enumerate(lhs.pytype.items)]
rhs_items = [rhs_indexer(idx) for idx, _ in enumerate(rhs.pytype.items)]

result_exprs = []
for idx, (lhs_item, rhs_item) in enumerate(itertools.zip_longest(lhs_items, rhs_items)):
if lhs_item is None:
# if lhs is shorter than rhs, use the rhs item to compare against itself,
# so that rhs is still fully evaluated to maintain semantic compatibility,
# and we don't create typing errors trying to invent a different eval.
# make sure it's evaluated once though, because it only appears on one side.
# also, to make the comparison correct, we need to invert the overall op,
# if we're doing tup1 == tup2, then things get &&'d together, and we need to use !=
# if we're doing tup1 != tup2, then things get ||'d together, and we need to use ==
lhs_item = rhs_item = rhs_item.single_eval()
op_at = inverse
elif rhs_item is None:
rhs_item = lhs_item = lhs_item.single_eval()
op_at = inverse
else:
op_at = op
cmp_builder = lhs_item.compare(rhs_item, op=op_at, location=location)
if cmp_builder is NotImplemented:
cmp_builder = right.compare(left, op=op.reversed(), location=location)
cmp_builder = rhs_item.compare(lhs_item, op=op_at.reversed(), location=location)
if cmp_builder is NotImplemented:
raise CodeError(
f"items at index {idx} do not support comparison with operator {op.value!r}",
f"items at index {idx} do not support comparison with operator {op_at.value!r}",
location,
)
return cmp_builder.resolve()
result_exprs.append(cmp_builder.resolve())

if not result_exprs:
return bool_eval_to_constant(value=result_if_both_empty, location=location)

result = compare_at_index(0)
for i in range(1, len(lhs.pytype.items)):
result = result_exprs[0]
for result_expr in result_exprs[1:]:
result = IntrinsicCall(
op_code=chain_op,
stack_args=[result, compare_at_index(i)],
source_location=location,
stack_args=[result, result_expr],
wtype=wtypes.bool_wtype,
source_location=location,
)
return BoolExpressionBuilder(result)

Expand Down
4 changes: 0 additions & 4 deletions src/puya/awst_build/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,13 +1081,9 @@ def visit_tuple_expr(self, mypy_expr: mypy.nodes.TupleExpr) -> NodeBuilder:
from puya.awst_build.eb.tuple import TupleLiteralBuilder

location = self._location(mypy_expr)
if not mypy_expr.items:
raise CodeError("empty tuples are not supported", location)

item_builders = [
require_instance_builder(mypy_item.accept(self)) for mypy_item in mypy_expr.items
]

return TupleLiteralBuilder(item_builders, location)

def visit_assignment_expr(self, expr: mypy.nodes.AssignmentExpr) -> NodeBuilder:
Expand Down
Loading

0 comments on commit 840118a

Please sign in to comment.