From 0c8a745311be5fb779edcba2929f3e9911f265ad Mon Sep 17 00:00:00 2001 From: Adam Chidlow Date: Fri, 14 Jun 2024 13:39:00 +0800 Subject: [PATCH] feat: support tuple equality comparisons with literal elements, support tuple repetition & concatenation, and support indexing/slicing literals that support it --- src/puya/awst/wtypes.py | 4 +- src/puya/awst_build/eb/_literals.py | 2 +- src/puya/awst_build/eb/arc4/abi_call.py | 12 +- src/puya/awst_build/eb/interface.py | 4 + src/puya/awst_build/eb/tuple.py | 318 +++++++++++++++--- src/puya/awst_build/pytypes.py | 19 +- src/puya/awst_build/subroutine.py | 48 +-- .../test_expected_output/expected_errors.test | 4 +- 8 files changed, 310 insertions(+), 101 deletions(-) diff --git a/src/puya/awst/wtypes.py b/src/puya/awst/wtypes.py index c6fcfa624c..e486795113 100644 --- a/src/puya/awst/wtypes.py +++ b/src/puya/awst/wtypes.py @@ -186,7 +186,7 @@ class WTuple(WType): def __init__(self, types: Iterable[WType], source_location: SourceLocation | None): types = tuple(types) if not types: - raise CodeError("tuple needs types", source_location) + raise CodeError("empty tuples are not supported", source_location) if void_wtype in types: raise CodeError("tuple should not contain void types", source_location) name = f"tuple<{','.join([t.name for t in types])}>" @@ -285,7 +285,7 @@ class ARC4Tuple(ARC4Type): def __init__(self, types: Iterable[WType], source_location: SourceLocation | None): types = tuple(types) if not types: - raise CodeError("ARC4 Tuple cannot be empty", source_location) + raise CodeError("empty tuples are not supported", source_location) immutable = True arc4_types = [] for typ_idx, typ in enumerate(types): diff --git a/src/puya/awst_build/eb/_literals.py b/src/puya/awst_build/eb/_literals.py index 23942aa69a..4bc0a523c7 100644 --- a/src/puya/awst_build/eb/_literals.py +++ b/src/puya/awst_build/eb/_literals.py @@ -58,7 +58,7 @@ def pytype(self) -> pytypes.PyType: def resolve(self) -> Expression: if isinstance(self.value, bool): return BoolConstant(value=self.value, source_location=self.source_location) - raise CodeError("A Python literal is not valid at this location", self.source_location) + raise CodeError("a Python literal is not valid at this location", self.source_location) @typing.override def resolve_literal(self, converter: LiteralConverter) -> InstanceBuilder: diff --git a/src/puya/awst_build/eb/arc4/abi_call.py b/src/puya/awst_build/eb/arc4/abi_call.py index 67834f5e0d..f27aab5f1c 100644 --- a/src/puya/awst_build/eb/arc4/abi_call.py +++ b/src/puya/awst_build/eb/arc4/abi_call.py @@ -39,12 +39,13 @@ implicit_operand_conversion, ) from puya.awst_build.eb.arc4.base import ARC4FromLogBuilder +from puya.awst_build.eb.factories import builder_for_instance from puya.awst_build.eb.interface import InstanceBuilder, LiteralBuilder, NodeBuilder from puya.awst_build.eb.subroutine import BaseClassSubroutineInvokerExpressionBuilder from puya.awst_build.eb.transaction import InnerTransactionExpressionBuilder from puya.awst_build.eb.transaction.fields import get_field_python_name from puya.awst_build.eb.transaction.inner_params import get_field_expr -from puya.awst_build.eb.tuple import TupleExpressionBuilder +from puya.awst_build.eb.tuple import TupleLiteralBuilder from puya.awst_build.utils import ( get_decorators_by_fullname, require_instance_builder, @@ -398,12 +399,9 @@ def append_ref_arg(ref_list: list[Expression], arg: InstanceBuilder) -> None: else: raise InternalError("Return type does not match signature type", location) - result_pytype = pytypes.GenericTupleType.parameterise( - [declared_result_pytype, itxn_result_pytype], location - ) - tuple_expr = TupleExpression.from_items((abi_result, itxn_tmp), location) - assert tuple_expr.wtype == result_pytype.wtype # TODO: fixme - return TupleExpressionBuilder(tuple_expr, result_pytype) + abi_result_builder = builder_for_instance(declared_result_pytype, abi_result) + itxn_tmp_builder = builder_for_instance(itxn_result_pytype, itxn_tmp) + return TupleLiteralBuilder((abi_result_builder, itxn_tmp_builder), location) def _add_array_exprs( diff --git a/src/puya/awst_build/eb/interface.py b/src/puya/awst_build/eb/interface.py index 3d5f10c48d..d34b9c8a8d 100644 --- a/src/puya/awst_build/eb/interface.py +++ b/src/puya/awst_build/eb/interface.py @@ -12,6 +12,7 @@ from puya.awst_build.contract_data import AppStorageDeclaration from puya.errors import CodeError from puya.parse import SourceLocation +from puya.utils import invert_ordered_binary_op if typing.TYPE_CHECKING: from collections.abc import Collection, Sequence @@ -28,6 +29,9 @@ class BuilderComparisonOp(enum.StrEnum): gt = ">" gte = ">=" + def reversed(self) -> BuilderComparisonOp: + return BuilderComparisonOp(invert_ordered_binary_op(self.value)) + @enum.unique class BuilderUnaryOp(enum.StrEnum): diff --git a/src/puya/awst_build/eb/tuple.py b/src/puya/awst_build/eb/tuple.py index 703bd4a8aa..c137f7b90b 100644 --- a/src/puya/awst_build/eb/tuple.py +++ b/src/puya/awst_build/eb/tuple.py @@ -1,5 +1,5 @@ import typing -from collections.abc import Sequence +from collections.abc import Callable, Sequence import mypy.nodes @@ -11,25 +11,27 @@ Contains, Expression, IntegerConstant, + Lvalue, SliceExpression, + Statement, TupleExpression, TupleItemExpression, UInt64Constant, ) from puya.awst_build import pytypes -from puya.awst_build.eb._base import ( - GenericTypeBuilder, - InstanceExpressionBuilder, - TypeBuilder, -) +from puya.awst_build.eb._base import GenericTypeBuilder, InstanceExpressionBuilder, TypeBuilder +from puya.awst_build.eb._literals import LiteralBuilderImpl from puya.awst_build.eb._utils import bool_eval_to_constant from puya.awst_build.eb.bool import BoolExpressionBuilder from puya.awst_build.eb.factories import builder_for_instance from puya.awst_build.eb.interface import ( + BuilderBinaryOp, BuilderComparisonOp, + BuilderUnaryOp, InstanceBuilder, Iteration, LiteralBuilder, + LiteralConverter, NodeBuilder, ) from puya.awst_build.utils import require_instance_builder @@ -49,6 +51,10 @@ def call( arg_names: list[str | None], location: SourceLocation, ) -> InstanceBuilder: + if not args: + raise CodeError("empty tuples are not supported", location) + if len(args) != 1: + raise CodeError("tuple constructor takes a single argument") inst_args = [require_instance_builder(a) for a in args] typ = pytypes.GenericTupleType.parameterise([ia.pytype for ia in inst_args], location) tuple_expr = TupleExpression.from_items([ia.resolve() for ia in inst_args], location) @@ -72,6 +78,9 @@ def call( arg_names: list[str | None], location: SourceLocation, ) -> InstanceBuilder: + pytype = self.produces() + if len(args) != len(pytype.items): + raise CodeError("") tuple_expr = TupleExpression( items=[require_instance_builder(a).resolve() for a in args], @@ -81,6 +90,121 @@ def call( return TupleExpressionBuilder(tuple_expr, self.produces()) +class TupleLiteralBuilder(InstanceBuilder[pytypes.TupleType]): + def __init__(self, items: Sequence[InstanceBuilder], location: SourceLocation): + super().__init__(location) + self._items = tuple(items) + self._pytype = pytypes.GenericTupleType.parameterise([i.pytype for i in items], location) + + @typing.override + @property + def pytype(self) -> pytypes.TupleType: + return self._pytype + + @property + def items(self) -> Sequence[InstanceBuilder]: + return self._items + + @typing.override + def member_access(self, name: str, location: SourceLocation) -> typing.Never: + if name in dir(tuple()): # noqa: C408 + raise CodeError("method is not currently supported", location) + raise CodeError("unrecognised member access", location) + + @typing.override + def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> InstanceBuilder: + # TODO: semantic compatibility issue, here and potentially elsewhere: ignores evaluation + return bool_eval_to_constant(value=bool(self._items), location=location, negate=negate) + + @typing.override + def to_bytes(self, location: SourceLocation) -> Expression: + raise CodeError(f"cannot serialize {self.pytype}", location) + + @typing.override + def resolve(self) -> TupleExpression: + item_exprs = [i.resolve() for i in self.items] + return TupleExpression.from_items(item_exprs, self.source_location) + + @typing.override + def resolve_lvalue(self) -> Lvalue: + return self.resolve() + + @typing.override + def resolve_literal(self, converter: LiteralConverter) -> InstanceBuilder: + # even though this may contain literals, it's not homogenous, so we can't really + # resolve with a single converter currently...? + return self + + @typing.override + def delete(self, location: SourceLocation) -> Statement: + raise CodeError("cannot delete tuple literal", location) + + @typing.override + def unary_op(self, op: BuilderUnaryOp, location: SourceLocation) -> InstanceBuilder: + raise CodeError(f"bad operand type for unary {op.value}: 'tuple'", location) + + @typing.override + def compare( + self, other: InstanceBuilder, op: BuilderComparisonOp, location: SourceLocation + ) -> InstanceBuilder: + return _compare(self, other, op, location) + + @typing.override + def binary_op( + self, + other: InstanceBuilder, + op: BuilderBinaryOp, + location: SourceLocation, + *, + reverse: bool, + ) -> InstanceBuilder: + match op: + case BuilderBinaryOp.add: + return _concat(self, other, location, reverse=reverse) + case BuilderBinaryOp.mult: + match other: + # can't handle non-simple literals here + case LiteralBuilder(value=int(mult_literal)): + return TupleLiteralBuilder(self._items * mult_literal, location) + case _: + raise CodeError("can't multiple sequence by non-int-literal", location) + case _: + return NotImplemented + + @typing.override + def augmented_assignment( + self, op: BuilderBinaryOp, rhs: InstanceBuilder, location: SourceLocation + ) -> Statement: + raise CodeError("'tuple' is an illegal expression for augmented assignment", location) + + @typing.override + def iterate(self) -> Iteration: + return self.resolve() + + def _expr_builder(self) -> InstanceBuilder: + # used to maintain semantic compatibility, we must resolve this so all elements + # get evaluated, we can't handle literal indexing or literal containment specially + return TupleExpressionBuilder(self.resolve(), self.pytype) + + @typing.override + def contains(self, item: InstanceBuilder, location: SourceLocation) -> InstanceBuilder: + return self._expr_builder().contains(item, location) + + @typing.override + def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBuilder: + return self._expr_builder().index(index, location) + + @typing.override + def slice_index( + self, + begin_index: InstanceBuilder | None, + end_index: InstanceBuilder | None, + stride: InstanceBuilder | None, + location: SourceLocation, + ) -> InstanceBuilder: + return self._expr_builder().slice_index(begin_index, end_index, stride, location) + + class TupleExpressionBuilder(InstanceExpressionBuilder[pytypes.TupleType]): def __init__(self, expr: Expression, typ: pytypes.PyType): assert isinstance(typ, pytypes.TupleType) @@ -90,6 +214,36 @@ def __init__(self, expr: Expression, typ: pytypes.PyType): def to_bytes(self, location: SourceLocation) -> Expression: raise CodeError(f"cannot serialize {self.pytype}", location) + @typing.override + def member_access(self, name: str, location: SourceLocation) -> typing.Never: + if name in dir(tuple()): # noqa: C408 + raise CodeError("method is not currently supported", location) + raise CodeError("unrecognised member access", location) + + @typing.override + def binary_op( + self, + other: InstanceBuilder, + op: BuilderBinaryOp, + location: SourceLocation, + *, + reverse: bool, + ) -> InstanceBuilder: + match op: + case BuilderBinaryOp.add: + return _concat(self, other, location, reverse=reverse) + case BuilderBinaryOp.mult: + match other: + # can't handle non-simple literals here + case LiteralBuilder(value=int(mult_literal)): + indexer = _make_tuple_indexer(self, location) + items = [indexer(idx) for idx in range(len(self.pytype.items))] + return TupleLiteralBuilder(items * mult_literal, location) + case _: + raise CodeError("can't multiple sequence by non-int-literal", location) + case _: + return NotImplemented + @typing.override def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBuilder: # special handling of tuples, they can be indexed by int literal only, @@ -98,18 +252,16 @@ def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBui index_expr_or_literal = index match index_expr_or_literal: case LiteralBuilder(value=int(index_value)): - return self._index(index_value, location) + pass case _: raise CodeError( "tuples can only be indexed by int constants", index_expr_or_literal.source_location, ) - - def _index(self, index_value: int, location: SourceLocation) -> InstanceBuilder: try: item_typ = self.pytype.items[index_value] except IndexError as ex: - raise CodeError("Tuple index out of bounds", location) from ex + raise CodeError("tuple index out of range", location) from ex item_expr = TupleItemExpression( base=self.resolve(), index=index_value, @@ -126,13 +278,13 @@ def slice_index( location: SourceLocation, ) -> InstanceBuilder: if stride is not None: - raise CodeError("Stride is not supported", location=stride.source_location) + raise CodeError("stride is not supported", location=stride.source_location) - start_expr, start_idx = self._convert_index(begin_index) - end_expr, end_idx = self._convert_index(end_index) + start_expr, start_idx = self._clamp_slice_index(begin_index) + end_expr, end_idx = self._clamp_slice_index(end_index) slice_types = self.pytype.items[start_idx:end_idx] if not slice_types: - raise CodeError("Empty slices are not supported", location) + raise CodeError("empty slices are not supported", location) updated_type = pytypes.GenericTupleType.parameterise(slice_types, location) updated_wtype = updated_type.wtype @@ -147,7 +299,7 @@ def slice_index( updated_type, ) - def _convert_index( + def _clamp_slice_index( self, index: NodeBuilder | None ) -> tuple[IntegerConstant | None, int | None]: match index: @@ -160,7 +312,7 @@ def _convert_index( expr = UInt64Constant(value=positive_idx_clamped, source_location=start_loc) case _: raise CodeError( - "Tuples can only be indexed with literal values", index.source_location + "tuples can only be indexed with literal values", index.source_location ) return expr, idx @@ -170,12 +322,11 @@ def iterate(self) -> Iteration: @typing.override def contains(self, item: InstanceBuilder, location: SourceLocation) -> InstanceBuilder: - if isinstance(item, LiteralBuilder): - raise CodeError( - "Cannot use in/not in check with a Python literal against a tuple", location - ) - item_expr = item.resolve() - contains_expr = Contains(source_location=location, item=item_expr, sequence=self.resolve()) + contains_expr = Contains( + sequence=self.resolve(), + item=item.resolve(), + source_location=location, + ) return BoolExpressionBuilder(contains_expr) @typing.override @@ -186,32 +337,99 @@ def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> Instan def compare( self, other: InstanceBuilder, op: BuilderComparisonOp, location: SourceLocation ) -> InstanceBuilder: - match op: - case BuilderComparisonOp.eq: - chain_op = BinaryBooleanOperator.and_ - result_if_types_differ = False - case BuilderComparisonOp.ne: - chain_op = BinaryBooleanOperator.or_ - result_if_types_differ = True - case _: - raise CodeError(f"The {op} operator on the tuple type is not supported", location) - - if not isinstance(other, TupleExpressionBuilder): - return NotImplemented - if self.pytype.items != other.pytype.items: - return bool_eval_to_constant(value=result_if_types_differ, location=location) - - def compare_at_index(idx: int) -> Expression: - left = self._index(idx, location) - right = other._index(idx, location) # noqa: SLF001 - return left.compare(right, op=op, location=location).resolve() - - result = compare_at_index(0) - for i in range(1, len(self.pytype.items)): - result = BooleanBinaryOperation( - left=result, - right=compare_at_index(i), - op=chain_op, - source_location=location, + return _compare(self, other, op, location) + + +def _compare( + lhs: InstanceBuilder[pytypes.TupleType], + rhs: InstanceBuilder, + op: BuilderComparisonOp, + location: SourceLocation, +) -> InstanceBuilder: + if not isinstance(rhs.pytype, pytypes.TupleType): + return NotImplemented + + match op: + case BuilderComparisonOp.eq: + chain_op = BinaryBooleanOperator.and_ + result_if_types_differ = False + case BuilderComparisonOp.ne: + chain_op = BinaryBooleanOperator.or_ + result_if_types_differ = True + case _: + raise CodeError( + f"the {op.value!r} operator is not currently supported with tuples", location + ) + + if len(lhs.pytype.items) != len(rhs.pytype.items): + # TODO: semantic compatibility issue + return bool_eval_to_constant(value=result_if_types_differ, location=location) + + lhs_indexer = _make_tuple_indexer(lhs, location) + rhs_indexer = _make_tuple_indexer(rhs, location) + + def compare_at_index(idx: int) -> Expression: + left = lhs_indexer(idx) + right = rhs_indexer(idx) + cmp_builder = left.compare(right, op=op, location=location) + if cmp_builder is NotImplemented: + cmp_builder = right.compare(left, op=op.reversed(), location=location) + if cmp_builder is NotImplemented: + raise CodeError( + f"items at index {idx} do not support comparison with operator {op.value!r}", + location, ) - return BoolExpressionBuilder(result) + return cmp_builder.resolve() + + result = compare_at_index(0) + for i in range(1, len(lhs.pytype.items)): + result = BooleanBinaryOperation( + left=result, + right=compare_at_index(i), + op=chain_op, + source_location=location, + ) + return BoolExpressionBuilder(result) + + +def _concat( + this: InstanceBuilder[pytypes.TupleType], + other: InstanceBuilder, + location: SourceLocation, + *, + reverse: bool, +) -> InstanceBuilder: + if not isinstance(other.pytype, pytypes.TupleType): + raise CodeError("can only concatenate tuple with other tuples", location) + other = typing.cast(InstanceBuilder[pytypes.TupleType], other) + if not reverse: + lhs, rhs = this, other + else: + lhs, rhs = other, this + + lhs_indexer = _make_tuple_indexer(lhs, location) + rhs_indexer = _make_tuple_indexer(rhs, location) + + items = [ + *(lhs_indexer(idx) for idx in range(len(lhs.pytype.items))), + *(rhs_indexer(idx) for idx in range(len(rhs.pytype.items))), + ] + return TupleLiteralBuilder(items, location) + + +def _make_tuple_indexer( + builder: InstanceBuilder, location: SourceLocation +) -> Callable[[int], InstanceBuilder]: + """this function should ONLY be used if ALL tuple elements are going to be visited""" + if isinstance(builder, TupleLiteralBuilder): + # this is why this function exists, going through .index() would evaluate to + # an expression in the general case, but this way we can support comparisons with + # literal items in tuples naturally + captured = builder + return lambda idx: captured.items[idx] + + def indexer(idx: int) -> InstanceBuilder: + index_lit = LiteralBuilderImpl(value=idx, source_location=location) + return builder.index(index_lit, location) + + return indexer diff --git a/src/puya/awst_build/pytypes.py b/src/puya/awst_build/pytypes.py index 5dcec19c97..671a3414ee 100644 --- a/src/puya/awst_build/pytypes.py +++ b/src/puya/awst_build/pytypes.py @@ -2,7 +2,7 @@ import abc import typing -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from functools import cached_property import attrs @@ -207,11 +207,16 @@ def wtype(self) -> typing.Never: @typing.final -@attrs.frozen +@attrs.frozen(kw_only=True) class TupleType(PyType): generic: _GenericType[TupleType] - items: tuple[PyType, ...] = attrs.field(validator=attrs.validators.min_len(1)) - wtype: wtypes.WType + items: tuple[PyType, ...] + _wtype_cls: type[wtypes.WTuple | wtypes.ARC4Tuple] + source_location: SourceLocation | None + + @property + def wtype(self) -> wtypes.WTuple | wtypes.ARC4Tuple: + return self._wtype_cls((i.wtype for i in self.items), self.source_location) @typing.final @@ -530,18 +535,18 @@ def parameterise( def _make_tuple_parameterise( - typ: Callable[[Iterable[wtypes.WType], SourceLocation | None], wtypes.WType] + wtype_cls: type[wtypes.WTuple | wtypes.ARC4Tuple], ) -> _Parameterise[TupleType]: def parameterise( self: _GenericType[TupleType], args: _TypeArgs, source_location: SourceLocation | None ) -> TupleType: - item_wtypes = [arg.wtype for arg in args] name = f"{self.name}[{', '.join(pyt.name for pyt in args)}]" return TupleType( generic=self, name=name, items=tuple(args), - wtype=typ(item_wtypes, source_location), + wtype_cls=wtype_cls, + source_location=source_location, ) return parameterise diff --git a/src/puya/awst_build/subroutine.py b/src/puya/awst_build/subroutine.py index 4b63c124d5..8b8ef398d9 100644 --- a/src/puya/awst_build/subroutine.py +++ b/src/puya/awst_build/subroutine.py @@ -9,7 +9,6 @@ import mypy.types from puya import log -from puya.awst import wtypes from puya.awst.nodes import ( AppStateExpression, AssertStatement, @@ -39,7 +38,6 @@ Subroutine, SubroutineArgument, Switch, - TupleExpression, UInt64Constant, VarExpression, WhileLoop, @@ -89,7 +87,7 @@ from puya.errors import CodeError, InternalError from puya.models import ARC4MethodConfig from puya.parse import SourceLocation -from puya.utils import invert_ordered_binary_op, lazy_setdefault +from puya.utils import lazy_setdefault logger = log.get_logger(__name__) @@ -970,25 +968,21 @@ def visit_index_expr(self, expr: mypy.nodes.IndexExpr) -> NodeBuilder: case _: typing.assert_never(expr.analyzed) - base_expr = require_instance_builder(expr.base.accept(self)) - if isinstance(base_expr, LiteralBuilder): - raise CodeError( # TODO: yeet me - "Python literals cannot be indexed or sliced", base_expr.source_location - ) - + base_builder = require_instance_builder(expr.base.accept(self)) match expr.index: # special case handling of SliceExpr, so we don't need to handle slice Literal's # or some such everywhere + # TODO: SliceBuilder? case mypy.nodes.SliceExpr(begin_index=begin, end_index=end, stride=stride): - return base_expr.slice_index( + return base_builder.slice_index( begin_index=(require_instance_builder(begin.accept(self)) if begin else None), end_index=(require_instance_builder(end.accept(self)) if end else None), stride=(require_instance_builder(stride.accept(self)) if stride else None), location=expr_location, ) - index_expr_or_literal = require_instance_builder(expr.index.accept(self)) - return base_expr.index(index=index_expr_or_literal, location=expr_location) + index_builder = require_instance_builder(expr.index.accept(self)) + return base_builder.index(index=index_builder, location=expr_location) def visit_conditional_expr(self, expr: mypy.nodes.ConditionalExpr) -> NodeBuilder: expr_loc = self._location(expr) @@ -1071,12 +1065,11 @@ def _build_compare( return contains_builder.resolve() result: InstanceBuilder = NotImplemented + op = BuilderComparisonOp(operator) if isinstance(lhs, NodeBuilder): - op = BuilderComparisonOp(operator) result = lhs.compare(other=rhs, op=op, location=cmp_loc) if result is NotImplemented and isinstance(rhs, NodeBuilder): - op = BuilderComparisonOp(invert_ordered_binary_op(operator)) - result = rhs.compare(other=lhs, op=op, location=cmp_loc) + result = rhs.compare(other=lhs, op=op.reversed(), location=cmp_loc) if result is NotImplemented: raise CodeError(f"Unsupported comparison {operator!r} between types", cmp_loc) return result.resolve() @@ -1092,26 +1085,17 @@ def visit_bytes_expr(self, expr: mypy.nodes.BytesExpr) -> LiteralBuilder: return LiteralBuilderImpl(value=bytes_const, source_location=self._location(expr)) def visit_tuple_expr(self, mypy_expr: mypy.nodes.TupleExpr) -> NodeBuilder: - from puya.awst_build.eb.tuple import TupleExpressionBuilder + from puya.awst_build.eb.tuple import TupleLiteralBuilder - items = [ - require_instance_builder(mypy_item.accept(self)).resolve() - for mypy_item in mypy_expr.items - ] - # TODO: grab item types from EB items? - typ = self.context.mypy_expr_node_type(mypy_expr) location = self._location(mypy_expr) - if not items: - raise CodeError("Empty tuples are not supported", location) - wtype = typ.wtype - assert isinstance(wtype, wtypes.WTuple) - tuple_expr = TupleExpression( - source_location=location, - wtype=wtype, - items=items, - ) + if not mypy_expr.items: + raise CodeError("empty tuples are not supported", location) + + item_builders = [ + require_instance_builder(mypy_item.accept(self)) for mypy_item in mypy_expr.items + ] - return TupleExpressionBuilder(tuple_expr, typ) + return TupleLiteralBuilder(item_builders, location) def visit_assignment_expr(self, expr: mypy.nodes.AssignmentExpr) -> NodeBuilder: expr_loc = self._location(expr) diff --git a/tests/test_expected_output/expected_errors.test b/tests/test_expected_output/expected_errors.test index 490597266e..60481754a1 100644 --- a/tests/test_expected_output/expected_errors.test +++ b/tests/test_expected_output/expected_errors.test @@ -222,7 +222,7 @@ def invalid_constructor_args() -> None: g = Application("g") # type: ignore[arg-type] ## E: Invalid/unhandled arguments # the next check is just to make sure we didn't throw on the line above - h = 1 ## E: A Python literal is not valid at this location + h = 1 ## E: a Python literal is not valid at this location ## case: unsupported_type_comparisons @@ -241,7 +241,7 @@ def float_div() -> None: assert UInt64(10) / 2 # type: ignore[operator] ## E: To maintain semantic compatibility with Python, only the truncating division operator (//) is supported assert BigUInt(5) / 2 # type: ignore[operator] ## E: To maintain semantic compatibility with Python, only the truncating division operator (//) is supported - x = 1 ## E: A Python literal is not valid at this location + x = 1 ## E: a Python literal is not valid at this location ## case: unsupported_math_operators