Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Compile lists and comprehensions #70

Merged
merged 28 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
359c2da
feat: Build for loops and comprehensions
mark-koch Dec 20, 2023
36396cc
feat: Add list and comprehension type checking
mark-koch Dec 20, 2023
000e146
Fix error golden files
mark-koch Dec 20, 2023
34d5415
feat: Compile lists and comprehensions
mark-koch Dec 20, 2023
f51f0bf
Improve instance method helper function
mark-koch Jan 8, 2024
01319cc
Don't override usage stats
mark-koch Jan 8, 2024
2359ab0
Include parent scope in Locals.__iter__
mark-koch Jan 8, 2024
81dcddb
Don't require linear return for __next__
mark-koch Jan 8, 2024
34e9ce2
Improve error msg
mark-koch Jan 9, 2024
f8e6768
Add comment to visit_List
mark-koch Jan 11, 2024
b02dc3d
Check uniqueness of input names
mark-koch Jan 11, 2024
34ba5e1
Rename _new_loop inputs to variants
mark-koch Jan 11, 2024
17aca09
Fix spelling and clarify docstring
mark-koch Jan 11, 2024
7639089
Fix _new_case inputs for outputs
mark-koch Jan 11, 2024
8cb0df0
Add comment explaining passing `inputs` twice
mark-koch Jan 11, 2024
20b2d99
Use inputs comprehension
mark-koch Jan 11, 2024
a14ae45
Fix comments
mark-koch Jan 11, 2024
8f561e2
Add _if_true context manager
mark-koch Jan 11, 2024
1fb02a0
Rename variants to loop_vars
mark-koch Jan 15, 2024
13eac52
Add comment explaining fresh visit calls
mark-koch Jan 15, 2024
633421d
Merge branch 'feat/lists' into lists/build
mark-koch Jan 15, 2024
7630f7b
Add make_assign helper
mark-koch Jan 15, 2024
70c5d07
Turn TemplateReplacer into dataclass
mark-koch Jan 16, 2024
fd8f56c
Add missing location
mark-koch Jan 16, 2024
6790161
Adjust context in make_assign
mark-koch Jan 16, 2024
f5e4ea9
Merge remote-tracking branch 'origin/lists/build' into lists/check
mark-koch Jan 16, 2024
74102c3
Merge remote-tracking branch 'origin/lists/check' into lists/compile
mark-koch Jan 16, 2024
81d1722
Fix formatting
mark-koch Jan 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 153 additions & 37 deletions guppy/ast_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ast
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import textwrap
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast

if TYPE_CHECKING:
from guppy.gtypes import GuppyType
Expand Down Expand Up @@ -54,51 +56,165 @@ def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> T:
raise NotImplementedError(f"visit_{node.__class__.__name__} is not implemented")


class NameVisitor(ast.NodeVisitor):
"""Visitor to collect all `Name` nodes occurring in an AST."""

names: list[ast.Name]

def __init__(self) -> None:
self.names = []

def visit_Name(self, node: ast.Name) -> None:
self.names.append(node)
class AstSearcher(ast.NodeVisitor):
"""Visitor that searches for occurrences of specific nodes in an AST."""

matcher: Callable[[ast.AST], bool]
dont_recurse_into: set[type[ast.AST]]
found: list[ast.AST]
is_first_node: bool

def __init__(
self,
matcher: Callable[[ast.AST], bool],
dont_recurse_into: set[type[ast.AST]] | None = None,
) -> None:
self.matcher = matcher
self.dont_recurse_into = dont_recurse_into or set()
self.found = []
self.is_first_node = True

def generic_visit(self, node: ast.AST) -> None:
if self.matcher(node):
self.found.append(node)
if self.is_first_node or type(node) not in self.dont_recurse_into:
self.is_first_node = False
super().generic_visit(node)


def find_nodes(
matcher: Callable[[ast.AST], bool],
node: ast.AST,
dont_recurse_into: set[type[ast.AST]] | None = None,
) -> list[ast.AST]:
"""Returns all nodes in the AST that satisfy the matcher."""
v = AstSearcher(matcher, dont_recurse_into)
v.visit(node)
return v.found


def name_nodes_in_ast(node: Any) -> list[ast.Name]:
"""Returns all `Name` nodes occurring in an AST."""
v = NameVisitor()
v.visit(node)
return v.names


class ReturnVisitor(ast.NodeVisitor):
"""Visitor to collect all `Return` nodes occurring in an AST."""
found = find_nodes(lambda n: isinstance(n, ast.Name), node)
return cast(list[ast.Name], found)

returns: list[ast.Return]
inside_func_def: bool

def __init__(self) -> None:
self.returns = []
self.inside_func_def = False

def visit_Return(self, node: ast.Return) -> None:
self.returns.append(node)
def return_nodes_in_ast(node: Any) -> list[ast.Return]:
"""Returns all `Return` nodes occurring in an AST."""
found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef})
return cast(list[ast.Return], found)

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
# Don't descend into nested function definitions
if not self.inside_func_def:
self.inside_func_def = True
for n in node.body:
self.visit(n)

