Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] infer_symbol_public_type infers union of all definitions #11667

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions crates/red_knot/src/lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::source::{source_text, Source};
use crate::symbols::{
resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, SymbolId, SymbolTable,
};
use crate::types::{infer_definition_type, infer_symbol_type, Type};
use crate::types::{infer_definition_type, infer_symbol_public_type, Type};

#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn lint_syntax(db: &dyn LintDb, file_id: FileId) -> QueryResult<Diagnostics> {
Expand Down Expand Up @@ -104,14 +104,14 @@ fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> {
for (symbol, definition) in context.symbols().all_definitions() {
match definition {
Definition::Import(import) => {
let ty = context.infer_symbol_type(symbol)?;
let ty = context.infer_symbol_public_type(symbol)?;

if ty.is_unknown() {
context.push_diagnostic(format!("Unresolved module {}", import.module));
}
}
Definition::ImportFrom(import) => {
let ty = context.infer_symbol_type(symbol)?;
let ty = context.infer_symbol_public_type(symbol)?;

if ty.is_unknown() {
let module_name = import.module().map(Deref::deref).unwrap_or_default();
Expand Down Expand Up @@ -217,8 +217,8 @@ impl<'a> SemanticLintContext<'a> {
&self.symbols
}

pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_type(
pub fn infer_symbol_public_type(&self, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_public_type(
self.db.upcast(),
GlobalSymbolId {
file_id: self.file_id,
Expand Down
20 changes: 10 additions & 10 deletions crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap;

pub(crate) mod infer;

pub(crate) use infer::{infer_definition_type, infer_symbol_type};
pub(crate) use infer::{infer_definition_type, infer_symbol_public_type};

/// unique ID for a type
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -119,7 +119,7 @@ impl TypeStore {
self.modules.remove(&file_id);
}

pub fn cache_symbol_type(&self, symbol: GlobalSymbolId, ty: Type) {
pub fn cache_symbol_public_type(&self, symbol: GlobalSymbolId, ty: Type) {
self.add_or_get_module(symbol.file_id)
.symbol_types
.insert(symbol.symbol_id, ty);
Expand All @@ -131,7 +131,7 @@ impl TypeStore {
.insert(node_key, ty);
}

pub fn get_cached_symbol_type(&self, symbol: GlobalSymbolId) -> Option<Type> {
pub fn get_cached_symbol_public_type(&self, symbol: GlobalSymbolId) -> Option<Type> {
self.try_get_module(symbol.file_id)?
.symbol_types
.get(&symbol.symbol_id)
Expand Down Expand Up @@ -182,12 +182,12 @@ impl TypeStore {
.add_class(name, scope_id, bases)
}

fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
self.add_or_get_module(file_id).add_union(elems)
}

fn add_intersection(
&mut self,
&self,
file_id: FileId,
positive: &[Type],
negative: &[Type],
Expand Down Expand Up @@ -393,7 +393,7 @@ impl ModuleTypeId {

fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult<Option<Type>> {
if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? {
Ok(Some(infer_symbol_type(db, symbol_id)?))
Ok(Some(infer_symbol_public_type(db, symbol_id)?))
} else {
Ok(None)
}
Expand Down Expand Up @@ -441,7 +441,7 @@ impl ClassTypeId {
let ClassType { scope_id, .. } = *self.class(db)?;
let table = symbol_table(db, self.file_id)?;
if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) {
Ok(Some(infer_symbol_type(
Ok(Some(infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: self.file_id,
Expand Down Expand Up @@ -497,7 +497,7 @@ struct ModuleTypeStore {
unions: IndexVec<ModuleUnionTypeId, UnionType>,
/// arena of all intersection types created in this module
intersections: IndexVec<ModuleIntersectionTypeId, IntersectionType>,
/// cached types of symbols in this module
/// cached public types of symbols in this module
symbol_types: FxHashMap<SymbolId, Type>,
/// cached types of AST nodes in this module
node_types: FxHashMap<NodeKey, Type>,
Expand Down Expand Up @@ -777,7 +777,7 @@ mod tests {

#[test]
fn add_union() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
Expand All @@ -794,7 +794,7 @@ mod tests {

#[test]
fn add_intersection() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
Expand Down
73 changes: 58 additions & 15 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,30 @@ use crate::{FileId, Name};

// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
#[tracing::instrument(level = "trace", skip(db))]
pub fn infer_symbol_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult<Type> {
pub fn infer_symbol_public_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult<Type> {
let symbols = symbol_table(db, symbol.file_id)?;
let defs = symbols.definitions(symbol.symbol_id);
let jar: &SemanticJar = db.jar()?;

if let Some(ty) = jar.type_store.get_cached_symbol_type(symbol) {
if let Some(ty) = jar.type_store.get_cached_symbol_public_type(symbol) {
return Ok(ty);
}

// TODO handle multiple defs, conditional defs...
assert_eq!(defs.len(), 1);

let ty = infer_definition_type(db, symbol, defs[0].clone())?;
// The public type of a symbol is the union of all of its definitions. This is the most
// cautious/sound approach, though it can lead to a broader-than-desired type in a case like
// `x = 1; x = str(x)`, where the first definition of `x` can never be visible. TODO prune
// definitions that we can prove can't be visible.
let tys = defs
.iter()
.map(|def| infer_definition_type(db, symbol, def.clone()))
.collect::<QueryResult<Vec<Type>>>()?;
let ty = match tys.len() {
0 => Type::Unknown,
1 => tys[0],
_ => Type::Union(jar.type_store.add_union(symbol.file_id, &tys)),
};

jar.type_store.cache_symbol_type(symbol, ty);
jar.type_store.cache_symbol_public_type(symbol, ty);

// TODO record dependencies
Ok(ty)
Expand Down Expand Up @@ -65,7 +74,7 @@ pub fn infer_definition_type(
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? {
infer_symbol_type(db, remote_symbol)
infer_symbol_public_type(db, remote_symbol)
} else {
Ok(Type::Unknown)
}
Expand Down Expand Up @@ -158,7 +167,8 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu
ast::Expr::Name(name) => {
// TODO look up in the correct scope, don't assume global
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
infer_symbol_type(db, GlobalSymbolId { file_id, symbol_id })
// TODO should use only reachable definitions, not public type
infer_symbol_public_type(db, GlobalSymbolId { file_id, symbol_id })
} else {
Ok(Type::Unknown)
}
Expand All @@ -182,7 +192,7 @@ mod tests {
resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind,
};
use crate::symbols::{symbol_table, GlobalSymbolId};
use crate::types::{infer_symbol_type, Type};
use crate::types::{infer_symbol_public_type, Type};
use crate::Name;

// TODO with virtual filesystem we shouldn't have to write files to disk for these
Expand Down Expand Up @@ -228,7 +238,7 @@ mod tests {
.root_symbol_id_by_name("E")
.expect("E symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: a_file,
Expand Down Expand Up @@ -259,7 +269,7 @@ mod tests {
.root_symbol_id_by_name("Sub")
.expect("Sub symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
Expand Down Expand Up @@ -300,7 +310,7 @@ mod tests {
.root_symbol_id_by_name("C")
.expect("C symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
Expand Down Expand Up @@ -345,7 +355,7 @@ mod tests {
.root_symbol_id_by_name("D")
.expect("D symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: a_file,
Expand Down Expand Up @@ -375,7 +385,7 @@ mod tests {
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
Expand All @@ -388,4 +398,37 @@ mod tests {
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[1]");
Ok(())
}

#[test]
fn resolve_union() -> anyhow::Result<()> {
let case = create_test()?;
let db = &case.db;

let path = case.src.path().join("a.py");
std::fs::write(path, "if flag:\n x = 1\nelse:\n x = 2")?;
let file = resolve_module(db, ModuleName::new("a"))?
.expect("module should be found")
.path(db)?
.file();
let syms = symbol_table(db, file)?;
let x_sym = syms
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
symbol_id: x_sym,
},
)?;

let jar = HasJar::<SemanticJar>::jar(db)?;
assert!(matches!(ty, Type::Union(_)));
assert_eq!(
format!("{}", ty.display(&jar.type_store)),
"(Literal[1] | Literal[2])"
);
Ok(())
}
}
Loading