diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 5a9e66f4..b5362b46 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -1,5 +1,8 @@ import ast -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +import textwrap +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast if TYPE_CHECKING: from guppy.gtypes import GuppyType @@ -54,51 +57,161 @@ def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> T: raise NotImplementedError(f"visit_{node.__class__.__name__} is not implemented") -class NameVisitor(ast.NodeVisitor): - """Visitor to collect all `Name` nodes occurring in an AST.""" - - names: list[ast.Name] - - def __init__(self) -> None: - self.names = [] - - def visit_Name(self, node: ast.Name) -> None: - self.names.append(node) +class AstSearcher(ast.NodeVisitor): + """Visitor that searches for occurrences of specific nodes in an AST.""" + + matcher: Callable[[ast.AST], bool] + dont_recurse_into: set[type[ast.AST]] + found: list[ast.AST] + is_first_node: bool + + def __init__( + self, + matcher: Callable[[ast.AST], bool], + dont_recurse_into: set[type[ast.AST]] | None = None, + ) -> None: + self.matcher = matcher + self.dont_recurse_into = dont_recurse_into or set() + self.found = [] + self.is_first_node = True + + def generic_visit(self, node: ast.AST) -> None: + if self.matcher(node): + self.found.append(node) + if self.is_first_node or type(node) not in self.dont_recurse_into: + self.is_first_node = False + super().generic_visit(node) + + +def find_nodes( + matcher: Callable[[ast.AST], bool], + node: ast.AST, + dont_recurse_into: set[type[ast.AST]] | None = None, +) -> list[ast.AST]: + """Returns all nodes in the AST that satisfy the matcher.""" + v = AstSearcher(matcher, dont_recurse_into) + v.visit(node) + return v.found def name_nodes_in_ast(node: Any) -> list[ast.Name]: """Returns all `Name` nodes occurring in an AST.""" - v = NameVisitor() - v.visit(node) - return v.names - - -class ReturnVisitor(ast.NodeVisitor): - """Visitor to collect all `Return` nodes occurring in an AST.""" + found = find_nodes(lambda n: isinstance(n, ast.Name), node) + return cast(list[ast.Name], found) - returns: list[ast.Return] - inside_func_def: bool - def __init__(self) -> None: - self.returns = [] - self.inside_func_def = False - - def visit_Return(self, node: ast.Return) -> None: - self.returns.append(node) +def return_nodes_in_ast(node: Any) -> list[ast.Return]: + """Returns all `Return` nodes occurring in an AST.""" + found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef}) + return cast(list[ast.Return], found) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - # Don't descend into nested function definitions - if not self.inside_func_def: - self.inside_func_def = True - for n in node.body: - self.visit(n) +def breaks_in_loop(node: Any) -> list[ast.Break]: + """Returns all `Break` nodes occurring in a loop. -def return_nodes_in_ast(node: Any) -> list[ast.Return]: - """Returns all `Return` nodes occurring in an AST.""" - v = ReturnVisitor() - v.visit(node) - return v.returns + Note that breaks in nested loops are excluded. + """ + found = find_nodes( + lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef} + ) + return cast(list[ast.Break], found) + + +class ContextAdjuster(ast.NodeTransformer): + """Updates the `ast.Context` indicating if expressions occur on the LHS or RHS.""" + + ctx: ast.expr_context + + def __init__(self, ctx: ast.expr_context) -> None: + self.ctx = ctx + + def visit(self, node: ast.AST) -> ast.AST: + return cast(ast.AST, super().visit(node)) + + def visit_Name(self, node: ast.Name) -> ast.Name: + return with_loc(node, ast.Name(id=node.id, ctx=self.ctx)) + + def visit_Starred(self, node: ast.Starred) -> ast.Starred: + return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx)) + + def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple: + return with_loc( + node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx) + ) + + def visit_List(self, node: ast.List) -> ast.List: + return with_loc( + node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx) + ) + + def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: + # Don't adjust the slice! + return with_loc( + node, + ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx), + ) + + def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: + return with_loc( + node, + ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx), + ) + + +@dataclass(frozen=True, eq=False) +class TemplateReplacer(ast.NodeTransformer): + """Replaces nodes in a template.""" + + replacements: Mapping[str, ast.AST | Sequence[ast.AST]] + default_loc: ast.AST + + def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]: + if x not in self.replacements: + msg = f"No replacement for `{x}` is given" + raise ValueError(msg) + return self.replacements[x] + + def visit_Name(self, node: ast.Name) -> ast.AST: + repl = self._get_replacement(node.id) + if not isinstance(repl, ast.expr): + msg = f"Replacement for `{node.id}` must be an expression" + raise TypeError(msg) + + # Update the context + adjuster = ContextAdjuster(node.ctx) + return with_loc(repl, adjuster.visit(repl)) + + def visit_Expr(self, node: ast.Expr) -> ast.AST | Sequence[ast.AST]: + if isinstance(node.value, ast.Name): + repl = self._get_replacement(node.value.id) + repls = [repl] if not isinstance(repl, Sequence) else repl + # Wrap expressions to turn them into statements + return [ + with_loc(r, ast.Expr(value=r)) if isinstance(r, ast.expr) else r + for r in repls + ] + return self.generic_visit(node) + + def generic_visit(self, node: ast.AST) -> ast.AST: + # Insert the default location + node = super().generic_visit(node) + return with_loc(self.default_loc, node) + + +def template_replace( + template: str, default_loc: ast.AST, **kwargs: ast.AST | Sequence[ast.AST] +) -> list[ast.stmt]: + """Turns a template into a proper AST by substituting all placeholders.""" + nodes = ast.parse(textwrap.dedent(template)).body + replacer = TemplateReplacer(kwargs, default_loc) + new_nodes = [] + for n in nodes: + new = replacer.visit(n) + if isinstance(new, list): + new_nodes.extend(new) + else: + new_nodes.append(new) + return new_nodes def line_col(node: ast.AST) -> tuple[int, int]: diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index a91965ca..d5c89af7 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -6,7 +6,7 @@ from typing_extensions import Self from guppy.ast_util import AstNode, name_nodes_in_ast -from guppy.nodes import NestedFunctionDef, PyExpr +from guppy.nodes import DesugaredListComp, NestedFunctionDef, PyExpr if TYPE_CHECKING: from guppy.cfg.cfg import BaseCFG @@ -119,6 +119,25 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node + def visit_DesugaredListComp(self, node: DesugaredListComp) -> None: + # Names bound in the comprehension are only available inside, so we shouldn't + # update `self.stats` with assignments + inner_visitor = VariableVisitor(self.bb) + inner_stats = inner_visitor.stats + + # The generators are evaluated left to right + for gen in node.generators: + inner_visitor.visit(gen.iter_assign) + inner_visitor.visit(gen.hasnext_assign) + inner_visitor.visit(gen.next_assign) + for cond in gen.ifs: + inner_visitor.visit(cond) + inner_visitor.visit(node.elt) + + self.stats.used |= { + x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned + } + def visit_PyExpr(self, node: PyExpr) -> None: # Don't look into `py(...)` expressions pass diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 3296ec93..1310562e 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -3,13 +3,29 @@ from collections.abc import Iterator from typing import NamedTuple -from guppy.ast_util import AstVisitor, set_location_from, with_loc +from guppy.ast_util import ( + AstVisitor, + ContextAdjuster, + find_nodes, + set_location_from, + template_replace, + with_loc, +) from guppy.cfg.bb import BB, BBStatement from guppy.cfg.cfg import CFG from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError from guppy.gtypes import NoneType -from guppy.nodes import NestedFunctionDef, PyExpr +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + IterEnd, + IterHasNext, + IterNext, + MakeIter, + NestedFunctionDef, + PyExpr, +) # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -142,6 +158,35 @@ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None: # its own jumps since the body is not guaranteed to execute return tail_bb + def visit_For(self, node: ast.For, bb: BB, jumps: Jumps) -> BB | None: + template = """ + it = make_iter + while True: + b, it = has_next + if b: + x, it = get_next + body + else: + break + end_iter # Consume iterator one last time + """ + + it = make_var(next(tmp_vars), node.iter) + b = make_var(next(tmp_vars), node.iter) + new_nodes = template_replace( + template, + node, + it=it, + b=b, + x=node.target, + make_iter=with_loc(node.iter, MakeIter(value=node.iter, origin_node=node)), + has_next=with_loc(node.iter, IterHasNext(value=it)), + get_next=with_loc(node.iter, IterNext(value=it)), + end_iter=with_loc(node.iter, IterEnd(value=it)), + body=node.body, + ) + return self.visit_stmts(new_nodes, bb, jumps) + def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None: if not jumps.continue_bb: raise InternalGuppyError("Continue BB not defined") @@ -211,20 +256,11 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]: builder = ExprBuilder(cfg, bb) return builder.visit(node), builder.bb - @classmethod - def _make_var(cls, name: str, loc: ast.expr | None = None) -> ast.Name: - """Creates an `ast.Name` node.""" - node = ast.Name(id=name, ctx=ast.Load) - if loc is not None: - set_location_from(node, loc) - return node - @classmethod def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None: """Adds a temporary variable assignment to a basic block.""" - node = ast.Assign(targets=[cls._make_var(tmp_name, value)], value=value) - set_location_from(node, value) - bb.statements.append(node) + lhs = make_var(tmp_name, value) + bb.statements.append(make_assign([lhs], value)) def visit_Name(self, node: ast.Name) -> ast.Name: return node @@ -256,7 +292,44 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name: self.bb = merge_bb # The final value is stored in the temporary variable - return self._make_var(tmp, node) + return make_var(tmp, node) + + def visit_ListComp(self, node: ast.ListComp) -> ast.AST: + # Check for illegal expressions + illegals = find_nodes(is_illegal_in_list_comp, node) + if illegals: + raise GuppyError( + "Expression is not supported inside a list comprehension", illegals[0] + ) + + # Desugar into statements that create the iterator, check for a next element, + # get the next element, and finalise the iterator. + gens = [] + for g in node.generators: + if g.is_async: + raise GuppyError("Async generators are not supported", g) + g.iter = self.visit(g.iter) + it = make_var(next(tmp_vars), g.iter) + hasnext = make_var(next(tmp_vars), g.iter) + desugared = DesugaredGenerator( + iter=it, + hasnext=hasnext, + iter_assign=make_assign( + [it], with_loc(it, MakeIter(value=g.iter, origin_node=node)) + ), + hasnext_assign=make_assign( + [hasnext, it], with_loc(it, IterHasNext(value=it)) + ), + next_assign=make_assign( + [g.target, it], with_loc(it, IterNext(value=it)) + ), + iterend=with_loc(it, IterEnd(value=it)), + ifs=g.ifs, + ) + gens.append(desugared) + + node.elt = self.visit(node.elt) + return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) def visit_Call(self, node: ast.Call) -> ast.AST: # Parse compile-time evaluated `py(...)` expression @@ -291,7 +364,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST: self._tmp_assign(tmp, false_const, false_bb) merge_bb = self.cfg.new_bb(true_bb, false_bb) self.bb = merge_bb - return self._make_var(tmp, node) + return make_var(tmp, node) # For all other expressions, just recurse deeper with the node transformer return super().generic_visit(node) @@ -414,3 +487,28 @@ def is_short_circuit_expr(node: ast.AST) -> bool: return isinstance(node, ast.BoolOp) or ( isinstance(node, ast.Compare) and len(node.comparators) > 1 ) + + +def is_illegal_in_list_comp(node: ast.AST) -> bool: + """Checks if an expression is illegal to use in a list comprehension.""" + return isinstance(node, ast.IfExp | ast.NamedExpr) or is_short_circuit_expr(node) + + +def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: + """Creates an `ast.Name` node.""" + node = ast.Name(id=name, ctx=ast.Load) + if loc is not None: + set_location_from(node, loc) + return node + + +def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign: + """Creates an `ast.Assign` node.""" + assert len(lhs) > 0 + adjuster = ContextAdjuster(ast.Store()) + lhs = [adjuster.visit(expr) for expr in lhs] + if len(lhs) == 1: + target = lhs[0] + else: + target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store())) + return with_loc(value, ast.Assign(targets=[target], value=value)) diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index c5cf8a64..ada6fd02 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -11,7 +11,7 @@ from guppy.ast_util import line_col from guppy.cfg.bb import BB from guppy.cfg.cfg import CFG, BaseCFG -from guppy.checker.core import Context, Globals, Variable +from guppy.checker.core import Context, Globals, Locals, Variable from guppy.checker.expr_checker import ExprSynthesizer, to_bool from guppy.checker.stmt_checker import StmtChecker from guppy.error import GuppyError @@ -127,7 +127,7 @@ def check_bb( raise GuppyError(f"Variable `{x}` is not defined", use) # Check the basic block - ctx = Context(globals, {v.name: v for v in inputs}) + ctx = Context(globals, Locals({v.name: v for v in inputs})) checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) # If we branch, we also have to check the branch predicate diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 69be5fb9..7a65083b 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -1,5 +1,8 @@ import ast +import copy +import itertools from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator from dataclasses import dataclass from typing import Any, NamedTuple @@ -110,8 +113,45 @@ def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034 return self -# Local variable mapping -Locals = dict[str, Variable] +@dataclass +class Locals: + """Scoped mapping from names to variables""" + + vars: dict[str, Variable] + parent_scope: "Locals | None" = None + + def __getitem__(self, item: str) -> Variable: + if item not in self.vars and self.parent_scope: + return self.parent_scope[item] + + return self.vars[item] + + def __setitem__(self, key: str, value: Variable) -> None: + self.vars[key] = value + + def __iter__(self) -> Iterator[str]: + parent_iter = iter(self.parent_scope) if self.parent_scope else iter(()) + return itertools.chain(iter(self.vars), parent_iter) + + def __contains__(self, item: str) -> bool: + return (item in self.vars) or ( + self.parent_scope is not None and item in self.parent_scope + ) + + def __copy__(self) -> "Locals": + # Make a copy of the var map so that mutating the copy doesn't + # mutate our variable mapping + return Locals(self.vars.copy(), copy.copy(self.parent_scope)) + + def keys(self) -> set[str]: + parent_keys = self.parent_scope.keys() if self.parent_scope else set() + return parent_keys | self.vars.keys() + + def items(self) -> Iterable[tuple[str, Variable]]: + parent_items = ( + iter(self.parent_scope.items()) if self.parent_scope else iter(()) + ) + return itertools.chain(self.vars.items(), parent_items) class Context(NamedTuple): diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 5ad876bf..776bfe9b 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -26,8 +26,17 @@ from contextlib import suppress from typing import Any, NoReturn, cast -from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type -from guppy.checker.core import CallableVariable, Context, DummyEvalDict, Globals +from guppy.ast_util import ( + AstNode, + AstVisitor, + breaks_in_loop, + get_type_opt, + name_nodes_in_ast, + return_nodes_in_ast, + with_loc, + with_type, +) +from guppy.checker.core import CallableVariable, Context, DummyEvalDict, Globals, Locals from guppy.error import ( GuppyError, GuppyTypeError, @@ -40,11 +49,26 @@ FunctionType, GuppyType, Inst, + LinstType, + ListType, + NoneType, Subst, TupleType, unify, ) -from guppy.nodes import GlobalName, LocalCall, LocalName, PyExpr, TypeApply +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + GlobalName, + IterEnd, + IterHasNext, + IterNext, + LocalCall, + LocalName, + MakeIter, + PyExpr, + TypeApply, +) # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { @@ -120,6 +144,14 @@ def check( a new desugared expression with type annotations and a substitution with the resolved type variables. """ + # If we already have a type for the expression, we just have to match it against + # the target + if actual := get_type_opt(expr): + subst, inst = check_type_against(actual, ty, expr, kind) + if inst: + expr = with_loc(expr, TypeApply(value=expr, tys=inst)) + return with_type(ty.substitute(subst), expr), subst + # When checking against a variable, we have to synthesize if isinstance(ty, ExistentialTypeVar): expr, syn_ty = self._synthesize(expr, allow_free_vars=False) @@ -147,6 +179,27 @@ def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> tuple[ast.expr, Subst]: subst |= s return node, subst + def visit_List(self, node: ast.List, ty: GuppyType) -> tuple[ast.expr, Subst]: + if not isinstance(ty, ListType | LinstType): + return self._fail(ty, node) + subst: Subst = {} + for i, el in enumerate(node.elts): + node.elts[i], s = self.check(el, ty.element_type.substitute(subst)) + subst |= s + return node, subst + + def visit_DesugaredListComp( + self, node: DesugaredListComp, ty: GuppyType + ) -> tuple[ast.expr, Subst]: + if not isinstance(ty, ListType | LinstType): + return self._fail(ty, node) + node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + subst = unify(ty.element_type, elt_ty, {}) + if subst is None: + actual = LinstType(elt_ty) if elt_ty.linear else ListType(elt_ty) + return self._fail(ty, actual, node) + return node, subst + def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: if len(node.keywords) > 0: raise GuppyError( @@ -217,7 +270,7 @@ def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, GuppyType]: raise GuppyError("Unsupported constant", node) return node, ty - def visit_Name(self, node: ast.Name) -> tuple[ast.expr, GuppyType]: + def visit_Name(self, node: ast.Name) -> tuple[ast.Name, GuppyType]: x = node.id if x in self.ctx.locals: var = self.ctx.locals[x] @@ -244,6 +297,22 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, GuppyType]: node.elts = [n for n, _ in elems] return node, TupleType([ty for _, ty in elems]) + def visit_List(self, node: ast.List) -> tuple[ast.expr, GuppyType]: + if len(node.elts) == 0: + raise GuppyTypeInferenceError( + "Cannot infer type variable in expression of type `list[?T]`", node + ) + node.elts[0], el_ty = self.synthesize(node.elts[0]) + node.elts[1:] = [self._check(el, el_ty)[0] for el in node.elts[1:]] + return node, LinstType(el_ty) if el_ty.linear else ListType(el_ty) + + def visit_DesugaredListComp( + self, node: DesugaredListComp + ) -> tuple[ast.expr, GuppyType]: + node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + result_ty = LinstType(elt_ty) if elt_ty.linear else ListType(elt_ty) + return node, result_ty + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: # We need to synthesise the argument type, so we can look up dunder methods node.operand, op_ty = self.synthesize(node.operand) @@ -293,6 +362,40 @@ def _synthesize_binary( node, ) + def _synthesize_instance_func( + self, + node: ast.expr, + args: list[ast.expr], + func_name: str, + err: str, + exp_sig: FunctionType | None = None, + give_reason: bool = False, + ) -> tuple[ast.expr, GuppyType]: + """Helper method for expressions that are implemented via instance methods. + + Raises a `GuppyTypeError` if the given instance method is not defined. The error + message can be customised by passing an `err` string and an optional error + reason can be printed. + + Optionally, the signature of the instance function can also be checked against a + given expected signature. + """ + node, ty = self.synthesize(node) + func = self.ctx.globals.get_instance_func(ty, func_name) + if func is None: + reason = f" since it does not implement the `{func_name}` method" + raise GuppyTypeError( + f"Expression of type `{ty}` is {err}{reason if give_reason else ''}", + node, + ) + if exp_sig and unify(exp_sig, func.ty.unquantified()[0], {}) is None: + raise GuppyError( + f"Method `{ty.name}.{func_name}` has signature `{func.ty}`, but " + f"expected `{exp_sig}`", + node, + ) + return func.synthesize_call([node, *args], node, self.ctx) + def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: return self._synthesize_binary(node.left, node.right, node.op, node) @@ -305,6 +408,16 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: left_expr, [op], [right_expr] = node.left, node.ops, node.comparators return self._synthesize_binary(left_expr, right_expr, op, node) + def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType( + [ty, ExistentialTypeVar.new("Key", False)], + ExistentialTypeVar.new("Val", False), + ) + return self._synthesize_instance_func( + node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig + ) + def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: if len(node.keywords) > 0: raise GuppyError("Keyword arguments are not supported", node.keywords[0]) @@ -326,6 +439,57 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: else: raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], ExistentialTypeVar.new("Iter", False)) + expr, ty = self._synthesize_instance_func( + node.value, [], "__iter__", "not iterable", exp_sig + ) + + # If the iterator was created by a `for` loop, we can add some extra checks to + # produce nicer errors for linearity violations. Namely, `break` and `return` + # are not allowed when looping over a linear iterator (`continue` is allowed) + if ty.linear and isinstance(node.origin_node, ast.For): + breaks = breaks_in_loop(node.origin_node) or return_nodes_in_ast( + node.origin_node + ) + if breaks: + raise GuppyTypeError( + f"Loop over iterator with linear type `{ty}` cannot be terminated " + f"prematurely", + breaks[0], + ) + return expr, ty + + def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], TupleType([BoolType(), ty])) + return self._synthesize_instance_func( + node.value, [], "__hasnext__", "not an iterator", exp_sig, True + ) + + def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType( + [ty], TupleType([ExistentialTypeVar.new("T", False), ty]) + ) + return self._synthesize_instance_func( + node.value, [], "__next__", "not an iterator", exp_sig, True + ) + + def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], NoneType()) + return self._synthesize_instance_func( + node.value, [], "__end__", "not an iterator", exp_sig, True + ) + + def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `ListComp`. Should have been removed during CFG" + f"construction: `{ast.unparse(node)}`" + ) + def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]: # The method we used for obtaining the Python variables in scope only works in # CPython (see `get_py_scope()`). @@ -639,6 +803,95 @@ def to_bool( return call, return_ty +def synthesize_comprehension( + node: DesugaredListComp, gens: list[DesugaredGenerator], ctx: Context +) -> tuple[DesugaredListComp, GuppyType]: + """Helper function to synthesise the element type of a list comprehension.""" + from guppy.checker.stmt_checker import StmtChecker + + def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None: + """Checks if an expression uses a linear variable from an outer scope. + + Since the expression is executed multiple times in the inner scope, this would + mean that the outer linear variable is used multiple times, which is not + allowed. + """ + for name in name_nodes_in_ast(expr): + x = name.id + if x in locals and x not in locals.vars: + var = locals[x] + if var.ty.linear: + raise GuppyTypeError( + f"Variable `{x}` with linear type `{var.ty}` would be used " + "multiple times when evaluating this comprehension", + name, + ) + + # If there are no more generators left, we can check the list element + if not gens: + node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt) + check_linear_use_from_outer_scope(node.elt, ctx.locals) + return node, elt_ty + + # Check the iterator in the outer context + gen, *gens = gens + gen.iter_assign = StmtChecker(ctx).visit_Assign(gen.iter_assign) + check_linear_use_from_outer_scope(gen.iter_assign.value, ctx.locals) + + # The rest is checked in a new nested context to ensure that variables don't escape + # their scope + inner_locals = Locals({}, parent_scope=ctx.locals) + inner_ctx = Context(ctx.globals, inner_locals) + expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx) + gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign) + gen.next_assign = stmt_chk.visit_Assign(gen.next_assign) + gen.hasnext, hasnext_ty = expr_sth.visit_Name(gen.hasnext) + gen.hasnext = with_type(hasnext_ty, gen.hasnext) + gen.iter, iter_ty = expr_sth.visit_Name(gen.iter) + gen.iter = with_type(iter_ty, gen.iter) + + # `if` guards are generally not allowed when we're iterating over linear variables. + # The only exception is if all linear variables are already consumed by the first + # guard + if gen.ifs: + gen.ifs[0], _ = expr_sth.synthesize(gen.ifs[0]) + + # Now, check if there are linear iteration variables that have not been used by + # the first guard + for target in name_nodes_in_ast(gen.next_assign.targets[0]): + var = inner_ctx.locals[target.id] + if var.ty.linear and not var.used and gen.ifs: + raise GuppyTypeError( + f"Variable `{var.name}` with linear type `{var.ty}` is not used on " + "all control-flow paths of the list comprehension", + target, + ) + + # Now, we can properly check all guards + for i in range(len(gen.ifs)): + gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i]) + gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx) + check_linear_use_from_outer_scope(gen.ifs[i], inner_locals) + + # Check remaining generators + node, elt_ty = synthesize_comprehension(node, gens, inner_ctx) + + # We have to make sure that all linear variables that were introduced in this scope + # have been used + for x, var in inner_ctx.locals.vars.items(): + if var.ty.linear and not var.used: + raise GuppyTypeError( + f"Variable `{x}` with linear type `{var.ty}` is not used", + var.defined_at, + ) + + # The iter finalizer is again checked in the outer context + ctx.locals[gen.iter.id].used = None + gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend) + gen.iterend = with_type(iterend_ty, gen.iterend) + return node, elt_ty + + def python_value_to_guppy_type( v: Any, node: ast.expr, globals: Globals ) -> GuppyType | None: diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index ea7f23b6..0847d609 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -49,14 +49,14 @@ def check_call( ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker args, subst, inst = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), subst + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty, inst = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), ty + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), ty @dataclass diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 0fa46608..85efc2b5 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -22,11 +22,13 @@ class StmtChecker(AstVisitor[BBStatement]): ctx: Context - bb: BB - return_ty: GuppyType + bb: BB | None + return_ty: GuppyType | None - def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: - assert not return_ty.unsolved_vars + def __init__( + self, ctx: Context, bb: BB | None = None, return_ty: GuppyType | None = None + ) -> None: + assert not return_ty or not return_ty.unsolved_vars self.ctx = ctx self.bb = bb self.return_ty = return_ty @@ -55,7 +57,7 @@ def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: f"Variable `{x}` with linear type `{var.ty}` is not used", var.defined_at, ) - self.ctx.locals[x] = Variable(x, ty, node, None) + self.ctx.locals[x] = Variable(x, ty, lhs, None) # The only other thing we support right now are tuples case ast.Tuple(elts=elts): @@ -76,7 +78,7 @@ def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: case _: raise GuppyError("Assignment pattern not supported", lhs) - def visit_Assign(self, node: ast.Assign) -> ast.stmt: + def visit_Assign(self, node: ast.Assign) -> ast.Assign: if len(node.targets) > 1: # This is the case for assignments like `a = b = 1` raise GuppyError("Multi assignment not supported", node) @@ -113,6 +115,9 @@ def visit_Expr(self, node: ast.Expr) -> ast.stmt: return node def visit_Return(self, node: ast.Return) -> ast.stmt: + if not self.return_ty: + raise InternalGuppyError("return_ty required to check return stmt!") + if node.value is not None: node.value, subst = self._check_expr( node.value, self.return_ty, "return value" @@ -127,6 +132,9 @@ def visit_Return(self, node: ast.Return) -> ast.stmt: def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: from guppy.checker.func_checker import check_nested_func_def + if not self.bb: + raise InternalGuppyError("BB required to check nested function def!") + func_def = check_nested_func_def(node, self.bb, self.ctx) self.ctx.locals[func_def.name] = Variable( func_def.name, func_def.ty, func_def, None diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index d2f1abb9..0b9c9a9a 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -56,7 +56,7 @@ def compile_bb( for (i, v) in enumerate(inputs) }, ) - dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, bb, dfg) + dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, dfg) # If we branch, we also have to compile the branch predicate if len(bb.successors) > 1: diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index cf01fc07..1d71556e 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -1,8 +1,16 @@ import ast +from collections.abc import Iterator +from contextlib import contextmanager from typing import Any -from guppy.ast_util import AstVisitor, get_type -from guppy.compiler.core import CompiledFunction, CompilerBase, DFContainer +from guppy.ast_util import AstVisitor, get_type, with_loc, with_type +from guppy.cfg.builder import tmp_vars +from guppy.compiler.core import ( + CompiledFunction, + CompilerBase, + DFContainer, + PortVariable, +) from guppy.error import GuppyError, InternalGuppyError from guppy.gtypes import ( BoolType, @@ -14,8 +22,16 @@ type_to_row, ) from guppy.hugr import ops, val -from guppy.hugr.hugr import OutPortV -from guppy.nodes import GlobalCall, GlobalName, LocalCall, LocalName, TypeApply +from guppy.hugr.hugr import DFContainingNode, OutPortV, VNode +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + GlobalCall, + GlobalName, + LocalCall, + LocalName, + TypeApply, +) class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): @@ -39,6 +55,85 @@ def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: """ return [self.compile(e, dfg) for e in expr_to_row(expr)] + @contextmanager + def _new_dfcontainer( + self, inputs: list[ast.Name], node: DFContainingNode + ) -> Iterator[None]: + """Context manager to build a graph inside a new `DFContainer`. + + Automatically updates `self.dfg` and makes the inputs available. + """ + old = self.dfg + inp = self.graph.add_input(parent=node) + # Check that the input names are unique + assert len({inp.id for inp in inputs}) == len(inputs), "Inputs are not unique" + new_locals = { + name.id: PortVariable(name.id, inp.add_out_port(get_type(name)), name, None) + for name in inputs + } + self.dfg = DFContainer(node, self.dfg.locals | new_locals) + with self.graph.parent(node): + yield + self.dfg = old + + @contextmanager + def _new_loop( + self, + loop_vars: list[ast.Name], + branch: ast.Name, + parent: DFContainingNode | None = None, + ) -> Iterator[None]: + """Context manager to build a graph inside a new `TailLoop` node. + + Automatically adds the `Output` node to the loop body once the context manager + exits. + """ + loop = self.graph.add_tail_loop( + [self.visit(name) for name in loop_vars], parent + ) + with self._new_dfcontainer(loop_vars, loop): + yield + # Output the branch predicate and the inputs for the next iteration + self.graph.add_output( + # Note that we have to do fresh calls to `self.visit` here since we're + # in a new context + [self.visit(branch), *(self.visit(name) for name in loop_vars)] + ) + # Update the DFG with the outputs from the loop + for name in loop_vars: + self.dfg[name.id].port = loop.add_out_port(get_type(name)) + + @contextmanager + def _new_case( + self, inputs: list[ast.Name], outputs: list[ast.Name], cond_node: VNode + ) -> Iterator[None]: + """Context manager to build a graph inside a new `Case` node. + + Automatically adds the `Output` node once the context manager exits. + """ + with self._new_dfcontainer(inputs, self.graph.add_case(cond_node)): + yield + self.graph.add_output([self.visit(name) for name in outputs]) + + @contextmanager + def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]: + """Context manager to build a graph inside the `true` case of a `Conditional` + + In the `false` case, the inputs are outputted as is. + """ + cond_node = self.graph.add_conditional( + self.visit(cond), [self.visit(inp) for inp in inputs] + ) + # If the condition is false, output the inputs as is + with self._new_case(inputs, inputs, cond_node): + pass + # If the condition is true, we enter the `with` block + with self._new_case(inputs, inputs, cond_node): + yield + # Update the DFG with the outputs from the Conditional node + for name in inputs: + self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) + def visit_Constant(self, node: ast.Constant) -> OutPortV: if value := python_value_to_hugr(node.value): const = self.graph.add_constant(value, get_type(node)).out_port(0) @@ -59,6 +154,12 @@ def visit_Tuple(self, node: ast.Tuple) -> OutPortV: inputs=[self.visit(e) for e in node.elts] ).out_port(0) + def visit_List(self, node: ast.List) -> OutPortV: + # Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension + return self.graph.add_node( + ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts] + ).add_out_port(get_type(node)) + def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: """Groups function return values into a tuple""" if len(returns) != 1: @@ -118,6 +219,61 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") + def visit_DesugaredListComp(self, node: DesugaredListComp) -> OutPortV: + from guppy.compiler.stmt_compiler import StmtCompiler + + compiler = StmtCompiler(self.graph, self.globals) + + # Make up a name for the list under construction and bind it to an empty list + list_ty = get_type(node) + list_name = with_type(list_ty, with_loc(node, LocalName(id=next(tmp_vars)))) + empty_list = self.graph.add_node(ops.DummyOp(name="MakeList")) + self.dfg[list_name.id] = PortVariable( + list_name.id, empty_list.add_out_port(list_ty), node, None + ) + + def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: + """Helper function to generate nested TailLoop nodes for generators""" + # If there are no more generators left, just append the element to the list + if not gens: + list_port, elt_port = self.visit(list_name), self.visit(elt) + push = self.graph.add_node( + ops.DummyOp(name="Push"), inputs=[list_port, elt_port] + ) + self.dfg[list_name.id].port = push.add_out_port(list_port.ty) + return + + # Otherwise, compile the first iterator and construct a TailLoop + gen, *gens = gens + compiler.compile_stmts([gen.iter_assign], self.dfg) + inputs = [gen.iter, list_name] + with self._new_loop(inputs, gen.hasnext): + # If there is a next element, compile it and continue with the next + # generator + compiler.compile_stmts([gen.hasnext_assign], self.dfg) + with self._if_true(gen.hasnext, inputs): + + def compile_ifs(ifs: list[ast.expr]) -> None: + """Helper function to compile a series of if-guards into nested + Conditional nodes.""" + if ifs: + if_expr, *ifs = ifs + # If the condition is true, continue with the next one + with self._if_true(if_expr, inputs): + compile_ifs(ifs) + else: + # If there are no guards left, compile the next generator + compile_generators(elt, gens) + + compiler.compile_stmts([gen.next_assign], self.dfg) + compile_ifs(gen.ifs) + + # After the loop is done, we have to finalize the iterator + self.visit(gen.iterend) + + compile_generators(node.elt, node.generators) + return self.visit(list_name) + def visit_BinOp(self, node: ast.BinOp) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py index 18147b22..a2ab4f6a 100644 --- a/guppy/compiler/stmt_compiler.py +++ b/guppy/compiler/stmt_compiler.py @@ -2,7 +2,6 @@ from collections.abc import Sequence from guppy.ast_util import AstVisitor -from guppy.checker.cfg_checker import CheckedBB from guppy.compiler.core import ( CompiledGlobals, CompilerBase, @@ -22,7 +21,6 @@ class StmtCompiler(CompilerBase, AstVisitor[None]): expr_compiler: ExprCompiler - bb: CheckedBB dfg: DFContainer def __init__(self, graph: Hugr, globals: CompiledGlobals): @@ -32,7 +30,6 @@ def __init__(self, graph: Hugr, globals: CompiledGlobals): def compile_stmts( self, stmts: Sequence[ast.stmt], - bb: CheckedBB, dfg: DFContainer, ) -> DFContainer: """Compiles a list of basic statements into a dataflow node. @@ -40,7 +37,6 @@ def compile_stmts( Note that the `dfg` is mutated in-place. After compilation, the DFG will also contain all variables that are assigned in the given list of statements. """ - self.bb = bb self.dfg = dfg for s in stmts: self.visit(s) diff --git a/guppy/declared.py b/guppy/declared.py index b868a374..df1c0f29 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -1,7 +1,7 @@ import ast from dataclasses import dataclass -from guppy.ast_util import AstNode, has_empty_body +from guppy.ast_util import AstNode, has_empty_body, with_loc from guppy.checker.core import Context, Globals from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature @@ -34,14 +34,14 @@ def check_call( ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker args, subst, inst = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), subst + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty, inst = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), ty + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), ty def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 8bc43b71..ab0e2b1a 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -386,6 +386,7 @@ def add_output( parent: Node | None = None, ) -> VNode: """Adds an `Output` node to the graph.""" + parent = parent or self._default_parent node = self.add_node(ops.Output(), input_tys, [], parent, inputs) if isinstance(parent, DFContainingNode): parent.output_child = node @@ -706,6 +707,25 @@ def insert_order_edges(self) -> "Hugr": elif isinstance(n.op, ops.LoadConstant): assert n.parent.input_child is not None self.add_order_edge(n.parent.input_child, n) + + # Also add order edges for non-local edges + for src, tgt in list(self.edges()): + # Exclude CF and constant edges + if isinstance(src, OutPortCF) or isinstance( + src.node.op, ops.FuncDecl | ops.FuncDefn | ops.Const + ): + continue + + if src.node.parent != tgt.node.parent: + # Walk up the hierarchy from the tgt until we hit a node at the same + # level as src + node = tgt.node + while node.parent != src.node.parent: + if node.parent is None: + raise ValueError("Invalid non-local edge!") + node = node.parent + # Add order edge to make sure that the src is executed first + self.add_order_edge(src.node, node) return self def to_raw(self) -> raw.RawHugr: diff --git a/tests/error/linear_errors/branch_use.err b/tests/error/linear_errors/branch_use.err index 742a0ca7..c096aa16 100644 --- a/tests/error/linear_errors/branch_use.err +++ b/tests/error/linear_errors/branch_use.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:23 21: @guppy(module) 22: def foo(b: bool) -> bool: 23: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/break_unused.err b/tests/error/linear_errors/break_unused.err index a91d243e..8ee9039f 100644 --- a/tests/error/linear_errors/break_unused.err +++ b/tests/error/linear_errors/break_unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:25 23: b = False 24: while True: 25: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/continue_unused.err b/tests/error/linear_errors/continue_unused.err index 6b91f2be..7ecd4c5a 100644 --- a/tests/error/linear_errors/continue_unused.err +++ b/tests/error/linear_errors/continue_unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:25 23: b = False 24: while i > 0: 25: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/if_both_unused.err b/tests/error/linear_errors/if_both_unused.err index db518100..f942875c 100644 --- a/tests/error/linear_errors/if_both_unused.err +++ b/tests/error/linear_errors/if_both_unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:19 17: def foo(b: bool) -> int: 18: if b: 19: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/if_both_unused_reassign.err b/tests/error/linear_errors/if_both_unused_reassign.err index cacd42ca..0b0d5dbe 100644 --- a/tests/error/linear_errors/if_both_unused_reassign.err +++ b/tests/error/linear_errors/if_both_unused_reassign.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:19 17: def foo(b: bool) -> Qubit: 18: if b: 19: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/unused.err b/tests/error/linear_errors/unused.err index 7d62fee3..764d3d95 100644 --- a/tests/error/linear_errors/unused.err +++ b/tests/error/linear_errors/unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:13 11: @guppy(module) 12: def foo(q: Qubit) -> int: 13: x = q - ^^^^^ + ^ GuppyError: Variable `x` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/unused_same_block.err b/tests/error/linear_errors/unused_same_block.err index 960f266c..835fff05 100644 --- a/tests/error/linear_errors/unused_same_block.err +++ b/tests/error/linear_errors/unused_same_block.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:13 11: @guppy(module) 12: def foo(q: Qubit) -> int: 13: x = q - ^^^^^ + ^ GuppyError: Variable `x` with linear type `Qubit` is not used