diff --git a/Cargo.lock b/Cargo.lock index d9033c14f217c..c6130245e1a37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1904,6 +1904,8 @@ dependencies = [ "ruff_text_size", "rustc-hash 2.0.0", "salsa", + "smallvec", + "static_assertions", "tempfile", "tracing", "walkdir", diff --git a/crates/red_knot_python_semantic/Cargo.toml b/crates/red_knot_python_semantic/Cargo.toml index 1019ce943469c..d07978271b3a9 100644 --- a/crates/red_knot_python_semantic/Cargo.toml +++ b/crates/red_knot_python_semantic/Cargo.toml @@ -29,6 +29,8 @@ salsa = { workspace = true } tracing = { workspace = true } rustc-hash = { workspace = true } hashbrown = { workspace = true } +smallvec = { workspace = true } +static_assertions = { workspace = true } [build-dependencies] path-slash = { workspace = true } diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index fef72fe74ca80..56c7e31d85c88 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -16,10 +16,9 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolTable, }; +use crate::semantic_index::use_def::UseDefMap; use crate::Db; -pub(crate) use self::use_def::UseDefMap; - pub mod ast_ids; mod builder; pub mod definition; @@ -27,6 +26,8 @@ pub mod expression; pub mod symbol; mod use_def; +pub(crate) use self::use_def::{DefinitionWithConstraints, DefinitionWithConstraintsIterator}; + type SymbolMap = hashbrown::HashMap; /// Returns the semantic index for `file`. @@ -310,12 +311,29 @@ mod tests { use ruff_text_size::{Ranged, TextRange}; use crate::db::tests::TestDb; - use crate::semantic_index::ast_ids::HasScopedUseId; - use crate::semantic_index::definition::DefinitionKind; - use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable}; + use crate::semantic_index::ast_ids::{HasScopedUseId, ScopedUseId}; + use crate::semantic_index::definition::{Definition, DefinitionKind}; + use crate::semantic_index::symbol::{ + FileScopeId, Scope, ScopeKind, ScopedSymbolId, SymbolTable, + }; + use crate::semantic_index::use_def::UseDefMap; use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map}; use crate::Db; + impl UseDefMap<'_> { + fn first_public_definition(&self, symbol: ScopedSymbolId) -> Option> { + self.public_definitions(symbol) + .next() + .map(|constrained_definition| constrained_definition.definition) + } + + fn first_use_definition(&self, use_id: ScopedUseId) -> Option> { + self.use_definitions(use_id) + .next() + .map(|constrained_definition| constrained_definition.definition) + } + } + struct TestCase { db: TestDb, file: File, @@ -374,9 +392,7 @@ mod tests { let foo = global_table.symbol_id_by_name("foo").unwrap(); let use_def = use_def_map(&db, scope); - let [definition] = use_def.public_definitions(foo) else { - panic!("expected one definition"); - }; + let definition = use_def.first_public_definition(foo).unwrap(); assert!(matches!(definition.node(&db), DefinitionKind::Import(_))); } @@ -411,13 +427,13 @@ mod tests { ); let use_def = use_def_map(&db, scope); - let [definition] = use_def.public_definitions( - global_table - .symbol_id_by_name("foo") - .expect("symbol to exist"), - ) else { - panic!("expected one definition"); - }; + let definition = use_def + .first_public_definition( + global_table + .symbol_id_by_name("foo") + .expect("symbol to exist"), + ) + .unwrap(); assert!(matches!( definition.node(&db), DefinitionKind::ImportFrom(_) @@ -438,11 +454,9 @@ mod tests { "a symbol used but not defined in a scope should have only the used flag" ); let use_def = use_def_map(&db, scope); - let [definition] = - use_def.public_definitions(global_table.symbol_id_by_name("x").expect("symbol exists")) - else { - panic!("expected one definition"); - }; + let definition = use_def + .first_public_definition(global_table.symbol_id_by_name("x").expect("symbol exists")) + .unwrap(); assert!(matches!( definition.node(&db), DefinitionKind::Assignment(_) @@ -477,11 +491,9 @@ y = 2 assert_eq!(names(&class_table), vec!["x"]); let use_def = index.use_def_map(class_scope_id); - let [definition] = - use_def.public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists")) - else { - panic!("expected one definition"); - }; + let definition = use_def + .first_public_definition(class_table.symbol_id_by_name("x").expect("symbol exists")) + .unwrap(); assert!(matches!( definition.node(&db), DefinitionKind::Assignment(_) @@ -515,13 +527,13 @@ y = 2 assert_eq!(names(&function_table), vec!["x"]); let use_def = index.use_def_map(function_scope_id); - let [definition] = use_def.public_definitions( - function_table - .symbol_id_by_name("x") - .expect("symbol exists"), - ) else { - panic!("expected one definition"); - }; + let definition = use_def + .first_public_definition( + function_table + .symbol_id_by_name("x") + .expect("symbol exists"), + ) + .unwrap(); assert!(matches!( definition.node(&db), DefinitionKind::Assignment(_) @@ -557,26 +569,26 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): let use_def = index.use_def_map(function_scope_id); for name in ["a", "b", "c", "d"] { - let [definition] = use_def.public_definitions( - function_table - .symbol_id_by_name(name) - .expect("symbol exists"), - ) else { - panic!("Expected parameter definition for {name}"); - }; + let definition = use_def + .first_public_definition( + function_table + .symbol_id_by_name(name) + .expect("symbol exists"), + ) + .unwrap(); assert!(matches!( definition.node(&db), DefinitionKind::ParameterWithDefault(_) )); } for name in ["args", "kwargs"] { - let [definition] = use_def.public_definitions( - function_table - .symbol_id_by_name(name) - .expect("symbol exists"), - ) else { - panic!("Expected parameter definition for {name}"); - }; + let definition = use_def + .first_public_definition( + function_table + .symbol_id_by_name(name) + .expect("symbol exists"), + ) + .unwrap(); assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_))); } } @@ -605,22 +617,22 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): let use_def = index.use_def_map(lambda_scope_id); for name in ["a", "b", "c", "d"] { - let [definition] = use_def - .public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists")) - else { - panic!("Expected parameter definition for {name}"); - }; + let definition = use_def + .first_public_definition( + lambda_table.symbol_id_by_name(name).expect("symbol exists"), + ) + .unwrap(); assert!(matches!( definition.node(&db), DefinitionKind::ParameterWithDefault(_) )); } for name in ["args", "kwargs"] { - let [definition] = use_def - .public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists")) - else { - panic!("Expected parameter definition for {name}"); - }; + let definition = use_def + .first_public_definition( + lambda_table.symbol_id_by_name(name).expect("symbol exists"), + ) + .unwrap(); assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_))); } } @@ -691,9 +703,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): let element_use_id = element.scoped_use_id(&db, comprehension_scope_id.to_scope_id(&db, file)); - let [definition] = use_def.use_definitions(element_use_id) else { - panic!("expected one definition") - }; + let definition = use_def.first_use_definition(element_use_id).unwrap(); let DefinitionKind::Comprehension(comprehension) = definition.node(&db) else { panic!("expected generator definition") }; @@ -790,13 +800,13 @@ def func(): assert_eq!(names(&func2_table), vec!["y"]); let use_def = index.use_def_map(FileScopeId::global()); - let [definition] = use_def.public_definitions( - global_table - .symbol_id_by_name("func") - .expect("symbol exists"), - ) else { - panic!("expected one definition"); - }; + let definition = use_def + .first_public_definition( + global_table + .symbol_id_by_name("func") + .expect("symbol exists"), + ) + .unwrap(); assert!(matches!(definition.node(&db), DefinitionKind::Function(_))); } @@ -897,9 +907,7 @@ class C[T]: }; let x_use_id = x_use_expr_name.scoped_use_id(&db, scope); let use_def = use_def_map(&db, scope); - let [definition] = use_def.use_definitions(x_use_id) else { - panic!("expected one definition"); - }; + let definition = use_def.first_use_definition(x_use_id).unwrap(); let DefinitionKind::Assignment(assignment) = definition.node(&db) else { panic!("should be an assignment definition") }; diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 7fa6fe1639d0c..246c810216ae4 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -155,7 +155,7 @@ impl<'db> SemanticIndexBuilder<'db> { self.current_use_def_map_mut().restore(state); } - fn flow_merge(&mut self, state: &FlowSnapshot) { + fn flow_merge(&mut self, state: FlowSnapshot) { self.current_use_def_map_mut().merge(state); } @@ -195,9 +195,16 @@ impl<'db> SemanticIndexBuilder<'db> { definition } + fn add_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> { + let expression = self.add_standalone_expression(constraint_node); + self.current_use_def_map_mut().record_constraint(expression); + + expression + } + /// Record an expression that needs to be a Salsa ingredient, because we need to infer its type /// standalone (type narrowing tests, RHS of an assignment.) - fn add_standalone_expression(&mut self, expression_node: &ast::Expr) { + fn add_standalone_expression(&mut self, expression_node: &ast::Expr) -> Expression<'db> { let expression = Expression::new( self.db, self.file, @@ -210,6 +217,7 @@ impl<'db> SemanticIndexBuilder<'db> { ); self.expressions_by_node .insert(expression_node.into(), expression); + expression } fn with_type_params( @@ -476,6 +484,7 @@ where ast::Stmt::If(node) => { self.visit_expr(&node.test); let pre_if = self.flow_snapshot(); + self.add_constraint(&node.test); self.visit_body(&node.body); let mut post_clauses: Vec = vec![]; for clause in &node.elif_else_clauses { @@ -488,7 +497,7 @@ where self.visit_elif_else_clause(clause); } for post_clause_state in post_clauses { - self.flow_merge(&post_clause_state); + self.flow_merge(post_clause_state); } let has_else = node .elif_else_clauses @@ -497,7 +506,7 @@ where if !has_else { // if there's no else clause, then it's possible we took none of the branches, // and the pre_if state can reach here - self.flow_merge(&pre_if); + self.flow_merge(pre_if); } } ast::Stmt::While(node) => { @@ -515,13 +524,13 @@ where // We may execute the `else` clause without ever executing the body, so merge in // the pre-loop state before visiting `else`. - self.flow_merge(&pre_loop); + self.flow_merge(pre_loop); self.visit_body(&node.orelse); // Breaking out of a while loop bypasses the `else` clause, so merge in the break // states after visiting `else`. for break_state in break_states { - self.flow_merge(&break_state); + self.flow_merge(break_state); } } ast::Stmt::Break(_) => { @@ -631,7 +640,7 @@ where let post_body = self.flow_snapshot(); self.flow_restore(pre_if); self.visit_expr(orelse); - self.flow_merge(&post_body); + self.flow_merge(post_body); } ast::Expr::ListComp( list_comprehension @ ast::ExprListComp { diff --git a/crates/red_knot_python_semantic/src/semantic_index/expression.rs b/crates/red_knot_python_semantic/src/semantic_index/expression.rs index 8dcbc44e28667..4a7582bc32243 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/expression.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/expression.rs @@ -21,7 +21,7 @@ pub(crate) struct Expression<'db> { /// The expression node. #[no_eq] #[return_ref] - pub(crate) node: AstNodeRef, + pub(crate) node_ref: AstNodeRef, #[no_eq] count: countme::Count>, diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index f3e1afe98273e..96fe0fd56d9af 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -1,4 +1,5 @@ -//! Build a map from each use of a symbol to the definitions visible from that use. +//! Build a map from each use of a symbol to the definitions visible from that use, and the +//! type-narrowing constraints that apply to each definition. //! //! Let's take this code sample: //! @@ -6,7 +7,7 @@ //! x = 1 //! x = 2 //! y = x -//! if flag: +//! if y is not None: //! x = 3 //! else: //! x = 4 @@ -34,8 +35,8 @@ //! [`AstIds`](crate::semantic_index::ast_ids::AstIds) we number all uses (that means a `Name` node //! with `Load` context) so we have a `ScopedUseId` to efficiently represent each use. //! -//! The other case we need to handle is when a symbol is referenced from a different scope (the -//! most obvious example of this is an import). We call this "public" use of a symbol. So the other +//! Another case we need to handle is when a symbol is referenced from a different scope (the most +//! obvious example of this is an import). We call this "public" use of a symbol. So the other //! question we need to be able to answer is, what are the publicly-visible definitions of each //! symbol? //! @@ -53,42 +54,55 @@ //! start.) //! //! So this means that the publicly-visible definitions of a symbol are the definitions still -//! visible at the end of the scope. +//! visible at the end of the scope; effectively we have an implicit "use" of every symbol at the +//! end of the scope. //! -//! The data structure we build to answer these two questions is the `UseDefMap`. It has a +//! We also need to know, for a given definition of a symbol, what type-narrowing constraints apply +//! to it. For instance, in this code sample: +//! +//! ```python +//! x = 1 if flag else None +//! if x is not None: +//! y = x +//! ``` +//! +//! At the use of `x` in `y = x`, the visible definition of `x` is `1 if flag else None`, which +//! would infer as the type `Literal[1] | None`. But the constraint `x is not None` dominates this +//! use, which means we can rule out the possibility that `x` is `None` here, which should give us +//! the type `Literal[1]` for this use. +//! +//! The data structure we build to answer these questions is the `UseDefMap`. It has a //! `definitions_by_use` vector indexed by [`ScopedUseId`] and a `public_definitions` vector //! indexed by [`ScopedSymbolId`]. The values in each of these vectors are (in principle) a list of -//! visible definitions at that use, or at the end of the scope for that symbol. +//! visible definitions at that use, or at the end of the scope for that symbol, with a list of the +//! dominating constraints for each of those definitions. //! -//! In order to avoid vectors-of-vectors and all the allocations that would entail, we don't -//! actually store these "list of visible definitions" as a vector of [`Definition`] IDs. Instead, -//! the values in `definitions_by_use` and `public_definitions` are a [`Definitions`] struct that -//! keeps a [`Range`] into a third vector of [`Definition`] IDs, `all_definitions`. The trick with -//! this representation is that it requires that the definitions visible at any given use of a -//! symbol are stored sequentially in `all_definitions`. +//! In order to avoid vectors-of-vectors-of-vectors and all the allocations that would entail, we +//! don't actually store these "list of visible definitions" as a vector of [`Definition`]. +//! Instead, the values in `definitions_by_use` and `public_definitions` are a [`SymbolState`] +//! struct which uses bit-sets to track definitions and constraints in terms of +//! [`ScopedDefinitionId`] and [`ScopedConstraintId`], which are indices into the `all_definitions` +//! and `all_constraints` indexvecs in the [`UseDefMap`]. //! -//! There is another special kind of possible "definition" for a symbol: it might be unbound in the -//! scope. (This isn't equivalent to "zero visible definitions", since we may go through an `if` -//! that has a definition for the symbol, leaving us with one visible definition, but still also -//! the "unbound" possibility, since we might not have taken the `if` branch.) +//! There is another special kind of possible "definition" for a symbol: there might be a path from +//! the scope entry to a given use in which the symbol is never bound. //! //! The simplest way to model "unbound" would be as an actual [`Definition`] itself: the initial //! visible [`Definition`] for each symbol in a scope. But actually modeling it this way would -//! dramatically increase the number of [`Definition`] that Salsa must track. Since "unbound" is a +//! unnecessarily increase the number of [`Definition`] that Salsa must track. Since "unbound" is a //! special definition in that all symbols share it, and it doesn't have any additional per-symbol -//! state, we can represent it more efficiently: we use the `may_be_unbound` boolean on the -//! [`Definitions`] struct. If this flag is `true`, it means the symbol/use really has one -//! additional visible "definition", which is the unbound state. If this flag is `false`, it means -//! we've eliminated the possibility of unbound: every path we've followed includes a definition -//! for this symbol. +//! state, and constraints are irrelevant to it, we can represent it more efficiently: we use the +//! `may_be_unbound` boolean on the [`SymbolState`] struct. If this flag is `true`, it means the +//! symbol/use really has one additional visible "definition", which is the unbound state. If this +//! flag is `false`, it means we've eliminated the possibility of unbound: every path we've +//! followed includes a definition for this symbol. //! -//! To build a [`UseDefMap`], the [`UseDefMapBuilder`] is notified of each new use and definition -//! as they are encountered by the +//! To build a [`UseDefMap`], the [`UseDefMapBuilder`] is notified of each new use, definition, and +//! constraint as they are encountered by the //! [`SemanticIndexBuilder`](crate::semantic_index::builder::SemanticIndexBuilder) AST visit. For -//! each symbol, the builder tracks the currently-visible definitions for that symbol. When we hit -//! a use of a symbol, it records the currently-visible definitions for that symbol as the visible -//! definitions for that use. When we reach the end of the scope, it records the currently-visible -//! definitions for each symbol as the public definitions of that symbol. +//! each symbol, the builder tracks the `SymbolState` for that symbol. When we hit a use of a +//! symbol, it records the current state for that symbol for that use. When we reach the end of the +//! scope, it records the state for each symbol as the public definitions of that symbol. //! //! Let's walk through the above example. Initially we record for `x` that it has no visible //! definitions, and may be unbound. When we see `x = 1`, we record that as the sole visible @@ -98,10 +112,11 @@ //! //! Then we hit the `if` branch. We visit the `test` node (`flag` in this case), since that will //! happen regardless. Then we take a pre-branch snapshot of the currently visible definitions for -//! all symbols, which we'll need later. Then we go ahead and visit the `if` body. When we see `x = -//! 3`, it replaces `x = 2` as the sole visible definition of `x`. At the end of the `if` body, we -//! take another snapshot of the currently-visible definitions; we'll call this the post-if-body -//! snapshot. +//! all symbols, which we'll need later. Then we record `flag` as a possible constraint on the +//! currently visible definition (`x = 2`), and go ahead and visit the `if` body. When we see `x = +//! 3`, it replaces `x = 2` (constrained by `flag`) as the sole visible definition of `x`. At the +//! end of the `if` body, we take another snapshot of the currently-visible definitions; we'll call +//! this the post-if-body snapshot. //! //! Now we need to visit the `else` clause. The conditions when entering the `else` clause should //! be the pre-if conditions; if we are entering the `else` clause, we know that the `if` test @@ -125,98 +140,142 @@ //! (In the future we may have some other questions we want to answer as well, such as "is this //! definition used?", which will require tracking a bit more info in our map, e.g. a "used" bit //! for each [`Definition`] which is flipped to true when we record that definition for a use.) +use self::symbol_state::{ + ConstraintIdIterator, DefinitionIdWithConstraintsIterator, ScopedConstraintId, + ScopedDefinitionId, SymbolState, +}; use crate::semantic_index::ast_ids::ScopedUseId; use crate::semantic_index::definition::Definition; +use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::ScopedSymbolId; use ruff_index::IndexVec; -use std::ops::Range; -/// All definitions that can reach a given use of a name. +mod bitset; +mod symbol_state; + +/// Applicable definitions and constraints for every use of a name. #[derive(Debug, PartialEq, Eq)] pub(crate) struct UseDefMap<'db> { - // TODO store constraints with definitions for type narrowing - /// Definition IDs array for `definitions_by_use` and `public_definitions` to slice into. - all_definitions: Vec>, + /// Array of [`Definition`] in this scope. + all_definitions: IndexVec>, + + /// Array of constraints (as [`Expression`]) in this scope. + all_constraints: IndexVec>, - /// Definitions that can reach a [`ScopedUseId`]. - definitions_by_use: IndexVec, + /// [`SymbolState`] visible at a [`ScopedUseId`]. + definitions_by_use: IndexVec, - /// Definitions of each symbol visible at end of scope. - public_definitions: IndexVec, + /// [`SymbolState`] visible at end of scope for each symbol. + public_definitions: IndexVec, } impl<'db> UseDefMap<'db> { - pub(crate) fn use_definitions(&self, use_id: ScopedUseId) -> &[Definition<'db>] { - &self.all_definitions[self.definitions_by_use[use_id].definitions_range.clone()] + pub(crate) fn use_definitions( + &self, + use_id: ScopedUseId, + ) -> DefinitionWithConstraintsIterator<'_, 'db> { + DefinitionWithConstraintsIterator { + all_definitions: &self.all_definitions, + all_constraints: &self.all_constraints, + inner: self.definitions_by_use[use_id].visible_definitions(), + } } pub(crate) fn use_may_be_unbound(&self, use_id: ScopedUseId) -> bool { - self.definitions_by_use[use_id].may_be_unbound + self.definitions_by_use[use_id].may_be_unbound() } - pub(crate) fn public_definitions(&self, symbol: ScopedSymbolId) -> &[Definition<'db>] { - &self.all_definitions[self.public_definitions[symbol].definitions_range.clone()] + pub(crate) fn public_definitions( + &self, + symbol: ScopedSymbolId, + ) -> DefinitionWithConstraintsIterator<'_, 'db> { + DefinitionWithConstraintsIterator { + all_definitions: &self.all_definitions, + all_constraints: &self.all_constraints, + inner: self.public_definitions[symbol].visible_definitions(), + } } pub(crate) fn public_may_be_unbound(&self, symbol: ScopedSymbolId) -> bool { - self.public_definitions[symbol].may_be_unbound + self.public_definitions[symbol].may_be_unbound() } } -/// Definitions visible for a symbol at a particular use (or end-of-scope). -#[derive(Clone, Debug, PartialEq, Eq)] -struct Definitions { - /// [`Range`] in `all_definitions` of the visible definition IDs. - definitions_range: Range, - /// Is the symbol possibly unbound at this point? - may_be_unbound: bool, +#[derive(Debug)] +pub(crate) struct DefinitionWithConstraintsIterator<'map, 'db> { + all_definitions: &'map IndexVec>, + all_constraints: &'map IndexVec>, + inner: DefinitionIdWithConstraintsIterator<'map>, } -impl Definitions { - /// The default state of a symbol is "no definitions, may be unbound", aka definitely-unbound. - fn unbound() -> Self { - Self { - definitions_range: Range::default(), - may_be_unbound: true, - } +impl<'map, 'db> Iterator for DefinitionWithConstraintsIterator<'map, 'db> { + type Item = DefinitionWithConstraints<'map, 'db>; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|def_id_with_constraints| DefinitionWithConstraints { + definition: self.all_definitions[def_id_with_constraints.definition], + constraints: ConstraintsIterator { + all_constraints: self.all_constraints, + constraint_ids: def_id_with_constraints.constraint_ids, + }, + }) } } -impl Default for Definitions { - fn default() -> Self { - Definitions::unbound() +impl std::iter::FusedIterator for DefinitionWithConstraintsIterator<'_, '_> {} + +pub(crate) struct DefinitionWithConstraints<'map, 'db> { + pub(crate) definition: Definition<'db>, + pub(crate) constraints: ConstraintsIterator<'map, 'db>, +} + +pub(crate) struct ConstraintsIterator<'map, 'db> { + all_constraints: &'map IndexVec>, + constraint_ids: ConstraintIdIterator<'map>, +} + +impl<'map, 'db> Iterator for ConstraintsIterator<'map, 'db> { + type Item = Expression<'db>; + + fn next(&mut self) -> Option { + self.constraint_ids + .next() + .map(|constraint_id| self.all_constraints[constraint_id]) } } -/// A snapshot of the visible definitions for each symbol at a particular point in control flow. +impl std::iter::FusedIterator for ConstraintsIterator<'_, '_> {} + +/// A snapshot of the definitions and constraints state at a particular point in control flow. #[derive(Clone, Debug)] pub(super) struct FlowSnapshot { - definitions_by_symbol: IndexVec, + definitions_by_symbol: IndexVec, } -#[derive(Debug)] +#[derive(Debug, Default)] pub(super) struct UseDefMapBuilder<'db> { - /// Definition IDs array for `definitions_by_use` and `definitions_by_symbol` to slice into. - all_definitions: Vec>, + /// Append-only array of [`Definition`]; None is unbound. + all_definitions: IndexVec>, + + /// Append-only array of constraints (as [`Expression`]). + all_constraints: IndexVec>, /// Visible definitions at each so-far-recorded use. - definitions_by_use: IndexVec, + definitions_by_use: IndexVec, /// Currently visible definitions for each symbol. - definitions_by_symbol: IndexVec, + definitions_by_symbol: IndexVec, } impl<'db> UseDefMapBuilder<'db> { pub(super) fn new() -> Self { - Self { - all_definitions: Vec::new(), - definitions_by_use: IndexVec::new(), - definitions_by_symbol: IndexVec::new(), - } + Self::default() } pub(super) fn add_symbol(&mut self, symbol: ScopedSymbolId) { - let new_symbol = self.definitions_by_symbol.push(Definitions::unbound()); + let new_symbol = self.definitions_by_symbol.push(SymbolState::unbound()); debug_assert_eq!(symbol, new_symbol); } @@ -227,13 +286,15 @@ impl<'db> UseDefMapBuilder<'db> { ) { // We have a new definition of a symbol; this replaces any previous definitions in this // path. - let def_idx = self.all_definitions.len(); - self.all_definitions.push(definition); - self.definitions_by_symbol[symbol] = Definitions { - #[allow(clippy::range_plus_one)] - definitions_range: def_idx..(def_idx + 1), - may_be_unbound: false, - }; + let def_id = self.all_definitions.push(definition); + self.definitions_by_symbol[symbol] = SymbolState::with(def_id); + } + + pub(super) fn record_constraint(&mut self, constraint: Expression<'db>) { + let constraint_id = self.all_constraints.push(constraint); + for definitions in &mut self.definitions_by_symbol { + definitions.add_constraint(constraint_id); + } } pub(super) fn record_use(&mut self, symbol: ScopedSymbolId, use_id: ScopedUseId) { @@ -265,15 +326,15 @@ impl<'db> UseDefMapBuilder<'db> { // If the snapshot we are restoring is missing some symbols we've recorded since, we need // to fill them in so the symbol IDs continue to line up. Since they don't exist in the - // snapshot, the correct state to fill them in with is "unbound", the default. + // snapshot, the correct state to fill them in with is "unbound". self.definitions_by_symbol - .resize(num_symbols, Definitions::unbound()); + .resize(num_symbols, SymbolState::unbound()); } /// Merge the given snapshot into the current state, reflecting that we might have taken either /// path to get here. The new visible-definitions state for each symbol should include /// definitions from both the prior state and the snapshot. - pub(super) fn merge(&mut self, snapshot: &FlowSnapshot) { + pub(super) fn merge(&mut self, snapshot: FlowSnapshot) { // The tricky thing about merging two Ranges pointing into `all_definitions` is that if the // two Ranges aren't already adjacent in `all_definitions`, we will have to copy at least // one or the other of the ranges to the end of `all_definitions` so as to make them @@ -287,66 +348,26 @@ impl<'db> UseDefMapBuilder<'db> { // greater than the number of known symbols in a previously-taken snapshot. debug_assert!(self.definitions_by_symbol.len() >= snapshot.definitions_by_symbol.len()); - for (symbol_id, current) in self.definitions_by_symbol.iter_mut_enumerated() { - let Some(snapshot) = snapshot.definitions_by_symbol.get(symbol_id) else { - // Symbol not present in snapshot, so it's unbound from that path. - current.may_be_unbound = true; - continue; - }; - - // If the symbol can be unbound in either predecessor, it can be unbound post-merge. - current.may_be_unbound |= snapshot.may_be_unbound; - - // Merge the definition ranges. - let current = &mut current.definitions_range; - let snapshot = &snapshot.definitions_range; - - // We never create reversed ranges. - debug_assert!(current.end >= current.start); - debug_assert!(snapshot.end >= snapshot.start); - - if current == snapshot { - // Ranges already identical, nothing to do. - } else if snapshot.is_empty() { - // Merging from an empty range; nothing to do. - } else if (*current).is_empty() { - // Merging to an empty range; just use the incoming range. - *current = snapshot.clone(); - } else if snapshot.end >= current.start && snapshot.start <= current.end { - // Ranges are adjacent or overlapping, merge them in-place. - *current = current.start.min(snapshot.start)..current.end.max(snapshot.end); - } else if current.end == self.all_definitions.len() { - // Ranges are not adjacent or overlapping, `current` is at the end of - // `all_definitions`, we need to copy `snapshot` to the end so they are adjacent - // and can be merged into one range. - self.all_definitions.extend_from_within(snapshot.clone()); - current.end = self.all_definitions.len(); - } else if snapshot.end == self.all_definitions.len() { - // Ranges are not adjacent or overlapping, `snapshot` is at the end of - // `all_definitions`, we need to copy `current` to the end so they are adjacent and - // can be merged into one range. - self.all_definitions.extend_from_within(current.clone()); - current.start = snapshot.start; - current.end = self.all_definitions.len(); + let mut snapshot_definitions_iter = snapshot.definitions_by_symbol.into_iter(); + for current in &mut self.definitions_by_symbol { + if let Some(snapshot) = snapshot_definitions_iter.next() { + current.merge(snapshot); } else { - // Ranges are not adjacent and neither one is at the end of `all_definitions`, we - // have to copy both to the end so they are adjacent and we can merge them. - let start = self.all_definitions.len(); - self.all_definitions.extend_from_within(current.clone()); - self.all_definitions.extend_from_within(snapshot.clone()); - current.start = start; - current.end = self.all_definitions.len(); + // Symbol not present in snapshot, so it's unbound from that path. + current.add_unbound(); } } } pub(super) fn finish(mut self) -> UseDefMap<'db> { self.all_definitions.shrink_to_fit(); + self.all_constraints.shrink_to_fit(); self.definitions_by_symbol.shrink_to_fit(); self.definitions_by_use.shrink_to_fit(); UseDefMap { all_definitions: self.all_definitions, + all_constraints: self.all_constraints, definitions_by_use: self.definitions_by_use, public_definitions: self.definitions_by_symbol, } diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs new file mode 100644 index 0000000000000..ac8ce65398e1b --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs @@ -0,0 +1,228 @@ +/// Ordered set of `u32`. +/// +/// Uses an inline bit-set for small values (up to 64 * B), falls back to heap allocated vector of +/// blocks for larger values. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) enum BitSet { + /// Bit-set (in 64-bit blocks) for the first 64 * B entries. + Inline([u64; B]), + + /// Overflow beyond 64 * B. + Heap(Vec), +} + +impl Default for BitSet { + fn default() -> Self { + // B * 64 must fit in a u32, or else we have unusable bits; this assertion makes the + // truncating casts to u32 below safe. This would be better as a const assertion, but + // that's not possible on stable with const generic params. (B should never really be + // anywhere close to this large.) + assert!(B * 64 < (u32::MAX as usize)); + // This implementation requires usize >= 32 bits. + static_assertions::const_assert!(usize::BITS >= 32); + Self::Inline([0; B]) + } +} + +impl BitSet { + /// Create and return a new [`BitSet`] with a single `value` inserted. + pub(super) fn with(value: u32) -> Self { + let mut bitset = Self::default(); + bitset.insert(value); + bitset + } + + /// Convert from Inline to Heap, if needed, and resize the Heap vector, if needed. + fn resize(&mut self, value: u32) { + let num_blocks_needed = (value / 64) + 1; + match self { + Self::Inline(blocks) => { + let mut vec = blocks.to_vec(); + vec.resize(num_blocks_needed as usize, 0); + *self = Self::Heap(vec); + } + Self::Heap(vec) => { + vec.resize(num_blocks_needed as usize, 0); + } + } + } + + fn blocks_mut(&mut self) -> &mut [u64] { + match self { + Self::Inline(blocks) => blocks.as_mut_slice(), + Self::Heap(blocks) => blocks.as_mut_slice(), + } + } + + fn blocks(&self) -> &[u64] { + match self { + Self::Inline(blocks) => blocks.as_slice(), + Self::Heap(blocks) => blocks.as_slice(), + } + } + + /// Insert a value into the [`BitSet`]. + /// + /// Return true if the value was newly inserted, false if already present. + pub(super) fn insert(&mut self, value: u32) -> bool { + let value_usize = value as usize; + let (block, index) = (value_usize / 64, value_usize % 64); + if block >= self.blocks().len() { + self.resize(value); + } + let blocks = self.blocks_mut(); + let missing = blocks[block] & (1 << index) == 0; + blocks[block] |= 1 << index; + missing + } + + /// Intersect in-place with another [`BitSet`]. + pub(super) fn intersect(&mut self, other: &BitSet) { + let my_blocks = self.blocks_mut(); + let other_blocks = other.blocks(); + let min_len = my_blocks.len().min(other_blocks.len()); + for i in 0..min_len { + my_blocks[i] &= other_blocks[i]; + } + for block in my_blocks.iter_mut().skip(min_len) { + *block = 0; + } + } + + /// Return an iterator over the values (in ascending order) in this [`BitSet`]. + pub(super) fn iter(&self) -> BitSetIterator<'_, B> { + let blocks = self.blocks(); + BitSetIterator { + blocks, + current_block_index: 0, + current_block: blocks[0], + } + } +} + +/// Iterator over values in a [`BitSet`]. +#[derive(Debug)] +pub(super) struct BitSetIterator<'a, const B: usize> { + /// The blocks we are iterating over. + blocks: &'a [u64], + + /// The index of the block we are currently iterating through. + current_block_index: usize, + + /// The block we are currently iterating through (and zeroing as we go.) + current_block: u64, +} + +impl Iterator for BitSetIterator<'_, B> { + type Item = u32; + + fn next(&mut self) -> Option { + while self.current_block == 0 { + if self.current_block_index + 1 >= self.blocks.len() { + return None; + } + self.current_block_index += 1; + self.current_block = self.blocks[self.current_block_index]; + } + let lowest_bit_set = self.current_block.trailing_zeros(); + // reset the lowest set bit, without a data dependency on `lowest_bit_set` + self.current_block &= self.current_block.wrapping_sub(1); + // SAFETY: `lowest_bit_set` cannot be more than 64, `current_block_index` cannot be more + // than `B - 1`, and we check above that `B * 64 < u32::MAX`. So both `64 * + // current_block_index` and the final value here must fit in u32. + #[allow(clippy::cast_possible_truncation)] + Some(lowest_bit_set + (64 * self.current_block_index) as u32) + } +} + +impl std::iter::FusedIterator for BitSetIterator<'_, B> {} + +#[cfg(test)] +mod tests { + use super::BitSet; + + fn assert_bitset(bitset: &BitSet, contents: &[u32]) { + assert_eq!(bitset.iter().collect::>(), contents); + } + + #[test] + fn iter() { + let mut b = BitSet::<1>::with(3); + b.insert(27); + b.insert(6); + assert!(matches!(b, BitSet::Inline(_))); + assert_bitset(&b, &[3, 6, 27]); + } + + #[test] + fn iter_overflow() { + let mut b = BitSet::<1>::with(140); + b.insert(100); + b.insert(129); + assert!(matches!(b, BitSet::Heap(_))); + assert_bitset(&b, &[100, 129, 140]); + } + + #[test] + fn intersect() { + let mut b1 = BitSet::<1>::with(4); + let mut b2 = BitSet::<1>::with(4); + b1.insert(23); + b2.insert(5); + + b1.intersect(&b2); + assert_bitset(&b1, &[4]); + } + + #[test] + fn intersect_mixed_1() { + let mut b1 = BitSet::<1>::with(4); + let mut b2 = BitSet::<1>::with(4); + b1.insert(89); + b2.insert(5); + + b1.intersect(&b2); + assert_bitset(&b1, &[4]); + } + + #[test] + fn intersect_mixed_2() { + let mut b1 = BitSet::<1>::with(4); + let mut b2 = BitSet::<1>::with(4); + b1.insert(23); + b2.insert(89); + + b1.intersect(&b2); + assert_bitset(&b1, &[4]); + } + + #[test] + fn intersect_heap() { + let mut b1 = BitSet::<1>::with(4); + let mut b2 = BitSet::<1>::with(4); + b1.insert(89); + b2.insert(90); + + b1.intersect(&b2); + assert_bitset(&b1, &[4]); + } + + #[test] + fn intersect_heap_2() { + let mut b1 = BitSet::<1>::with(89); + let mut b2 = BitSet::<1>::with(89); + b1.insert(91); + b2.insert(90); + + b1.intersect(&b2); + assert_bitset(&b1, &[89]); + } + + #[test] + fn multiple_blocks() { + let mut b = BitSet::<2>::with(120); + b.insert(45); + assert!(matches!(b, BitSet::Inline(_))); + assert_bitset(&b, &[45, 120]); + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs new file mode 100644 index 0000000000000..c465bbe320b1f --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs @@ -0,0 +1,374 @@ +//! Track visible definitions of a symbol, and applicable constraints per definition. +//! +//! These data structures operate entirely on scope-local newtype-indices for definitions and +//! constraints, referring to their location in the `all_definitions` and `all_constraints` +//! indexvecs in [`super::UseDefMapBuilder`]. +//! +//! We need to track arbitrary associations between definitions and constraints, not just a single +//! set of currently dominating constraints (where "dominating" means "control flow must have +//! passed through it to reach this point"), because we can have dominating constraints that apply +//! to some definitions but not others, as in this code: +//! +//! ```python +//! x = 1 if flag else None +//! if x is not None: +//! if flag2: +//! x = 2 if flag else None +//! x +//! ``` +//! +//! The `x is not None` constraint dominates the final use of `x`, but it applies only to the first +//! definition of `x`, not the second, so `None` is a possible value for `x`. +//! +//! And we can't just track, for each definition, an index into a list of dominating constraints, +//! either, because we can have definitions which are still visible, but subject to constraints +//! that are no longer dominating, as in this code: +//! +//! ```python +//! x = 0 +//! if flag1: +//! x = 1 if flag2 else None +//! assert x is not None +//! x +//! ``` +//! +//! From the point of view of the final use of `x`, the `x is not None` constraint no longer +//! dominates, but it does dominate the `x = 1 if flag2 else None` definition, so we have to keep +//! track of that. +//! +//! The data structures used here ([`BitSet`] and [`smallvec::SmallVec`]) optimize for keeping all +//! data inline (avoiding lots of scattered allocations) in small-to-medium cases, and falling back +//! to heap allocation to be able to scale to arbitrary numbers of definitions and constraints when +//! needed. +use super::bitset::{BitSet, BitSetIterator}; +use ruff_index::newtype_index; +use smallvec::SmallVec; + +/// A newtype-index for a definition in a particular scope. +#[newtype_index] +pub(super) struct ScopedDefinitionId; + +/// A newtype-index for a constraint expression in a particular scope. +#[newtype_index] +pub(super) struct ScopedConstraintId; + +/// Can reference this * 64 total definitions inline; more will fall back to the heap. +const INLINE_DEFINITION_BLOCKS: usize = 3; + +/// A [`BitSet`] of [`ScopedDefinitionId`], representing visible definitions of a symbol in a scope. +type Definitions = BitSet; +type DefinitionsIterator<'a> = BitSetIterator<'a, INLINE_DEFINITION_BLOCKS>; + +/// Can reference this * 64 total constraints inline; more will fall back to the heap. +const INLINE_CONSTRAINT_BLOCKS: usize = 2; + +/// Can keep inline this many visible definitions per symbol at a given time; more will go to heap. +const INLINE_VISIBLE_DEFINITIONS_PER_SYMBOL: usize = 4; + +/// One [`BitSet`] of applicable [`ScopedConstraintId`] per visible definition. +type InlineConstraintArray = + [BitSet; INLINE_VISIBLE_DEFINITIONS_PER_SYMBOL]; +type Constraints = SmallVec; +type ConstraintsIterator<'a> = std::slice::Iter<'a, BitSet>; +type ConstraintsIntoIterator = smallvec::IntoIter; + +/// Visible definitions and narrowing constraints for a single symbol at some point in control flow. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(super) struct SymbolState { + /// [`BitSet`]: which [`ScopedDefinitionId`] are visible for this symbol? + visible_definitions: Definitions, + + /// For each definition, which [`ScopedConstraintId`] apply? + /// + /// This is a [`smallvec::SmallVec`] which should always have one [`BitSet`] of constraints per + /// definition in `visible_definitions`. + constraints: Constraints, + + /// Could the symbol be unbound at this point? + may_be_unbound: bool, +} + +/// A single [`ScopedDefinitionId`] with an iterator of its applicable [`ScopedConstraintId`]. +#[derive(Debug)] +pub(super) struct DefinitionIdWithConstraints<'a> { + pub(super) definition: ScopedDefinitionId, + pub(super) constraint_ids: ConstraintIdIterator<'a>, +} + +impl SymbolState { + /// Return a new [`SymbolState`] representing an unbound symbol. + pub(super) fn unbound() -> Self { + Self { + visible_definitions: Definitions::default(), + constraints: Constraints::default(), + may_be_unbound: true, + } + } + + /// Return a new [`SymbolState`] representing a symbol with a single visible definition. + pub(super) fn with(definition_id: ScopedDefinitionId) -> Self { + let mut constraints = Constraints::with_capacity(1); + constraints.push(BitSet::default()); + Self { + visible_definitions: Definitions::with(definition_id.into()), + constraints, + may_be_unbound: false, + } + } + + /// Add Unbound as a possibility for this symbol. + pub(super) fn add_unbound(&mut self) { + self.may_be_unbound = true; + } + + /// Add given constraint to all currently-visible definitions. + pub(super) fn add_constraint(&mut self, constraint_id: ScopedConstraintId) { + for bitset in &mut self.constraints { + bitset.insert(constraint_id.into()); + } + } + + /// Merge another [`SymbolState`] into this one. + pub(super) fn merge(&mut self, b: SymbolState) { + let mut a = Self { + visible_definitions: Definitions::default(), + constraints: Constraints::default(), + may_be_unbound: self.may_be_unbound || b.may_be_unbound, + }; + std::mem::swap(&mut a, self); + let mut a_defs_iter = a.visible_definitions.iter(); + let mut b_defs_iter = b.visible_definitions.iter(); + let mut a_constraints_iter = a.constraints.into_iter(); + let mut b_constraints_iter = b.constraints.into_iter(); + + let mut opt_a_def: Option = a_defs_iter.next(); + let mut opt_b_def: Option = b_defs_iter.next(); + + // Iterate through the definitions from `a` and `b`, always processing the lower definition + // ID first, and pushing each definition onto the merged `SymbolState` with its + // constraints. If a definition is found in both `a` and `b`, push it with the intersection + // of the constraints from the two paths; a constraint that applies from only one possible + // path is irrelevant. + + // Helper to push `def`, with constraints in `constraints_iter`, onto `self`. + let push = |def, constraints_iter: &mut ConstraintsIntoIterator, merged: &mut Self| { + merged.visible_definitions.insert(def); + // SAFETY: we only ever create SymbolState with either no definitions and no constraint + // bitsets (`::unbound`) or one definition and one constraint bitset (`::with`), and + // `::merge` always pushes one definition and one constraint bitset together (just + // below), so the number of definitions and the number of constraint bitsets can never + // get out of sync. + let constraints = constraints_iter + .next() + .expect("definitions and constraints length mismatch"); + merged.constraints.push(constraints); + }; + + loop { + match (opt_a_def, opt_b_def) { + (Some(a_def), Some(b_def)) => match a_def.cmp(&b_def) { + std::cmp::Ordering::Less => { + // Next definition ID is only in `a`, push it to `self` and advance `a`. + push(a_def, &mut a_constraints_iter, self); + opt_a_def = a_defs_iter.next(); + } + std::cmp::Ordering::Greater => { + // Next definition ID is only in `b`, push it to `self` and advance `b`. + push(b_def, &mut b_constraints_iter, self); + opt_b_def = b_defs_iter.next(); + } + std::cmp::Ordering::Equal => { + // Next definition is in both; push to `self` and intersect constraints. + push(a_def, &mut b_constraints_iter, self); + // SAFETY: we only ever create SymbolState with either no definitions and + // no constraint bitsets (`::unbound`) or one definition and one constraint + // bitset (`::with`), and `::merge` always pushes one definition and one + // constraint bitset together (just below), so the number of definitions + // and the number of constraint bitsets can never get out of sync. + let a_constraints = a_constraints_iter + .next() + .expect("definitions and constraints length mismatch"); + // If the same definition is visible through both paths, any constraint + // that applies on only one path is irrelevant to the resulting type from + // unioning the two paths, so we intersect the constraints. + self.constraints + .last_mut() + .unwrap() + .intersect(&a_constraints); + opt_a_def = a_defs_iter.next(); + opt_b_def = b_defs_iter.next(); + } + }, + (Some(a_def), None) => { + // We've exhausted `b`, just push the def from `a` and move on to the next. + push(a_def, &mut a_constraints_iter, self); + opt_a_def = a_defs_iter.next(); + } + (None, Some(b_def)) => { + // We've exhausted `a`, just push the def from `b` and move on to the next. + push(b_def, &mut b_constraints_iter, self); + opt_b_def = b_defs_iter.next(); + } + (None, None) => break, + } + } + } + + /// Get iterator over visible definitions with constraints. + pub(super) fn visible_definitions(&self) -> DefinitionIdWithConstraintsIterator { + DefinitionIdWithConstraintsIterator { + definitions: self.visible_definitions.iter(), + constraints: self.constraints.iter(), + } + } + + /// Could the symbol be unbound? + pub(super) fn may_be_unbound(&self) -> bool { + self.may_be_unbound + } +} + +/// The default state of a symbol (if we've seen no definitions of it) is unbound. +impl Default for SymbolState { + fn default() -> Self { + SymbolState::unbound() + } +} + +#[derive(Debug)] +pub(super) struct DefinitionIdWithConstraintsIterator<'a> { + definitions: DefinitionsIterator<'a>, + constraints: ConstraintsIterator<'a>, +} + +impl<'a> Iterator for DefinitionIdWithConstraintsIterator<'a> { + type Item = DefinitionIdWithConstraints<'a>; + + fn next(&mut self) -> Option { + match (self.definitions.next(), self.constraints.next()) { + (None, None) => None, + (Some(def), Some(constraints)) => Some(DefinitionIdWithConstraints { + definition: ScopedDefinitionId::from_u32(def), + constraint_ids: ConstraintIdIterator { + wrapped: constraints.iter(), + }, + }), + // SAFETY: see above. + _ => unreachable!("definitions and constraints length mismatch"), + } + } +} + +impl std::iter::FusedIterator for DefinitionIdWithConstraintsIterator<'_> {} + +#[derive(Debug)] +pub(super) struct ConstraintIdIterator<'a> { + wrapped: BitSetIterator<'a, INLINE_CONSTRAINT_BLOCKS>, +} + +impl Iterator for ConstraintIdIterator<'_> { + type Item = ScopedConstraintId; + + fn next(&mut self) -> Option { + self.wrapped.next().map(ScopedConstraintId::from_u32) + } +} + +impl std::iter::FusedIterator for ConstraintIdIterator<'_> {} + +#[cfg(test)] +mod tests { + use super::{ScopedConstraintId, ScopedDefinitionId, SymbolState}; + + impl SymbolState { + pub(crate) fn assert(&self, may_be_unbound: bool, expected: &[&str]) { + assert_eq!(self.may_be_unbound(), may_be_unbound); + let actual = self + .visible_definitions() + .map(|def_id_with_constraints| { + format!( + "{}<{}>", + def_id_with_constraints.definition.as_u32(), + def_id_with_constraints + .constraint_ids + .map(ScopedConstraintId::as_u32) + .map(|idx| idx.to_string()) + .collect::>() + .join(", ") + ) + }) + .collect::>(); + assert_eq!(actual, expected); + } + } + + #[test] + fn unbound() { + let cd = SymbolState::unbound(); + + cd.assert(true, &[]); + } + + #[test] + fn with() { + let cd = SymbolState::with(ScopedDefinitionId::from_u32(0)); + + cd.assert(false, &["0<>"]); + } + + #[test] + fn add_unbound() { + let mut cd = SymbolState::with(ScopedDefinitionId::from_u32(0)); + cd.add_unbound(); + + cd.assert(true, &["0<>"]); + } + + #[test] + fn add_constraint() { + let mut cd = SymbolState::with(ScopedDefinitionId::from_u32(0)); + cd.add_constraint(ScopedConstraintId::from_u32(0)); + + cd.assert(false, &["0<0>"]); + } + + #[test] + fn merge() { + // merging the same definition with the same constraint keeps the constraint + let mut cd0a = SymbolState::with(ScopedDefinitionId::from_u32(0)); + cd0a.add_constraint(ScopedConstraintId::from_u32(0)); + + let mut cd0b = SymbolState::with(ScopedDefinitionId::from_u32(0)); + cd0b.add_constraint(ScopedConstraintId::from_u32(0)); + + cd0a.merge(cd0b); + let mut cd0 = cd0a; + cd0.assert(false, &["0<0>"]); + + // merging the same definition with differing constraints drops all constraints + let mut cd1a = SymbolState::with(ScopedDefinitionId::from_u32(1)); + cd1a.add_constraint(ScopedConstraintId::from_u32(1)); + + let mut cd1b = SymbolState::with(ScopedDefinitionId::from_u32(1)); + cd1b.add_constraint(ScopedConstraintId::from_u32(2)); + + cd1a.merge(cd1b); + let cd1 = cd1a; + cd1.assert(false, &["1<>"]); + + // merging a constrained definition with unbound keeps both + let mut cd2a = SymbolState::with(ScopedDefinitionId::from_u32(2)); + cd2a.add_constraint(ScopedConstraintId::from_u32(3)); + + let cd2b = SymbolState::unbound(); + + cd2a.merge(cd2b); + let cd2 = cd2a; + cd2.assert(true, &["2<3>"]); + + // merging different definitions keeps them each with their existing constraints + cd0.merge(cd2); + let cd = cd0; + cd.assert(true, &["0<0>", "2<3>"]); + } +} diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index b59d7a7f2513d..bf6230d50fcb0 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -4,15 +4,22 @@ use ruff_python_ast::name::Name; use crate::builtins::builtins_scope; use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId}; -use crate::semantic_index::{global_scope, symbol_table, use_def_map}; +use crate::semantic_index::{ + global_scope, symbol_table, use_def_map, DefinitionWithConstraints, + DefinitionWithConstraintsIterator, +}; +use crate::types::narrow::narrowing_constraint; use crate::{Db, FxOrderSet}; mod builder; mod display; mod infer; +mod narrow; -pub(crate) use self::builder::UnionBuilder; -pub(crate) use self::infer::{infer_definition_types, infer_scope_types}; +pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder}; +pub(crate) use self::infer::{ + infer_definition_types, infer_expression_types, infer_scope_types, TypeInference, +}; /// Infer the public type of a symbol (its type as seen from outside its scope). pub(crate) fn symbol_ty<'db>( @@ -82,10 +89,31 @@ pub(crate) fn definition_ty<'db>(db: &'db dyn Db, definition: Definition<'db>) - /// provide an `unbound_ty`. pub(crate) fn definitions_ty<'db>( db: &'db dyn Db, - definitions: &[Definition<'db>], + definitions_with_constraints: DefinitionWithConstraintsIterator<'_, 'db>, unbound_ty: Option>, ) -> Type<'db> { - let def_types = definitions.iter().map(|def| definition_ty(db, *def)); + let def_types = definitions_with_constraints.map( + |DefinitionWithConstraints { + definition, + constraints, + }| { + let mut constraint_tys = + constraints.filter_map(|test| narrowing_constraint(db, test, definition)); + let definition_ty = definition_ty(db, definition); + if let Some(first_constraint_ty) = constraint_tys.next() { + let mut builder = IntersectionBuilder::new(db); + builder = builder + .add_positive(definition_ty) + .add_positive(first_constraint_ty); + for constraint_ty in constraint_tys { + builder = builder.add_positive(constraint_ty); + } + builder.build() + } else { + definition_ty + } + }, + ); let mut all_types = unbound_ty.into_iter().chain(def_types); let Some(first) = all_types.next() else { diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 8581ff546434d..e08a9d7e2d103 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -65,7 +65,6 @@ impl<'db> UnionBuilder<'db> { } } -#[allow(unused)] #[derive(Clone)] pub(crate) struct IntersectionBuilder<'db> { // Really this builds a union-of-intersections, because we always keep our set-theoretic types @@ -78,8 +77,7 @@ pub(crate) struct IntersectionBuilder<'db> { } impl<'db> IntersectionBuilder<'db> { - #[allow(dead_code)] - fn new(db: &'db dyn Db) -> Self { + pub(crate) fn new(db: &'db dyn Db) -> Self { Self { db, intersections: vec![InnerIntersectionBuilder::new()], @@ -93,8 +91,7 @@ impl<'db> IntersectionBuilder<'db> { } } - #[allow(dead_code)] - fn add_positive(mut self, ty: Type<'db>) -> Self { + pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self { if let Type::Union(union) = ty { // Distribute ourself over this union: for each union element, clone ourself and // intersect with that union element, then create a new union-of-intersections with all @@ -122,8 +119,7 @@ impl<'db> IntersectionBuilder<'db> { } } - #[allow(dead_code)] - fn add_negative(mut self, ty: Type<'db>) -> Self { + pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self { // See comments above in `add_positive`; this is just the negated version. if let Type::Union(union) = ty { union @@ -142,8 +138,7 @@ impl<'db> IntersectionBuilder<'db> { } } - #[allow(dead_code)] - fn build(mut self) -> Type<'db> { + pub(crate) fn build(mut self) -> Type<'db> { // Avoid allocating the UnionBuilder unnecessarily if we have just one intersection: if self.intersections.len() == 1 { self.intersections.pop().unwrap().build(self.db) @@ -157,7 +152,6 @@ impl<'db> IntersectionBuilder<'db> { } } -#[allow(unused)] #[derive(Debug, Clone, Default)] struct InnerIntersectionBuilder<'db> { positive: FxOrderSet>, @@ -223,6 +217,16 @@ impl<'db> InnerIntersectionBuilder<'db> { self.positive.retain(Type::is_unbound); self.negative.clear(); } + + // None intersects only with object + for pos in &self.positive { + if let Type::Instance(_) = pos { + // could be `object` type + } else { + self.negative.remove(&Type::None); + break; + } + } } fn build(mut self, db: &'db dyn Db) -> Type<'db> { @@ -453,4 +457,15 @@ mod tests { assert_eq!(ty, Type::IntLiteral(1)); } + + #[test] + fn build_intersection_simplify_negative_none() { + let db = setup_db(); + let ty = IntersectionBuilder::new(&db) + .add_negative(Type::None) + .add_positive(Type::IntLiteral(1)) + .build(); + + assert_eq!(ty, Type::IntLiteral(1)); + } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 8156dc6e73e71..290e063b06151 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -319,7 +319,7 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_region_expression(&mut self, expression: Expression<'db>) { - self.infer_expression(expression.node(self.db)); + self.infer_expression(expression.node_ref(self.db)); } fn infer_module(&mut self, module: &ast::ModModule) { @@ -2587,6 +2587,26 @@ mod tests { Ok(()) } + #[test] + fn narrow_not_none() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + x = None if flag else 1 + y = 0 + if x is not None: + y = x + ", + )?; + + assert_public_ty(&db, "/src/a.py", "x", "Literal[1] | None"); + assert_public_ty(&db, "/src/a.py", "y", "Literal[0, 1]"); + + Ok(()) + } + #[test] fn while_loop() -> anyhow::Result<()> { let mut db = setup_db(); @@ -2684,10 +2704,11 @@ mod tests { fn first_public_def<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> { let scope = global_scope(db, file); - *use_def_map(db, scope) + use_def_map(db, scope) .public_definitions(symbol_table(db, scope).symbol_id_by_name(name).unwrap()) - .first() + .next() .unwrap() + .definition } #[test] diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs new file mode 100644 index 0000000000000..381c6effa7171 --- /dev/null +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -0,0 +1,115 @@ +use crate::semantic_index::ast_ids::HasScopedAstId; +use crate::semantic_index::definition::Definition; +use crate::semantic_index::expression::Expression; +use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; +use crate::semantic_index::symbol_table; +use crate::types::{infer_expression_types, IntersectionBuilder, Type, TypeInference}; +use crate::Db; +use ruff_python_ast as ast; +use rustc_hash::FxHashMap; +use std::sync::Arc; + +/// Return the type constraint that `test` (if true) would place on `definition`, if any. +/// +/// For example, if we have this code: +/// +/// ```python +/// y = 1 if flag else None +/// x = 1 if flag else None +/// if x is not None: +/// ... +/// ``` +/// +/// The `test` expression `x is not None` places the constraint "not None" on the definition of +/// `x`, so in that case we'd return `Some(Type::Intersection(negative=[Type::None]))`. +/// +/// But if we called this with the same `test` expression, but the `definition` of `y`, no +/// constraint is applied to that definition, so we'd just return `None`. +pub(crate) fn narrowing_constraint<'db>( + db: &'db dyn Db, + test: Expression<'db>, + definition: Definition<'db>, +) -> Option> { + all_narrowing_constraints(db, test) + .get(&definition.symbol(db)) + .copied() +} + +#[salsa::tracked(return_ref)] +fn all_narrowing_constraints<'db>( + db: &'db dyn Db, + test: Expression<'db>, +) -> NarrowingConstraints<'db> { + NarrowingConstraintsBuilder::new(db, test).finish() +} + +type NarrowingConstraints<'db> = FxHashMap>; + +struct NarrowingConstraintsBuilder<'db> { + db: &'db dyn Db, + expression: Expression<'db>, + constraints: NarrowingConstraints<'db>, +} + +impl<'db> NarrowingConstraintsBuilder<'db> { + fn new(db: &'db dyn Db, expression: Expression<'db>) -> Self { + Self { + db, + expression, + constraints: NarrowingConstraints::default(), + } + } + + fn finish(mut self) -> NarrowingConstraints<'db> { + if let ast::Expr::Compare(expr_compare) = self.expression.node_ref(self.db).node() { + self.add_expr_compare(expr_compare); + } + // TODO other test expression kinds + + self.constraints.shrink_to_fit(); + self.constraints + } + + fn symbols(&self) -> Arc { + symbol_table(self.db, self.scope()) + } + + fn scope(&self) -> ScopeId<'db> { + self.expression.scope(self.db) + } + + fn inference(&self) -> &'db TypeInference<'db> { + infer_expression_types(self.db, self.expression) + } + + fn add_expr_compare(&mut self, expr_compare: &ast::ExprCompare) { + let ast::ExprCompare { + range: _, + left, + ops, + comparators, + } = expr_compare; + + if let ast::Expr::Name(ast::ExprName { + range: _, + id, + ctx: _, + }) = left.as_ref() + { + // SAFETY: we should always have a symbol for every Name node. + let symbol = self.symbols().symbol_id_by_name(id).unwrap(); + let scope = self.scope(); + let inference = self.inference(); + for (op, comparator) in std::iter::zip(&**ops, &**comparators) { + let comp_ty = inference.expression_ty(comparator.scoped_ast_id(self.db, scope)); + if matches!(op, ast::CmpOp::IsNot) { + let ty = IntersectionBuilder::new(self.db) + .add_negative(comp_ty) + .build(); + self.constraints.insert(symbol, ty); + }; + // TODO other comparison types + } + } + } +}