From b1fcabc2bb7820f1c4520c0089747f7a7c68ea99 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 3 Jun 2024 17:19:08 +0200 Subject: [PATCH] red-knot: Change `resolve_global_symbol` to take `Module` as an argument --- crates/red_knot/src/lint.rs | 18 ++++++++++++---- crates/red_knot/src/symbols.rs | 9 +++----- crates/red_knot/src/types.rs | 2 +- crates/red_knot/src/types/infer.rs | 34 ++++++++++++++++++------------ 4 files changed, 39 insertions(+), 24 deletions(-) diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index cad38f814d26d..6d9834ec03339 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -10,7 +10,7 @@ use ruff_python_parser::Parsed; use crate::cache::KeyValueCache; use crate::db::{LintDb, LintJar, QueryResult}; use crate::files::FileId; -use crate::module::ModuleName; +use crate::module::{resolve_module, ModuleName}; use crate::parse::parse; use crate::source::{source_text, Source}; use crate::symbols::{ @@ -145,9 +145,7 @@ fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> { // TODO we should have a special marker on the real typing module (from typeshed) so if you // have your own "typing" module in your project, we don't consider it THE typing module (and // same for other stdlib modules that our lint rules care about) - let Some(typing_override) = - resolve_global_symbol(context.db.upcast(), ModuleName::new("typing"), "override")? - else { + let Some(typing_override) = context.resolve_global_symbol("typing", "override")? else { // TODO once we bundle typeshed, this should be unreachable!() return Ok(()); }; @@ -235,6 +233,18 @@ impl<'a> SemanticLintContext<'a> { pub fn extend_diagnostics(&mut self, diagnostics: impl IntoIterator) { self.diagnostics.get_mut().extend(diagnostics); } + + pub fn resolve_global_symbol( + &self, + module: &str, + symbol_name: &str, + ) -> QueryResult> { + let Some(module) = resolve_module(self.db.upcast(), ModuleName::new(module))? else { + return Ok(None); + }; + + resolve_global_symbol(self.db.upcast(), module, symbol_name) + } } #[derive(Debug)] diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 4ac5c67694b57..84ee74c93975b 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -18,7 +18,7 @@ use crate::ast_ids::{NodeKey, TypedNodeKey}; use crate::cache::KeyValueCache; use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; -use crate::module::{resolve_module, ModuleName}; +use crate::module::{Module, ModuleName}; use crate::parse::parse; use crate::Name; @@ -35,13 +35,10 @@ pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult QueryResult> { - let Some(typing_module) = resolve_module(db, module)? else { - return Ok(None); - }; - let typing_file = typing_module.path(db)?.file(); + let typing_file = module.path(db)?.file(); let typing_table = symbol_table(db, typing_file)?; let Some(typing_override) = typing_table.root_symbol_id_by_name(name) else { return Ok(None); diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index 8628a8549ce35..539aaf0d92f4e 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -392,7 +392,7 @@ impl ModuleTypeId { } fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { - if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? { + if let Some(symbol_id) = resolve_global_symbol(db, self.module, name)? { Ok(Some(infer_symbol_public_type(db, symbol_id)?)) } else { Ok(None) diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index b5697ef14c246..8472f9e646f2f 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -96,7 +96,11 @@ pub fn infer_definition_type( // TODO relative imports assert!(matches!(level, 0)); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); - if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? { + let Some(module) = resolve_module(db, module_name.clone())? else { + return Ok(Type::Unknown); + }; + + if let Some(remote_symbol) = resolve_global_symbol(db, module, &name)? { infer_symbol_public_type(db, remote_symbol) } else { Ok(Type::Unknown) @@ -248,30 +252,34 @@ mod tests { Ok(TestCase { temp_dir, db, src }) } - fn write_to_path(case: &TestCase, relpath: &str, contents: &str) -> anyhow::Result<()> { - let path = case.src.path().join(relpath); + fn write_to_path(case: &TestCase, relative_path: &str, contents: &str) -> anyhow::Result<()> { + let path = case.src.path().join(relative_path); std::fs::write(path, contents)?; Ok(()) } - fn get_public_type(case: &TestCase, modname: &str, varname: &str) -> anyhow::Result { + fn get_public_type( + case: &TestCase, + module_name: &str, + variable_name: &str, + ) -> anyhow::Result { let db = &case.db; - let symbol = - resolve_global_symbol(db, ModuleName::new(modname), varname)?.expect("symbol to exist"); + let module = resolve_module(db, ModuleName::new(module_name))?.expect("Module to exist"); + let symbol = resolve_global_symbol(db, module, variable_name)?.expect("symbol to exist"); Ok(infer_symbol_public_type(db, symbol)?) } fn assert_public_type( case: &TestCase, - modname: &str, - varname: &str, - tyname: &str, + module_name: &str, + variable_name: &str, + type_name: &str, ) -> anyhow::Result<()> { - let ty = get_public_type(case, modname, varname)?; + let ty = get_public_type(case, module_name, variable_name)?; let jar = HasJar::::jar(&case.db)?; - assert_eq!(format!("{}", ty.display(&jar.type_store)), tyname); + assert_eq!(format!("{}", ty.display(&jar.type_store)), type_name); Ok(()) } @@ -399,8 +407,8 @@ mod tests { .expect("module should be found") .path(db)? .file(); - let syms = symbol_table(db, file)?; - let x_sym = syms + let symbols = symbol_table(db, file)?; + let x_sym = symbols .root_symbol_id_by_name("x") .expect("x symbol should be found");