From ff182466339d2e502b0ef59d98ef124c8826c4e5 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Mon, 5 Aug 2024 16:34:12 -0700 Subject: [PATCH] [red-knot] add type narrowing --- .../src/semantic_index.rs | 67 +-- .../src/semantic_index/builder.rs | 11 +- .../src/semantic_index/use_def.rs | 286 +++++++------ .../src/semantic_index/use_def/bitset.rs | 402 ++++++++++++++++++ .../semantic_index/use_def/symbol_state.rs | 362 ++++++++++++++++ crates/red_knot_python_semantic/src/types.rs | 31 +- .../src/types/builder.rs | 35 +- .../src/types/infer.rs | 25 +- .../src/types/narrow.rs | 100 +++++ 9 files changed, 1143 insertions(+), 176 deletions(-) create mode 100644 crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs create mode 100644 crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs create mode 100644 crates/red_knot_python_semantic/src/types/narrow.rs diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index a3626c0bdc3b94..6b38404db04d53 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -16,10 +16,9 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolTable, }; +use crate::semantic_index::use_def::UseDefMap; use crate::Db; -pub(crate) use self::use_def::UseDefMap; - pub mod ast_ids; mod builder; pub mod definition; @@ -27,6 +26,8 @@ pub mod expression; pub mod symbol; mod use_def; +pub(crate) use self::use_def::{DefinitionWithConstraints, DefinitionWithConstraintsIterator}; + type SymbolMap = hashbrown::HashMap; /// Returns the semantic index for `file`. @@ -313,6 +314,7 @@ mod tests { use crate::semantic_index::ast_ids::HasScopedUseId; use crate::semantic_index::definition::DefinitionKind; use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable}; + use crate::semantic_index::use_def::DefinitionWithConstraints; use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map}; use crate::Db; @@ -374,7 +376,9 @@ mod tests { let foo = global_table.symbol_id_by_name("foo").unwrap(); let use_def = use_def_map(&db, scope); - let [definition] = use_def.public_definitions(foo) else { + let [DefinitionWithConstraints { definition, .. }] = + use_def.public_definitions(foo).collect::>()[..] + else { panic!("expected one definition"); }; assert!(matches!(definition.node(&db), DefinitionKind::Import(_))); @@ -411,11 +415,14 @@ mod tests { ); let use_def = use_def_map(&db, scope); - let [definition] = use_def.public_definitions( - global_table - .symbol_id_by_name("foo") - .expect("symbol to exist"), - ) else { + let [DefinitionWithConstraints { definition, .. }] = use_def + .public_definitions( + global_table + .symbol_id_by_name("foo") + .expect("symbol to exist"), + ) + .collect::>()[..] + else { panic!("expected one definition"); }; assert!(matches!( @@ -438,8 +445,9 @@ mod tests { "a symbol used but not defined in a scope should have only the used flag" ); let use_def = use_def_map(&db, scope); - let [definition] = - use_def.public_definitions(global_table.symbol_id_by_name("x").expect("symbol exists")) + let [DefinitionWithConstraints { definition, .. }] = use_def + .public_definitions(global_table.symbol_id_by_name("x").expect("symbol exists")) + .collect::>()[..] else { panic!("expected one definition"); }; @@ -477,8 +485,9 @@ y = 2 assert_eq!(names(&class_table), vec!["x"]); let use_def = index.use_def_map(class_scope_id); - let [definition] = - use_def.public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists")) + let [DefinitionWithConstraints { definition, .. }] = use_def + .public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists")) + .collect::>()[..] else { panic!("expected one definition"); }; @@ -515,11 +524,14 @@ y = 2 assert_eq!(names(&function_table), vec!["x"]); let use_def = index.use_def_map(function_scope_id); - let [definition] = use_def.public_definitions( - function_table - .symbol_id_by_name("x") - .expect("symbol exists"), - ) else { + let [DefinitionWithConstraints { definition, .. }] = use_def + .public_definitions( + function_table + .symbol_id_by_name("x") + .expect("symbol exists"), + ) + .collect::>()[..] + else { panic!("expected one definition"); }; assert!(matches!( @@ -594,7 +606,9 @@ y = 2 let element_use_id = element.scoped_use_id(&db, comprehension_scope_id.to_scope_id(&db, file)); - let [definition] = use_def.use_definitions(element_use_id) else { + let [DefinitionWithConstraints { definition, .. }] = + use_def.use_definitions(element_use_id).collect::>()[..] + else { panic!("expected one definition") }; let DefinitionKind::Comprehension(comprehension) = definition.node(&db) else { @@ -693,11 +707,14 @@ def func(): assert_eq!(names(&func2_table), vec!["y"]); let use_def = index.use_def_map(FileScopeId::global()); - let [definition] = use_def.public_definitions( - global_table - .symbol_id_by_name("func") - .expect("symbol exists"), - ) else { + let [DefinitionWithConstraints { definition, .. }] = use_def + .public_definitions( + global_table + .symbol_id_by_name("func") + .expect("symbol exists"), + ) + .collect::>()[..] + else { panic!("expected one definition"); }; assert!(matches!(definition.node(&db), DefinitionKind::Function(_))); @@ -800,7 +817,9 @@ class C[T]: }; let x_use_id = x_use_expr_name.scoped_use_id(&db, scope); let use_def = use_def_map(&db, scope); - let [definition] = use_def.use_definitions(x_use_id) else { + let [DefinitionWithConstraints { definition, .. }] = + use_def.use_definitions(x_use_id).collect::>()[..] + else { panic!("expected one definition"); }; let DefinitionKind::Assignment(assignment) = definition.node(&db) else { diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index ee17e228d9a347..a59d8e70131639 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -195,9 +195,16 @@ impl<'db> SemanticIndexBuilder<'db> { definition } + fn add_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> { + let expression = self.add_standalone_expression(constraint_node); + self.current_use_def_map_mut().record_constraint(expression); + + expression + } + /// Record an expression that needs to be a Salsa ingredient, because we need to infer its type /// standalone (type narrowing tests, RHS of an assignment.) - fn add_standalone_expression(&mut self, expression_node: &ast::Expr) { + fn add_standalone_expression(&mut self, expression_node: &ast::Expr) -> Expression<'db> { let expression = Expression::new( self.db, self.file, @@ -210,6 +217,7 @@ impl<'db> SemanticIndexBuilder<'db> { ); self.expressions_by_node .insert(expression_node.into(), expression); + expression } fn with_type_params( @@ -456,6 +464,7 @@ where ast::Stmt::If(node) => { self.visit_expr(&node.test); let pre_if = self.flow_snapshot(); + self.add_constraint(&node.test); self.visit_body(&node.body); let mut post_clauses: Vec = vec![]; for clause in &node.elif_else_clauses { diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index f3e1afe98273e8..d80c777adb5293 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -1,4 +1,5 @@ -//! Build a map from each use of a symbol to the definitions visible from that use. +//! Build a map from each use of a symbol to the definitions visible from that use, and the +//! type-narrowing constraints that apply to each definition. //! //! Let's take this code sample: //! @@ -6,7 +7,7 @@ //! x = 1 //! x = 2 //! y = x -//! if flag: +//! if y is not None: //! x = 3 //! else: //! x = 4 @@ -34,8 +35,8 @@ //! [`AstIds`](crate::semantic_index::ast_ids::AstIds) we number all uses (that means a `Name` node //! with `Load` context) so we have a `ScopedUseId` to efficiently represent each use. //! -//! The other case we need to handle is when a symbol is referenced from a different scope (the -//! most obvious example of this is an import). We call this "public" use of a symbol. So the other +//! Another case we need to handle is when a symbol is referenced from a different scope (the most +//! obvious example of this is an import). We call this "public" use of a symbol. So the other //! question we need to be able to answer is, what are the publicly-visible definitions of each //! symbol? //! @@ -53,42 +54,55 @@ //! start.) //! //! So this means that the publicly-visible definitions of a symbol are the definitions still -//! visible at the end of the scope. +//! visible at the end of the scope; effectively we have an implicit "use" of every symbol at the +//! end of the scope. //! -//! The data structure we build to answer these two questions is the `UseDefMap`. It has a +//! We also need to know, for a given definition of a symbol, what type-narrowing constraints apply +//! to it. For instance, in this code sample: +//! +//! ```python +//! x = 1 if flag else None +//! if x is not None: +//! y = x +//! ``` +//! +//! At the use of `x` in `y = x`, the visible definition of `x` is `1 if flag else None`, which +//! would infer as the type `Literal[1] | None`. But the constraint `x is not None` dominates this +//! use, which means we can rule out the possibility that `x` is `None` here, which should give us +//! the type `Literal[1]` for this use. +//! +//! The data structure we build to answer these questions is the `UseDefMap`. It has a //! `definitions_by_use` vector indexed by [`ScopedUseId`] and a `public_definitions` vector //! indexed by [`ScopedSymbolId`]. The values in each of these vectors are (in principle) a list of -//! visible definitions at that use, or at the end of the scope for that symbol. +//! visible definitions at that use, or at the end of the scope for that symbol, with a list of the +//! dominating constraints for each of those definitions. //! -//! In order to avoid vectors-of-vectors and all the allocations that would entail, we don't -//! actually store these "list of visible definitions" as a vector of [`Definition`] IDs. Instead, -//! the values in `definitions_by_use` and `public_definitions` are a [`Definitions`] struct that -//! keeps a [`Range`] into a third vector of [`Definition`] IDs, `all_definitions`. The trick with -//! this representation is that it requires that the definitions visible at any given use of a -//! symbol are stored sequentially in `all_definitions`. +//! In order to avoid vectors-of-vectors-of-vectors and all the allocations that would entail, we +//! don't actually store these "list of visible definitions" as a vector of [`Definition`]. +//! Instead, the values in `definitions_by_use` and `public_definitions` are a [`SymbolState`] +//! struct which uses bit-sets to track definitions and constraints in terms of +//! [`ScopedDefinitionId`] and [`ScopedConstraintId`], which are indices into the `all_definitions` +//! and `all_constraints` indexvecs in the [`UseDefMap`]. //! -//! There is another special kind of possible "definition" for a symbol: it might be unbound in the -//! scope. (This isn't equivalent to "zero visible definitions", since we may go through an `if` -//! that has a definition for the symbol, leaving us with one visible definition, but still also -//! the "unbound" possibility, since we might not have taken the `if` branch.) +//! There is another special kind of possible "definition" for a symbol: there might be a path from +//! the scope entry to a given use in which the symbol is never bound. //! //! The simplest way to model "unbound" would be as an actual [`Definition`] itself: the initial //! visible [`Definition`] for each symbol in a scope. But actually modeling it this way would -//! dramatically increase the number of [`Definition`] that Salsa must track. Since "unbound" is a +//! unnecessarily increase the number of [`Definition`] that Salsa must track. Since "unbound" is a //! special definition in that all symbols share it, and it doesn't have any additional per-symbol -//! state, we can represent it more efficiently: we use the `may_be_unbound` boolean on the -//! [`Definitions`] struct. If this flag is `true`, it means the symbol/use really has one -//! additional visible "definition", which is the unbound state. If this flag is `false`, it means -//! we've eliminated the possibility of unbound: every path we've followed includes a definition -//! for this symbol. +//! state, and constraints are irrelevant to it, we can represent it more efficiently: we use the +//! `may_be_unbound` boolean on the [`SymbolState`] struct. If this flag is `true`, it means the +//! symbol/use really has one additional visible "definition", which is the unbound state. If this +//! flag is `false`, it means we've eliminated the possibility of unbound: every path we've +//! followed includes a definition for this symbol. //! -//! To build a [`UseDefMap`], the [`UseDefMapBuilder`] is notified of each new use and definition -//! as they are encountered by the +//! To build a [`UseDefMap`], the [`UseDefMapBuilder`] is notified of each new use, definition, and +//! constraint as they are encountered by the //! [`SemanticIndexBuilder`](crate::semantic_index::builder::SemanticIndexBuilder) AST visit. For -//! each symbol, the builder tracks the currently-visible definitions for that symbol. When we hit -//! a use of a symbol, it records the currently-visible definitions for that symbol as the visible -//! definitions for that use. When we reach the end of the scope, it records the currently-visible -//! definitions for each symbol as the public definitions of that symbol. +//! each symbol, the builder tracks the `SymbolState` for that symbol. When we hit a use of a +//! symbol, it records the current state for that symbol for that use. When we reach the end of the +//! scope, it records the state for each symbol as the public definitions of that symbol. //! //! Let's walk through the above example. Initially we record for `x` that it has no visible //! definitions, and may be unbound. When we see `x = 1`, we record that as the sole visible @@ -98,10 +112,11 @@ //! //! Then we hit the `if` branch. We visit the `test` node (`flag` in this case), since that will //! happen regardless. Then we take a pre-branch snapshot of the currently visible definitions for -//! all symbols, which we'll need later. Then we go ahead and visit the `if` body. When we see `x = -//! 3`, it replaces `x = 2` as the sole visible definition of `x`. At the end of the `if` body, we -//! take another snapshot of the currently-visible definitions; we'll call this the post-if-body -//! snapshot. +//! all symbols, which we'll need later. Then we record `flag` as a possible constraint on the +//! currently visible definition (`x = 2`), and go ahead and visit the `if` body. When we see `x = +//! 3`, it replaces `x = 2` (constrained by `flag`) as the sole visible definition of `x`. At the +//! end of the `if` body, we take another snapshot of the currently-visible definitions; we'll call +//! this the post-if-body snapshot. //! //! Now we need to visit the `else` clause. The conditions when entering the `else` clause should //! be the pre-if conditions; if we are entering the `else` clause, we know that the `if` test @@ -125,98 +140,140 @@ //! (In the future we may have some other questions we want to answer as well, such as "is this //! definition used?", which will require tracking a bit more info in our map, e.g. a "used" bit //! for each [`Definition`] which is flipped to true when we record that definition for a use.) +use self::symbol_state::{ + ConstraintIdIterator, DefinitionIdWithConstraintsIterator, ScopedConstraintId, + ScopedDefinitionId, SymbolState, +}; use crate::semantic_index::ast_ids::ScopedUseId; use crate::semantic_index::definition::Definition; +use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::ScopedSymbolId; use ruff_index::IndexVec; -use std::ops::Range; -/// All definitions that can reach a given use of a name. +mod bitset; +mod symbol_state; + +/// Applicable definitions and constraints for every use of a name. #[derive(Debug, PartialEq, Eq)] pub(crate) struct UseDefMap<'db> { - // TODO store constraints with definitions for type narrowing - /// Definition IDs array for `definitions_by_use` and `public_definitions` to slice into. - all_definitions: Vec>, + /// Array of [`Definition`] in this scope. + all_definitions: IndexVec>, + + /// Array of constraints (as [`Expression`]) in this scope. + all_constraints: IndexVec>, - /// Definitions that can reach a [`ScopedUseId`]. - definitions_by_use: IndexVec, + /// [`SymbolState`] visible at a [`ScopedUseId`]. + definitions_by_use: IndexVec, - /// Definitions of each symbol visible at end of scope. - public_definitions: IndexVec, + /// [`SymbolState`] visible at end of scope for each symbol. + public_definitions: IndexVec, } impl<'db> UseDefMap<'db> { - pub(crate) fn use_definitions(&self, use_id: ScopedUseId) -> &[Definition<'db>] { - &self.all_definitions[self.definitions_by_use[use_id].definitions_range.clone()] + pub(crate) fn use_definitions( + &self, + use_id: ScopedUseId, + ) -> DefinitionWithConstraintsIterator<'_, 'db> { + DefinitionWithConstraintsIterator { + all_definitions: &self.all_definitions, + all_constraints: &self.all_constraints, + wrapped: self.definitions_by_use[use_id].iter_visible_definitions(), + } } pub(crate) fn use_may_be_unbound(&self, use_id: ScopedUseId) -> bool { - self.definitions_by_use[use_id].may_be_unbound + self.definitions_by_use[use_id].may_be_unbound() } - pub(crate) fn public_definitions(&self, symbol: ScopedSymbolId) -> &[Definition<'db>] { - &self.all_definitions[self.public_definitions[symbol].definitions_range.clone()] + pub(crate) fn public_definitions( + &self, + symbol: ScopedSymbolId, + ) -> DefinitionWithConstraintsIterator<'_, 'db> { + DefinitionWithConstraintsIterator { + all_definitions: &self.all_definitions, + all_constraints: &self.all_constraints, + wrapped: self.public_definitions[symbol].iter_visible_definitions(), + } } pub(crate) fn public_may_be_unbound(&self, symbol: ScopedSymbolId) -> bool { - self.public_definitions[symbol].may_be_unbound + self.public_definitions[symbol].may_be_unbound() } } -/// Definitions visible for a symbol at a particular use (or end-of-scope). -#[derive(Clone, Debug, PartialEq, Eq)] -struct Definitions { - /// [`Range`] in `all_definitions` of the visible definition IDs. - definitions_range: Range, - /// Is the symbol possibly unbound at this point? - may_be_unbound: bool, +#[derive(Debug)] +pub(crate) struct DefinitionWithConstraintsIterator<'map, 'db> { + all_definitions: &'map IndexVec>, + all_constraints: &'map IndexVec>, + wrapped: DefinitionIdWithConstraintsIterator<'map>, } -impl Definitions { - /// The default state of a symbol is "no definitions, may be unbound", aka definitely-unbound. - fn unbound() -> Self { - Self { - definitions_range: Range::default(), - may_be_unbound: true, - } +impl<'map, 'db> Iterator for DefinitionWithConstraintsIterator<'map, 'db> { + type Item = DefinitionWithConstraints<'map, 'db>; + + fn next(&mut self) -> Option { + self.wrapped + .next() + .map(|def_id_with_constraints| DefinitionWithConstraints { + definition: self.all_definitions[def_id_with_constraints.definition], + constraints: ConstraintsIterator { + all_constraints: self.all_constraints, + constraint_ids: def_id_with_constraints.constraint_ids, + }, + }) } } -impl Default for Definitions { - fn default() -> Self { - Definitions::unbound() +pub(crate) struct DefinitionWithConstraints<'map, 'db> { + pub(crate) definition: Definition<'db>, + pub(crate) constraints: ConstraintsIterator<'map, 'db>, +} + +pub(crate) struct ConstraintsIterator<'map, 'db> { + all_constraints: &'map IndexVec>, + constraint_ids: ConstraintIdIterator<'map>, +} + +impl<'map, 'db> Iterator for ConstraintsIterator<'map, 'db> { + type Item = Expression<'db>; + + fn next(&mut self) -> Option { + self.constraint_ids + .next() + .map(|constraint_id| self.all_constraints[constraint_id]) } } -/// A snapshot of the visible definitions for each symbol at a particular point in control flow. +impl std::iter::FusedIterator for ConstraintsIterator<'_, '_> {} + +/// A snapshot of the definitions and constraints state at a particular point in control flow. #[derive(Clone, Debug)] pub(super) struct FlowSnapshot { - definitions_by_symbol: IndexVec, + definitions_by_symbol: IndexVec, } -#[derive(Debug)] +#[derive(Debug, Default)] pub(super) struct UseDefMapBuilder<'db> { - /// Definition IDs array for `definitions_by_use` and `definitions_by_symbol` to slice into. - all_definitions: Vec>, + /// Append-only array of [`Definition`]; None is unbound. + all_definitions: IndexVec>, + + /// Append-only array of constraints (as [`Expression`]). + all_constraints: IndexVec>, /// Visible definitions at each so-far-recorded use. - definitions_by_use: IndexVec, + definitions_by_use: IndexVec, /// Currently visible definitions for each symbol. - definitions_by_symbol: IndexVec, + definitions_by_symbol: IndexVec, } impl<'db> UseDefMapBuilder<'db> { pub(super) fn new() -> Self { - Self { - all_definitions: Vec::new(), - definitions_by_use: IndexVec::new(), - definitions_by_symbol: IndexVec::new(), - } + Self::default() } pub(super) fn add_symbol(&mut self, symbol: ScopedSymbolId) { - let new_symbol = self.definitions_by_symbol.push(Definitions::unbound()); + let new_symbol = self.definitions_by_symbol.push(SymbolState::unbound()); debug_assert_eq!(symbol, new_symbol); } @@ -227,13 +284,15 @@ impl<'db> UseDefMapBuilder<'db> { ) { // We have a new definition of a symbol; this replaces any previous definitions in this // path. - let def_idx = self.all_definitions.len(); - self.all_definitions.push(definition); - self.definitions_by_symbol[symbol] = Definitions { - #[allow(clippy::range_plus_one)] - definitions_range: def_idx..(def_idx + 1), - may_be_unbound: false, - }; + let def_id = self.all_definitions.push(definition); + self.definitions_by_symbol[symbol] = SymbolState::with(def_id); + } + + pub(super) fn record_constraint(&mut self, constraint: Expression<'db>) { + let constraint_id = self.all_constraints.push(constraint); + for definitions in &mut self.definitions_by_symbol { + definitions.add_constraint(constraint_id); + } } pub(super) fn record_use(&mut self, symbol: ScopedSymbolId, use_id: ScopedUseId) { @@ -265,9 +324,9 @@ impl<'db> UseDefMapBuilder<'db> { // If the snapshot we are restoring is missing some symbols we've recorded since, we need // to fill them in so the symbol IDs continue to line up. Since they don't exist in the - // snapshot, the correct state to fill them in with is "unbound", the default. + // snapshot, the correct state to fill them in with is "unbound". self.definitions_by_symbol - .resize(num_symbols, Definitions::unbound()); + .resize(num_symbols, SymbolState::unbound()); } /// Merge the given snapshot into the current state, reflecting that we might have taken either @@ -288,65 +347,24 @@ impl<'db> UseDefMapBuilder<'db> { debug_assert!(self.definitions_by_symbol.len() >= snapshot.definitions_by_symbol.len()); for (symbol_id, current) in self.definitions_by_symbol.iter_mut_enumerated() { - let Some(snapshot) = snapshot.definitions_by_symbol.get(symbol_id) else { - // Symbol not present in snapshot, so it's unbound from that path. - current.may_be_unbound = true; - continue; - }; - - // If the symbol can be unbound in either predecessor, it can be unbound post-merge. - current.may_be_unbound |= snapshot.may_be_unbound; - - // Merge the definition ranges. - let current = &mut current.definitions_range; - let snapshot = &snapshot.definitions_range; - - // We never create reversed ranges. - debug_assert!(current.end >= current.start); - debug_assert!(snapshot.end >= snapshot.start); - - if current == snapshot { - // Ranges already identical, nothing to do. - } else if snapshot.is_empty() { - // Merging from an empty range; nothing to do. - } else if (*current).is_empty() { - // Merging to an empty range; just use the incoming range. - *current = snapshot.clone(); - } else if snapshot.end >= current.start && snapshot.start <= current.end { - // Ranges are adjacent or overlapping, merge them in-place. - *current = current.start.min(snapshot.start)..current.end.max(snapshot.end); - } else if current.end == self.all_definitions.len() { - // Ranges are not adjacent or overlapping, `current` is at the end of - // `all_definitions`, we need to copy `snapshot` to the end so they are adjacent - // and can be merged into one range. - self.all_definitions.extend_from_within(snapshot.clone()); - current.end = self.all_definitions.len(); - } else if snapshot.end == self.all_definitions.len() { - // Ranges are not adjacent or overlapping, `snapshot` is at the end of - // `all_definitions`, we need to copy `current` to the end so they are adjacent and - // can be merged into one range. - self.all_definitions.extend_from_within(current.clone()); - current.start = snapshot.start; - current.end = self.all_definitions.len(); + if let Some(snapshot) = snapshot.definitions_by_symbol.get(symbol_id) { + *current = SymbolState::merge(current, snapshot); } else { - // Ranges are not adjacent and neither one is at the end of `all_definitions`, we - // have to copy both to the end so they are adjacent and we can merge them. - let start = self.all_definitions.len(); - self.all_definitions.extend_from_within(current.clone()); - self.all_definitions.extend_from_within(snapshot.clone()); - current.start = start; - current.end = self.all_definitions.len(); + // Symbol not present in snapshot, so it's unbound from that path. + current.add_unbound(); } } } pub(super) fn finish(mut self) -> UseDefMap<'db> { self.all_definitions.shrink_to_fit(); + self.all_constraints.shrink_to_fit(); self.definitions_by_symbol.shrink_to_fit(); self.definitions_by_use.shrink_to_fit(); UseDefMap { all_definitions: self.all_definitions, + all_constraints: self.all_constraints, definitions_by_use: self.definitions_by_use, public_definitions: self.definitions_by_symbol, } diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs new file mode 100644 index 00000000000000..1cc7820f6660d4 --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs @@ -0,0 +1,402 @@ +use std::collections::{btree_set, BTreeSet}; + +/// Ordered set of `u32`. +/// +/// Uses an inline bit-set for small values (up to 64 * B), falls back to a [`BTreeSet`] if a +/// larger value is inserted. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) enum BitSet { + /// Bit-set (in 64-bit blocks) for the first 64 * B entries. + Inline([u64; B]), + + /// Overflow beyond 64 * B. + Heap(BTreeSet), +} + +impl Default for BitSet { + fn default() -> Self { + // B * 64 must fit in a u32, or else we have unusable bits; this assertion makes the + // truncating casts to u32 below safe. This would be better as a const assertion, but + // that's not possible on stable with const generic params. (B should never really be + // anywhere close to this large.) + debug_assert!(B * 64 < (u32::MAX as usize)); + Self::Inline([0; B]) + } +} + +impl BitSet { + // SAFETY: We check above that B * 64 < u32::MAX. + #[allow(clippy::cast_possible_truncation)] + const BITS: u32 = (64 * B) as u32; + + /// Create and return a new [`BitSet`] with a single `value` inserted. + pub(super) fn with(value: u32) -> Self { + let mut bitset = Self::default(); + bitset.insert(value); + bitset + } + + /// Convert from Inline to Heap representation. + fn overflow(&mut self) { + if matches!(self, Self::Inline(_)) { + *self = Self::Heap(self.iter().collect()); + } + } + + /// Insert a value into the [`BitSet`]. + /// + /// Return true if the value was newly inserted, false if already present. + pub(super) fn insert(&mut self, value: u32) -> bool { + if value >= Self::BITS { + self.overflow(); + } + match self { + Self::Inline(blocks) => { + let value_usize = value as usize; + let (block, index) = (value_usize / 64, value_usize % 64); + let missing = blocks[block] & (1_u64 << index) == 0; + blocks[block] |= 1_u64 << index; + missing + } + Self::Heap(set) => set.insert(value), + } + } + + /// Intersect in-place with another [`BitSet`]. + pub(super) fn intersect(&mut self, other: &BitSet) { + match (self, other) { + (Self::Inline(myblocks), Self::Inline(other_blocks)) => { + for i in 0..B { + myblocks[i] &= other_blocks[i]; + } + } + (Self::Heap(myset), Self::Heap(other_set)) => { + *myset = myset.intersection(other_set).copied().collect(); + } + (me, other) => { + for value in other.iter() { + me.insert(value); + } + } + } + } + + /// Return an iterator over the values (in ascending order) in this [`BitSet`]. + pub(super) fn iter(&self) -> BitSetIterator<'_, B> { + match self { + Self::Inline(blocks) => BitSetIterator::Inline { + blocks, + cur_block_index: 0, + cur_block: blocks[0], + }, + Self::Heap(set) => BitSetIterator::Heap(set.iter()), + } + } +} + +/// Iterator over values in a [`BitSet`]. +#[derive(Debug)] +pub(super) enum BitSetIterator<'a, const B: usize> { + Inline { + /// The blocks we are iterating over. + blocks: &'a [u64; B], + + /// The index of the block we are currently iterating through. + cur_block_index: usize, + + /// The block we are currently iterating through (and zeroing as we go.) + cur_block: u64, + }, + Heap(btree_set::Iter<'a, u32>), +} + +impl Iterator for BitSetIterator<'_, B> { + type Item = u32; + + fn next(&mut self) -> Option { + match self { + Self::Inline { + blocks, + cur_block_index, + cur_block, + } => { + while *cur_block == 0 { + if *cur_block_index == B - 1 { + return None; + } + *cur_block_index += 1; + *cur_block = blocks[*cur_block_index]; + } + let value = cur_block.trailing_zeros(); + // reset the lowest set bit + *cur_block &= cur_block.wrapping_sub(1); + // SAFETY: `value` cannot be more than 64, `cur_block_index` cannot be more than + // `B - 1`, and we check above that `B * 64 < u32::MAX`. So both `64 * + // cur_block_index` and the final value here must fit in u32. + #[allow(clippy::cast_possible_truncation)] + Some(value + (64 * *cur_block_index) as u32) + } + Self::Heap(set_iter) => set_iter.next().copied(), + } + } +} + +impl std::iter::FusedIterator for BitSetIterator<'_, B> {} + +/// Array of [`BitSet`]. Up to N stored inline, more than that overflows to a Vector. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) enum BitSetArray { + Inline { + /// Array of N [`BitSet`]. + array: [BitSet; N], + + /// How many of the bitsets are used? + size: usize, + }, + + Heap(Vec>), +} + +impl Default for BitSetArray { + fn default() -> Self { + Self::Inline { + array: std::array::from_fn(|_| BitSet::default()), + size: 0, + } + } +} + +impl BitSetArray { + /// Create a [`BitSetArray`] of `size` empty [`BitSet`]s. + pub(super) fn of_size(size: usize) -> Self { + let mut array = Self::default(); + for _ in 0..size { + array.push(BitSet::default()); + } + array + } + + fn overflow(&mut self) { + match self { + Self::Inline { array, size } => { + let mut vec: Vec> = vec![]; + for bitset in array.iter().take(*size) { + vec.push(bitset.clone()); + } + *self = Self::Heap(vec); + } + Self::Heap(_) => {} + } + } + + /// Push a [`BitSet`] onto the end of the array. + pub(super) fn push(&mut self, new: BitSet) { + match self { + Self::Inline { array, size } => { + *size += 1; + if *size > N { + self.overflow(); + self.push(new); + } else { + array[*size - 1] = new; + } + } + Self::Heap(vec) => vec.push(new), + } + } + + /// Return a mutable reference to the last [`BitSet`] in the array, or None. + pub(super) fn last_mut(&mut self) -> Option<&mut BitSet> { + match self { + Self::Inline { array, size } => { + if *size == 0 { + None + } else { + Some(&mut array[*size - 1]) + } + } + Self::Heap(vec) => vec.last_mut(), + } + } + + /// Insert `value` into every [`BitSet`] in this [`BitSetArray`]. + pub(super) fn insert_in_each(&mut self, value: u32) { + match self { + Self::Inline { array, size } => { + for bitset in array.iter_mut().take(*size) { + bitset.insert(value); + } + } + Self::Heap(vec) => { + for bitset in vec { + bitset.insert(value); + } + } + } + } + + /// Return an iterator over each [`BitSet`] in this [`BitSetArray`]. + pub(super) fn iter(&self) -> BitSetArrayIterator<'_, B, N> { + match self { + Self::Inline { array, size } => BitSetArrayIterator::Inline { + array, + index: 0, + size: *size, + }, + Self::Heap(vec) => BitSetArrayIterator::Heap(vec.iter()), + } + } +} + +/// Iterator over a [`BitSetArray`]. +#[derive(Debug)] +pub(super) enum BitSetArrayIterator<'a, const B: usize, const N: usize> { + Inline { + array: &'a [BitSet; N], + index: usize, + size: usize, + }, + + Heap(core::slice::Iter<'a, BitSet>), +} + +impl<'a, const B: usize, const N: usize> Iterator for BitSetArrayIterator<'a, B, N> { + type Item = &'a BitSet; + + fn next(&mut self) -> Option { + match self { + Self::Inline { array, index, size } => { + if *index >= *size { + return None; + } + let ret = Some(&array[*index]); + *index += 1; + ret + } + Self::Heap(iter) => iter.next(), + } + } +} + +impl std::iter::FusedIterator for BitSetArrayIterator<'_, B, N> {} + +#[cfg(test)] +mod tests { + use super::{BitSet, BitSetArray}; + + fn assert_bitset(bitset: &BitSet, contents: &[u32]) { + assert_eq!(bitset.iter().collect::>(), contents); + } + + mod bitset { + use super::{assert_bitset, BitSet}; + + #[test] + fn iter() { + let mut b = BitSet::<1>::with(3); + b.insert(27); + b.insert(6); + assert!(matches!(b, BitSet::Inline(_))); + assert_bitset(&b, &[3, 6, 27]); + } + + #[test] + fn iter_overflow() { + let mut b = BitSet::<1>::with(140); + b.insert(100); + b.insert(129); + assert!(matches!(b, BitSet::Heap(_))); + assert_bitset(&b, &[100, 129, 140]); + } + + #[test] + fn intersect() { + let mut b1 = BitSet::<1>::with(4); + let mut b2 = BitSet::<1>::with(4); + b1.insert(23); + b2.insert(5); + + b1.intersect(&b2); + assert_bitset(&b1, &[4]); + } + + #[test] + fn multiple_blocks() { + let mut b = BitSet::<2>::with(120); + b.insert(45); + assert!(matches!(b, BitSet::Inline(_))); + assert_bitset(&b, &[45, 120]); + } + } + + fn assert_array( + array: &BitSetArray, + contents: &[Vec], + ) { + assert_eq!( + array + .iter() + .map(|bitset| bitset.iter().collect::>()) + .collect::>(), + contents + ); + } + + mod bitset_array { + use super::{assert_array, BitSet, BitSetArray}; + + #[test] + fn insert_in_each() { + let mut ba = BitSetArray::<1, 2>::default(); + assert_array(&ba, &[]); + + ba.push(BitSet::default()); + assert_array(&ba, &[vec![]]); + + ba.insert_in_each(3); + assert_array(&ba, &[vec![3]]); + + ba.push(BitSet::default()); + assert_array(&ba, &[vec![3], vec![]]); + + ba.insert_in_each(79); + assert_array(&ba, &[vec![3, 79], vec![79]]); + + assert!(matches!(ba, BitSetArray::Inline { .. })); + + ba.push(BitSet::default()); + assert!(matches!(ba, BitSetArray::Heap(_))); + assert_array(&ba, &[vec![3, 79], vec![79], vec![]]); + + ba.insert_in_each(130); + assert_array(&ba, &[vec![3, 79, 130], vec![79, 130], vec![130]]); + } + + #[test] + fn of_size() { + let mut ba = BitSetArray::<1, 2>::of_size(1); + ba.insert_in_each(5); + assert_array(&ba, &[vec![5]]); + } + + #[test] + fn last_mut() { + let mut ba = BitSetArray::<1, 2>::of_size(1); + ba.insert_in_each(3); + ba.insert_in_each(5); + + ba.last_mut() + .expect("last to not be None") + .intersect(&BitSet::with(3)); + + assert_array(&ba, &[vec![3]]); + } + + #[test] + fn last_mut_none() { + let mut ba = BitSetArray::<1, 1>::default(); + + assert!(ba.last_mut().is_none()); + } + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs new file mode 100644 index 00000000000000..d3c89b237cf73c --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs @@ -0,0 +1,362 @@ +//! Track visible definitions of a symbol, and applicable constraints per definition. +//! +//! These data structures operate entirely on scope-local newtype-indices for definitions and +//! constraints, referring to their location in the `all_definitions` and `all_constraints` +//! indexvecs in [`super::UseDefMapBuilder`]. +//! +//! We need to track arbitrary associations between definitions and constraints, not just a single +//! set of currently dominating constraints (where "dominating" means "control flow must have +//! passed through it to reach this point"), because we can have dominating constraints that apply +//! to some definitions but not others, as in this code: +//! +//! ```python +//! x = 1 if flag else None +//! if x is not None: +//! if flag2: +//! x = 2 if flag else None +//! x +//! ``` +//! +//! From the point of view of the final use of `x`, the `x is not None` constraint is "currently +//! visible" (in the sense that control flow must have passed through it to reach this point), but +//! it applies only to the first definition of `x`, not the second, so `None` is a possible value +//! for `x`. +//! +//! And we can't just track an index into a list of dominating constraints for each definition, +//! either, because we can have definitions which are still visible, but subject to constraints +//! that are no longer dominating, as in this code: +//! +//! ```python +//! x = 0 +//! if flag1: +//! x = 1 if flag2 else None +//! assert x is not None +//! x +//! ``` +//! +//! From the point of view of the final use of `x`, the `x is not None` constraint no longer +//! dominates, but it does dominate the `x = 1 if flag2 else None` definition, so we have to keep +//! track of that. +//! +//! The data structures used here ([`BitSet`] and [`BitSetArray`]) optimize for keeping all data +//! inline (avoiding lots of scattered allocations) in small-to-medium cases, and falling back to +//! heap allocation to be able to scale to arbitrary numbers of definitions and constraints when +//! needed. +use super::bitset::{BitSet, BitSetArray, BitSetArrayIterator, BitSetIterator}; +use ruff_index::newtype_index; + +/// A newtype-index for a definition in a particular scope. +#[newtype_index] +pub(super) struct ScopedDefinitionId; + +/// A newtype-index for a constraint expression in a particular scope. +#[newtype_index] +pub(super) struct ScopedConstraintId; + +/// Can reference this * 64 total definitions inline; more will fall back to the heap. +const INLINE_DEFINITION_BLOCKS: usize = 1; + +/// A [`BitSet`] of [`ScopedDefinitionId`], representing visible definitions of a symbol in a scope. +type Definitions = BitSet; +type DefinitionsIterator<'a> = BitSetIterator<'a, INLINE_DEFINITION_BLOCKS>; + +/// Can reference this * 64 total constraints inline; more will fall back to the heap. +const INLINE_CONSTRAINT_BLOCKS: usize = 1; + +/// Can handle this many visible definitions per symbol at a given time inline. +const INLINE_VISIBLE_DEFINITIONS_PER_SYMBOL: usize = 4; + +/// A [`BitSetArray`];one [`BitSet`] of applicable [`ScopedConstraintId`] per visible definition. +type Constraints = BitSetArray; +type ConstraintsIterator<'a> = + BitSetArrayIterator<'a, INLINE_CONSTRAINT_BLOCKS, INLINE_VISIBLE_DEFINITIONS_PER_SYMBOL>; + +/// Visible definitions and narrowing constraints for a single symbol at some point in control flow. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(super) struct SymbolState { + /// [`BitSet`]: which [`ScopedDefinitionId`] are visible for this symbol? + visible_definitions: Definitions, + + /// For each definition, which [`ScopedConstraintId`] apply? + /// + /// This is a [`BitSetArray`] which should always have one [`BitSet`] of constraints per + /// definition in `visible_definitions`. + constraints: Constraints, + + /// Could the symbol be unbound at this point? + may_be_unbound: bool, +} + +/// A single [`ScopedDefinitionId`] with an iterator of its applicable [`ScopedConstraintId`]. +#[derive(Debug)] +pub(super) struct DefinitionIdWithConstraints<'a> { + pub(super) definition: ScopedDefinitionId, + pub(super) constraint_ids: ConstraintIdIterator<'a>, +} + +impl SymbolState { + /// Return a new [`SymbolState`] representing an unbound symbol. + pub(super) fn unbound() -> Self { + Self { + visible_definitions: Definitions::default(), + constraints: Constraints::default(), + may_be_unbound: true, + } + } + + /// Return a new [`SymbolState`] representing a symbol with a single visible definition. + pub(super) fn with(definition_id: ScopedDefinitionId) -> Self { + Self { + visible_definitions: Definitions::with(definition_id.into()), + constraints: Constraints::of_size(1), + may_be_unbound: false, + } + } + + /// Add Unbound as a possibility for this symbol. + pub(super) fn add_unbound(&mut self) { + self.may_be_unbound = true; + } + + /// Add given constraint to all currently-visible definitions. + pub(super) fn add_constraint(&mut self, constraint_id: ScopedConstraintId) { + self.constraints.insert_in_each(constraint_id.into()); + } + + /// Merge two [`SymbolState`] and return the result. + pub(super) fn merge(a: &SymbolState, b: &SymbolState) -> SymbolState { + let mut merged = Self { + visible_definitions: Definitions::default(), + constraints: Constraints::default(), + may_be_unbound: a.may_be_unbound || b.may_be_unbound, + }; + let mut a_defs_iter = a.visible_definitions.iter(); + let mut b_defs_iter = b.visible_definitions.iter(); + let mut a_constraints_iter = a.constraints.iter(); + let mut b_constraints_iter = b.constraints.iter(); + + let mut opt_a_def: Option = a_defs_iter.next(); + let mut opt_b_def: Option = b_defs_iter.next(); + + // Iterate through the definitions from `a` and `b`, always processing the lower definition + // ID first, and pushing each definition onto the merged `SymbolState` with its + // constraints. If a definition is found in both `a` and `b`, push it with the intersection + // of the constraints from the two paths; a constraint that applies from only one possible + // path is irrelevant. + + // Helper to push `def`, with constraints in `constraints_iter`, onto merged`. + let push = |def, constraints_iter: &mut ConstraintsIterator, merged: &mut Self| { + merged.visible_definitions.insert(def); + let Some(constraints) = constraints_iter.next() else { + // SAFETY: we only ever create SymbolState with either no definitions and no + // constraint bitsets (`::unbound`) or one definition and one constraint bitset + // (`::with`), and `::merge` always pushes one definition and one constraint bitset + // together (just below), so the number of definitions and the number of constraint + // bitsets can never get out of sync. + unreachable!("definitions and constraints length mismatch"); + }; + merged.constraints.push(constraints.clone()); + }; + + loop { + match (opt_a_def, opt_b_def) { + (Some(a_def), Some(b_def)) => match a_def.cmp(&b_def) { + std::cmp::Ordering::Less => { + // Next definition ID is only in `a`, push it to `merged` and advance `a`. + push(a_def, &mut a_constraints_iter, &mut merged); + opt_a_def = a_defs_iter.next(); + } + std::cmp::Ordering::Greater => { + // Next definition ID is only in `b`, push it to `merged` and advance `b`. + push(b_def, &mut b_constraints_iter, &mut merged); + opt_b_def = b_defs_iter.next(); + } + std::cmp::Ordering::Equal => { + // Next definition is in both; push to `merged` and intersect constraints. + push(a_def, &mut a_constraints_iter, &mut merged); + let Some(b_constraints) = b_constraints_iter.next() else { + // SAFETY: see above. + unreachable!("definitions and constraints length mismatch"); + }; + // If the same definition is visible through both paths, any constraint + // that applies on only one path is irrelevant to the resulting type from + // unioning the two paths, so we intersect the constraints. + merged + .constraints + .last_mut() + .unwrap() + .intersect(b_constraints); + opt_a_def = a_defs_iter.next(); + opt_b_def = b_defs_iter.next(); + } + }, + (Some(a_def), None) => { + // We've exhausted `b`, just push the def from `a` and move on to the next. + push(a_def, &mut a_constraints_iter, &mut merged); + opt_a_def = a_defs_iter.next(); + } + (None, Some(b_def)) => { + // We've exhausted `a`, just push the def from `b` and move on to the next. + push(b_def, &mut b_constraints_iter, &mut merged); + opt_b_def = b_defs_iter.next(); + } + (None, None) => break, + } + } + merged + } + + /// Get iterator over visible definitions with constraints. + pub(super) fn iter_visible_definitions(&self) -> DefinitionIdWithConstraintsIterator { + DefinitionIdWithConstraintsIterator { + definitions: self.visible_definitions.iter(), + constraints: self.constraints.iter(), + } + } + + /// Could the symbol be unbound? + pub(super) fn may_be_unbound(&self) -> bool { + self.may_be_unbound + } +} + +/// The default state of a symbol (if we've seen no definitions of it) is unbound. +impl Default for SymbolState { + fn default() -> Self { + SymbolState::unbound() + } +} + +#[derive(Debug)] +pub(super) struct DefinitionIdWithConstraintsIterator<'a> { + definitions: DefinitionsIterator<'a>, + constraints: ConstraintsIterator<'a>, +} + +impl<'a> Iterator for DefinitionIdWithConstraintsIterator<'a> { + type Item = DefinitionIdWithConstraints<'a>; + + fn next(&mut self) -> Option { + match (self.definitions.next(), self.constraints.next()) { + (None, None) => None, + (Some(def), Some(constraints)) => Some(DefinitionIdWithConstraints { + definition: ScopedDefinitionId::from_u32(def), + constraint_ids: ConstraintIdIterator { + wrapped: constraints.iter(), + }, + }), + // SAFETY: see above. + _ => unreachable!("definitions and constraints length mismatch"), + } + } +} + +impl std::iter::FusedIterator for DefinitionIdWithConstraintsIterator<'_> {} + +#[derive(Debug)] +pub(super) struct ConstraintIdIterator<'a> { + wrapped: BitSetIterator<'a, INLINE_CONSTRAINT_BLOCKS>, +} + +impl Iterator for ConstraintIdIterator<'_> { + type Item = ScopedConstraintId; + + fn next(&mut self) -> Option { + self.wrapped.next().map(ScopedConstraintId::from_u32) + } +} + +impl std::iter::FusedIterator for ConstraintIdIterator<'_> {} + +#[cfg(test)] +mod tests { + use super::{ScopedConstraintId, ScopedDefinitionId, SymbolState}; + + impl SymbolState { + pub(crate) fn assert(&self, may_be_unbound: bool, expected: &[&str]) { + assert_eq!(self.may_be_unbound(), may_be_unbound); + let actual = self + .iter_visible_definitions() + .map(|def_id_with_constraints| { + format!( + "{}<{}>", + def_id_with_constraints.definition.as_u32(), + def_id_with_constraints + .constraint_ids + .map(ScopedConstraintId::as_u32) + .map(|idx| idx.to_string()) + .collect::>() + .join(", ") + ) + }) + .collect::>(); + assert_eq!(actual, expected); + } + } + + #[test] + fn unbound() { + let cd = SymbolState::unbound(); + + cd.assert(true, &[]); + } + + #[test] + fn with() { + let cd = SymbolState::with(ScopedDefinitionId::from_u32(0)); + + cd.assert(false, &["0<>"]); + } + + #[test] + fn add_unbound() { + let mut cd = SymbolState::with(ScopedDefinitionId::from_u32(0)); + cd.add_unbound(); + + cd.assert(true, &["0<>"]); + } + + #[test] + fn add_constraint() { + let mut cd = SymbolState::with(ScopedDefinitionId::from_u32(0)); + cd.add_constraint(ScopedConstraintId::from_u32(0)); + + cd.assert(false, &["0<0>"]); + } + + #[test] + fn merge() { + // merging the same definition with the same constraint keeps the constraint + let mut cd0a = SymbolState::with(ScopedDefinitionId::from_u32(0)); + cd0a.add_constraint(ScopedConstraintId::from_u32(0)); + + let mut cd0b = SymbolState::with(ScopedDefinitionId::from_u32(0)); + cd0b.add_constraint(ScopedConstraintId::from_u32(0)); + + let cd0 = SymbolState::merge(&cd0a, &cd0b); + cd0.assert(false, &["0<0>"]); + + // merging the same definition with differing constraints drops all constraints + let mut cd1a = SymbolState::with(ScopedDefinitionId::from_u32(1)); + cd1a.add_constraint(ScopedConstraintId::from_u32(1)); + + let mut cd1b = SymbolState::with(ScopedDefinitionId::from_u32(1)); + cd1b.add_constraint(ScopedConstraintId::from_u32(2)); + + let cd1 = SymbolState::merge(&cd1a, &cd1b); + cd1.assert(false, &["1<>"]); + + // merging a constrained definition with unbound keeps both + let mut cd2a = SymbolState::with(ScopedDefinitionId::from_u32(2)); + cd2a.add_constraint(ScopedConstraintId::from_u32(3)); + + let cd2b = SymbolState::unbound(); + + let cd2 = SymbolState::merge(&cd2a, &cd2b); + cd2.assert(true, &["2<3>"]); + + // merging different definitions keeps them each with their existing constraints + let cd = SymbolState::merge(&cd0, &cd2); + cd.assert(true, &["0<0>", "2<3>"]); + } +} diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 37430d95c3aa6b..1b4f57416b0569 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -4,15 +4,22 @@ use ruff_python_ast::name::Name; use crate::builtins::builtins_scope; use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId}; -use crate::semantic_index::{global_scope, symbol_table, use_def_map}; +use crate::semantic_index::{ + global_scope, symbol_table, use_def_map, DefinitionWithConstraints, + DefinitionWithConstraintsIterator, +}; +use crate::types::narrow::narrowing_constraint; use crate::{Db, FxOrderSet}; mod builder; mod display; mod infer; +mod narrow; -pub(crate) use self::builder::UnionBuilder; -pub(crate) use self::infer::{infer_definition_types, infer_scope_types}; +pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder}; +pub(crate) use self::infer::{ + infer_definition_types, infer_expression_types, infer_scope_types, TypeInference, +}; /// Infer the public type of a symbol (its type as seen from outside its scope). pub(crate) fn symbol_ty<'db>( @@ -82,10 +89,24 @@ pub(crate) fn definition_ty<'db>(db: &'db dyn Db, definition: Definition<'db>) - /// provide an `unbound_ty`. pub(crate) fn definitions_ty<'db>( db: &'db dyn Db, - definitions: &[Definition<'db>], + definitions_with_constraints: DefinitionWithConstraintsIterator<'_, 'db>, unbound_ty: Option>, ) -> Type<'db> { - let def_types = definitions.iter().map(|def| definition_ty(db, *def)); + let def_types = definitions_with_constraints.map( + |DefinitionWithConstraints { + definition, + constraints, + }| { + let mut builder = IntersectionBuilder::new(db); + builder = builder.add_positive(definition_ty(db, definition)); + for constraint_ty in + constraints.filter_map(|test| narrowing_constraint(db, test, definition)) + { + builder = builder.add_positive(constraint_ty); + } + builder.build() + }, + ); let mut all_types = unbound_ty.into_iter().chain(def_types); let Some(first) = all_types.next() else { diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 9f8af0f2951606..e8c1c0dd6e29fe 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -65,7 +65,6 @@ impl<'db> UnionBuilder<'db> { } } -#[allow(unused)] #[derive(Clone)] pub(crate) struct IntersectionBuilder<'db> { // Really this builds a union-of-intersections, because we always keep our set-theoretic types @@ -78,8 +77,7 @@ pub(crate) struct IntersectionBuilder<'db> { } impl<'db> IntersectionBuilder<'db> { - #[allow(dead_code)] - fn new(db: &'db dyn Db) -> Self { + pub(crate) fn new(db: &'db dyn Db) -> Self { Self { db, intersections: vec![InnerIntersectionBuilder::new()], @@ -93,8 +91,7 @@ impl<'db> IntersectionBuilder<'db> { } } - #[allow(dead_code)] - fn add_positive(mut self, ty: Type<'db>) -> Self { + pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self { if let Type::Union(union) = ty { // Distribute ourself over this union: for each union element, clone ourself and // intersect with that union element, then create a new union-of-intersections with all @@ -122,8 +119,7 @@ impl<'db> IntersectionBuilder<'db> { } } - #[allow(dead_code)] - fn add_negative(mut self, ty: Type<'db>) -> Self { + pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self { // See comments above in `add_positive`; this is just the negated version. if let Type::Union(union) = ty { union @@ -142,8 +138,7 @@ impl<'db> IntersectionBuilder<'db> { } } - #[allow(dead_code)] - fn build(mut self) -> Type<'db> { + pub(crate) fn build(mut self) -> Type<'db> { // Avoid allocating the UnionBuilder unnecessarily if we have just one intersection: if self.intersections.len() == 1 { self.intersections.pop().unwrap().build(self.db) @@ -157,7 +152,6 @@ impl<'db> IntersectionBuilder<'db> { } } -#[allow(unused)] #[derive(Debug, Clone, Default)] struct InnerIntersectionBuilder<'db> { positive: FxOrderSet>, @@ -218,6 +212,16 @@ impl<'db> InnerIntersectionBuilder<'db> { self.negative.clear(); self.positive.insert(Type::Never); } + + // None intersects only with object + for pos in &self.positive { + if let Type::Instance(_) = pos { + // could be `object` type + } else { + self.negative.remove(&Type::None); + break; + } + } } fn build(mut self, db: &'db dyn Db) -> Type<'db> { @@ -426,4 +430,15 @@ mod tests { assert_eq!(ty, Type::Never); } + + #[test] + fn build_intersection_simplify_negative_none() { + let db = setup_db(); + let ty = IntersectionBuilder::new(&db) + .add_negative(Type::None) + .add_positive(Type::IntLiteral(1)) + .build(); + + assert_eq!(ty, Type::IntLiteral(1)); + } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index ea39ee0725736c..6bafc85eda85d8 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2291,6 +2291,26 @@ mod tests { Ok(()) } + #[test] + fn narrow_not_none() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + x = None if flag else 1 + y = 0 + if x is not None: + y = x + ", + )?; + + assert_public_ty(&db, "/src/a.py", "x", "Literal[1] | None"); + assert_public_ty(&db, "/src/a.py", "y", "Literal[0, 1]"); + + Ok(()) + } + #[test] fn while_loop() -> anyhow::Result<()> { let mut db = setup_db(); @@ -2388,10 +2408,11 @@ mod tests { fn first_public_def<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> { let scope = global_scope(db, file); - *use_def_map(db, scope) + use_def_map(db, scope) .public_definitions(symbol_table(db, scope).symbol_id_by_name(name).unwrap()) - .first() + .next() .unwrap() + .definition } #[test] diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs new file mode 100644 index 00000000000000..8256c35ad45f5a --- /dev/null +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -0,0 +1,100 @@ +use crate::semantic_index::ast_ids::HasScopedAstId; +use crate::semantic_index::definition::Definition; +use crate::semantic_index::expression::Expression; +use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; +use crate::semantic_index::symbol_table; +use crate::types::{infer_expression_types, IntersectionBuilder, Type, TypeInference}; +use crate::Db; +use ruff_python_ast as ast; +use rustc_hash::FxHashMap; +use std::sync::Arc; + +/// Return type constraint, if any, on `definition` applied by `test`. +pub(crate) fn narrowing_constraint<'db>( + db: &'db dyn Db, + test: Expression<'db>, + definition: Definition<'db>, +) -> Option> { + all_narrowing_constraints(db, test) + .get(&definition.symbol(db)) + .copied() +} + +#[salsa::tracked] +fn all_narrowing_constraints<'db>( + db: &'db dyn Db, + test: Expression<'db>, +) -> NarrowingConstraints<'db> { + NarrowingConstraintsBuilder::new(db, test).finish() +} + +type NarrowingConstraints<'db> = FxHashMap>; + +struct NarrowingConstraintsBuilder<'db> { + db: &'db dyn Db, + expression: Expression<'db>, + constraints: NarrowingConstraints<'db>, +} + +impl<'db> NarrowingConstraintsBuilder<'db> { + fn new(db: &'db dyn Db, expression: Expression<'db>) -> Self { + Self { + db, + expression, + constraints: NarrowingConstraints::default(), + } + } + + fn finish(mut self) -> NarrowingConstraints<'db> { + if let ast::Expr::Compare(expr_compare) = self.expression.node(self.db).node() { + self.add_expr_compare(expr_compare); + } + // TODO other test expression kinds + + self.constraints.shrink_to_fit(); + self.constraints + } + + fn symbols(&self) -> Arc { + symbol_table(self.db, self.scope()) + } + + fn scope(&self) -> ScopeId<'db> { + self.expression.scope(self.db) + } + + fn inference(&self) -> &TypeInference<'db> { + infer_expression_types(self.db, self.expression) + } + + fn add_expr_compare(&mut self, expr_compare: &ast::ExprCompare) { + let ast::ExprCompare { + range: _, + left, + ops, + comparators, + } = expr_compare; + + if let ast::Expr::Name(ast::ExprName { + range: _, + id, + ctx: _, + }) = left.as_ref() + { + // SAFETY: we should always have a symbol for every Name node. + let symbol = self.symbols().symbol_id_by_name(id).unwrap(); + for (op, comparator) in std::iter::zip(ops, comparators) { + let comp_ty = self + .inference() + .expression_ty(comparator.scoped_ast_id(self.db, self.scope())); + if matches!(op, ast::CmpOp::IsNot) { + let ty = IntersectionBuilder::new(self.db) + .add_negative(comp_ty) + .build(); + self.constraints.insert(symbol, ty); + }; + // TODO other comparison types + } + } + } +}