def breaks_in_loop(node: Any) -> list[ast.Break]:
"""Returns all `Break` nodes occurring in a loop.

def return_nodes_in_ast(node: Any) -> list[ast.Return]:
"""Returns all `Return` nodes occurring in an AST."""
v = ReturnVisitor()
v.visit(node)
return v.returns
Note that breaks in nested loops are excluded.
"""
found = find_nodes(
lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef}
)
return cast(list[ast.Break], found)


class ContextAdjuster(ast.NodeTransformer):
"""Updates the `ast.Context` indicating if expressions occur on the LHS or RHS."""

ctx: ast.expr_context

def __init__(self, ctx: ast.expr_context) -> None:
self.ctx = ctx

def visit(self, node: ast.AST) -> ast.AST:
return cast(ast.AST, super().visit(node))

def visit_Name(self, node: ast.Name) -> ast.Name:
return with_loc(node, ast.Name(id=node.id, ctx=self.ctx))

def visit_Starred(self, node: ast.Starred) -> ast.Starred:
return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx))

def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple:
return with_loc(
node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
)

def visit_List(self, node: ast.List) -> ast.List:
return with_loc(
node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
)

def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
# Don't adjust the slice!
return with_loc(
node,
ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx),
)

def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
return ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx)


class TemplateReplacer(ast.NodeTransformer):
"""Replaces nodes in a template."""

replacements: Mapping[str, ast.AST | Sequence[ast.AST]]
default_loc: ast.AST

def __init__(
self,
replacements: Mapping[str, ast.AST | Sequence[ast.AST]],
default_loc: ast.AST,
) -> None:
self.replacements = replacements
self.default_loc = default_loc

def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]:
if x not in self.replacements:
msg = f"No replacement for `{x}` is given"
raise ValueError(msg)
return self.replacements[x]

def visit_Name(self, node: ast.Name) -> ast.AST:
repl = self._get_replacement(node.id)
if not isinstance(repl, ast.expr):
msg = f"Replacement for `{node.id}` must be an expression"
raise TypeError(msg)

# Update the context
adjuster = ContextAdjuster(node.ctx)
return with_loc(repl, adjuster.visit(repl))

def visit_Expr(self, node: ast.Expr) -> ast.AST | Sequence[ast.AST]:
if isinstance(node.value, ast.Name):
repl = self._get_replacement(node.value.id)
repls = [repl] if not isinstance(repl, Sequence) else repl
# Wrap expressions to turn them into statements
return [
with_loc(r, ast.Expr(value=r)) if isinstance(r, ast.expr) else r
for r in repls
]
return self.generic_visit(node)

def generic_visit(self, node: ast.AST) -> ast.AST:
# Insert the default location
node = super().generic_visit(node)
return with_loc(self.default_loc, node)


def template_replace(
template: str, default_loc: ast.AST, **kwargs: ast.AST | Sequence[ast.AST]
) -> list[ast.stmt]:
"""Turns a template into a proper AST by substituting all placeholders."""
nodes = ast.parse(textwrap.dedent(template)).body
replacer = TemplateReplacer(kwargs, default_loc)
new_nodes = []
for n in nodes:
new = replacer.visit(n)
if isinstance(new, list):
new_nodes.extend(new)
else:
new_nodes.append(new)
return new_nodes


def line_col(node: ast.AST) -> tuple[int, int]:
Expand Down
33 changes: 26 additions & 7 deletions guppy/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self

from guppy.ast_util import AstNode, name_nodes_in_ast
from guppy.nodes import NestedFunctionDef
from guppy.nodes import DesugaredListComp, NestedFunctionDef

if TYPE_CHECKING:
from guppy.cfg.cfg import BaseCFG
Expand Down Expand Up @@ -99,24 +99,46 @@ def __init__(self, bb: BB):
self.bb = bb
self.stats = VariableStats()

def visit_Name(self, node: ast.Name) -> None:
self.stats.update_used(node)

def visit_Assign(self, node: ast.Assign) -> None:
self.stats.update_used(node.value)
self.visit(node.value)
for t in node.targets:
for name in name_nodes_in_ast(t):
self.stats.assigned[name.id] = node

def visit_AugAssign(self, node: ast.AugAssign) -> None:
self.stats.update_used(node.value)
self.visit(node.value)
self.stats.update_used(node.target) # The target is also used
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
if node.value:
self.stats.update_used(node.value)
self.visit(node.value)
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:
# Names bound in the comprehension are only available inside, so we shouldn't
# update `self.stats` with assignments
inner_visitor = VariableVisitor(self.bb)
inner_stats = inner_visitor.stats

# The generators are evaluated left to right
for gen in node.generators:
inner_visitor.visit(gen.iter_assign)
inner_visitor.visit(gen.hasnext_assign)
inner_visitor.visit(gen.next_assign)
for cond in gen.ifs:
inner_visitor.visit(cond)
inner_visitor.visit(node.elt)

self.stats.used = {
x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned
}

def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:
# In order to compute the used external variables in a nested function
# definition, we have to run live variable analysis first
Expand All @@ -139,6 +161,3 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:

# The name of the function is now assigned
self.stats.assigned[node.name] = node

def generic_visit(self, node: ast.AST) -> None:
self.stats.update_used(node)
Loading