From 27f6f048f0bd6fd340e7010a96517340491ec119 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 31 May 2024 14:27:17 -0600 Subject: [PATCH] [red-knot] initial (very incomplete) flow graph (#11624) ## Summary Introduces the skeleton of the flow graph. So far it doesn't actually handle any non-linear control flow :) But it does show how we can go from an expression that references a symbol, backward through the flow graph, to find reachable definitions of that symbol. Adding non-linear control flow will mean adding flow nodes with multiple predecessors, which will introduce more complexity into `ReachableDefinitionsIterator.next()`. But one step at a time. ## Test Plan Added a (very basic) test. --- crates/red_knot/src/ast_ids.rs | 11 ++- crates/red_knot/src/symbols.rs | 138 +++++++++++++++++++++++++++++++-- 2 files changed, 137 insertions(+), 12 deletions(-) diff --git a/crates/red_knot/src/ast_ids.rs b/crates/red_knot/src/ast_ids.rs index 784e44b22d727..0229fb748eea2 100644 --- a/crates/red_knot/src/ast_ids.rs +++ b/crates/red_knot/src/ast_ids.rs @@ -275,10 +275,7 @@ pub struct TypedNodeKey { impl TypedNodeKey { pub fn from_node(node: &N) -> Self { - let inner = NodeKey { - kind: node.as_any_node_ref().kind(), - range: node.range(), - }; + let inner = NodeKey::from_node(node.as_any_node_ref()); Self { inner, _marker: PhantomData, @@ -352,6 +349,12 @@ pub struct NodeKey { } impl NodeKey { + pub fn from_node(node: AnyNodeRef) -> Self { + NodeKey { + kind: node.kind(), + range: node.range(), + } + } pub fn resolve<'a>(&self, root: AnyNodeRef<'a>) -> Option> { // We need to do a binary search here. Only traverse into a node if the range is withint the node let mut visitor = FindNodeKeyVisitor { diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 475c52e3516dd..3b4db31469b88 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -236,6 +236,7 @@ pub struct SymbolTable { scopes_by_node: FxHashMap, /// dependencies of this module dependencies: Vec, + flow_graph: FlowGraph, } impl SymbolTable { @@ -245,6 +246,7 @@ impl SymbolTable { table: SymbolTable::new(), scopes: vec![root_scope_id], current_definition: None, + current_flow_node: FlowGraph::start(), }; builder.visit_body(&module.body); builder.table @@ -257,6 +259,7 @@ impl SymbolTable { defs: FxHashMap::default(), scopes_by_node: FxHashMap::default(), dependencies: Vec::new(), + flow_graph: FlowGraph::new(), }; table.scopes_by_id.push(Scope { name: Name::new(""), @@ -274,6 +277,23 @@ impl SymbolTable { &self.dependencies } + /// Return an iterator over all definitions of `symbol_id` reachable from `use_expr`. The value + /// of `symbol_id` in `use_expr` must originate from one of the iterated definitions (or from + /// an external reassignment of the name outside of this scope). + pub(crate) fn reachable_definitions( + &self, + symbol_id: SymbolId, + use_expr: &ast::Expr, + ) -> ReachableDefinitionsIterator { + let node_key = NodeKey::from_node(use_expr.into()); + let flow_node_id = self.flow_graph.ast_to_flow[&node_key]; + ReachableDefinitionsIterator { + table: self, + flow_node_id, + symbol_id, + } + } + pub(crate) const fn root_scope_id() -> ScopeId { ScopeId::from_usize(0) } @@ -523,11 +543,72 @@ where } } +pub(crate) struct ReachableDefinitionsIterator<'a> { + table: &'a SymbolTable, + flow_node_id: FlowNodeId, + symbol_id: SymbolId, +} + +impl<'a> Iterator for ReachableDefinitionsIterator<'a> { + type Item = Definition; + + fn next(&mut self) -> Option { + loop { + match &self.table.flow_graph.flow_nodes_by_id[self.flow_node_id] { + FlowNode::Start => return None, + FlowNode::Definition(def_node) => { + self.flow_node_id = def_node.predecessor; + if def_node.symbol_id == self.symbol_id { + return Some(def_node.definition.clone()); + } + } + } + } + } +} + +impl<'a> FusedIterator for ReachableDefinitionsIterator<'a> {} + +#[newtype_index] +struct FlowNodeId; + +#[derive(Debug)] +enum FlowNode { + Start, + Definition(DefinitionFlowNode), +} + +#[derive(Debug)] +struct DefinitionFlowNode { + symbol_id: SymbolId, + definition: Definition, + predecessor: FlowNodeId, +} + +#[derive(Debug, Default)] +struct FlowGraph { + flow_nodes_by_id: IndexVec, + ast_to_flow: FxHashMap, +} + +impl FlowGraph { + fn new() -> Self { + let mut graph = FlowGraph::default(); + graph.flow_nodes_by_id.push(FlowNode::Start); + graph + } + + fn start() -> FlowNodeId { + FlowNodeId::from_usize(0) + } +} + struct SymbolTableBuilder { table: SymbolTable, scopes: Vec, /// the definition whose target(s) we are currently walking current_definition: Option, + current_flow_node: FlowNodeId, } impl SymbolTableBuilder { @@ -546,7 +627,16 @@ impl SymbolTableBuilder { .defs .entry(symbol_id) .or_default() - .push(definition); + .push(definition.clone()); + self.current_flow_node = self + .table + .flow_graph + .flow_nodes_by_id + .push(FlowNode::Definition(DefinitionFlowNode { + definition, + symbol_id, + predecessor: self.current_flow_node, + })); symbol_id } @@ -561,6 +651,7 @@ impl SymbolTableBuilder { self.table .add_child_scope(self.cur_scope(), name, kind, definition, defining_symbol); self.scopes.push(scope_id); + self.current_flow_node = FlowGraph::start(); scope_id } @@ -624,6 +715,10 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { } } } + self.table + .flow_graph + .ast_to_flow + .insert(NodeKey::from_node(expr.into()), self.current_flow_node); ast::visitor::preorder::walk_expr(self, expr); } @@ -766,15 +861,13 @@ impl DerefMut for SymbolTablesStorage { #[cfg(test)] mod tests { - use textwrap::dedent; - - use crate::parse::Parsed; - use crate::symbols::ScopeKind; - - use super::{SymbolFlags, SymbolId, SymbolIterator, SymbolTable}; + use crate::symbols::{ScopeKind, SymbolFlags, SymbolTable}; mod from_ast { - use super::*; + use crate::parse::Parsed; + use crate::symbols::{Definition, ScopeKind, SymbolId, SymbolIterator, SymbolTable}; + use ruff_python_ast as ast; + use textwrap::dedent; fn parse(code: &str) -> Parsed { Parsed::from_text(&dedent(code)) @@ -1021,6 +1114,35 @@ mod tests { assert_eq!(func_scope.name(), "C"); assert_eq!(names(table.symbols_for_scope(func_scope_id)), vec!["x"]); } + + #[test] + fn reachability_trivial() { + let parsed = parse("x = 1; x"); + let ast = parsed.ast(); + let table = SymbolTable::from_ast(ast); + let x_sym = table + .root_symbol_id_by_name("x") + .expect("x symbol should exist"); + let ast::Stmt::Expr(ast::StmtExpr { value: x_use, .. }) = &ast.body[1] else { + panic!("should be an expr") + }; + let x_defs: Vec<_> = table.reachable_definitions(x_sym, x_use).collect(); + assert_eq!(x_defs.len(), 1); + let Definition::Assignment(node_key) = &x_defs[0] else { + panic!("def should be an assignment") + }; + let Some(def_node) = node_key.resolve(ast.into()) else { + panic!("node key should resolve") + }; + let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + value: ast::Number::Int(num), + .. + }) = &*def_node.value + else { + panic!("should be a number literal") + }; + assert_eq!(*num, 1); + } } #[test]