Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

red-knot: Change resolve_global_symbol to take Module as an argument #11723

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions crates/red_knot/src/lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use ruff_python_parser::Parsed;
use crate::cache::KeyValueCache;
use crate::db::{LintDb, LintJar, QueryResult};
use crate::files::FileId;
use crate::module::ModuleName;
use crate::module::{resolve_module, ModuleName};
use crate::parse::parse;
use crate::source::{source_text, Source};
use crate::symbols::{
Expand Down Expand Up @@ -145,9 +145,7 @@ fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> {
// TODO we should have a special marker on the real typing module (from typeshed) so if you
// have your own "typing" module in your project, we don't consider it THE typing module (and
// same for other stdlib modules that our lint rules care about)
let Some(typing_override) =
resolve_global_symbol(context.db.upcast(), ModuleName::new("typing"), "override")?
else {
let Some(typing_override) = context.resolve_global_symbol("typing", "override")? else {
// TODO once we bundle typeshed, this should be unreachable!()
return Ok(());
};
Expand Down Expand Up @@ -235,6 +233,18 @@ impl<'a> SemanticLintContext<'a> {
pub fn extend_diagnostics(&mut self, diagnostics: impl IntoIterator<Item = String>) {
self.diagnostics.get_mut().extend(diagnostics);
}

pub fn resolve_global_symbol(
&self,
module: &str,
symbol_name: &str,
) -> QueryResult<Option<GlobalSymbolId>> {
let Some(module) = resolve_module(self.db.upcast(), ModuleName::new(module))? else {
return Ok(None);
};

resolve_global_symbol(self.db.upcast(), module, symbol_name)
}
}

#[derive(Debug)]
Expand Down
9 changes: 3 additions & 6 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::ast_ids::{NodeKey, TypedNodeKey};
use crate::cache::KeyValueCache;
use crate::db::{QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::module::{resolve_module, ModuleName};
use crate::module::{Module, ModuleName};
use crate::parse::parse;
use crate::Name;

Expand All @@ -35,13 +35,10 @@ pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult<Arc<Sym
#[tracing::instrument(level = "debug", skip(db))]
pub fn resolve_global_symbol(
db: &dyn SemanticDb,
module: ModuleName,
module: Module,
name: &str,
) -> QueryResult<Option<GlobalSymbolId>> {
let Some(typing_module) = resolve_module(db, module)? else {
return Ok(None);
};
let typing_file = typing_module.path(db)?.file();
let typing_file = module.path(db)?.file();
let typing_table = symbol_table(db, typing_file)?;
let Some(typing_override) = typing_table.root_symbol_id_by_name(name) else {
return Ok(None);
Expand Down
2 changes: 1 addition & 1 deletion crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ impl ModuleTypeId {
}

fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult<Option<Type>> {
if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? {
if let Some(symbol_id) = resolve_global_symbol(db, self.module, name)? {
Ok(Some(infer_symbol_public_type(db, symbol_id)?))
} else {
Ok(None)
Expand Down
34 changes: 21 additions & 13 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ pub fn infer_definition_type(
// TODO relative imports
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? {
let Some(module) = resolve_module(db, module_name.clone())? else {
return Ok(Type::Unknown);
};

if let Some(remote_symbol) = resolve_global_symbol(db, module, &name)? {
infer_symbol_public_type(db, remote_symbol)
} else {
Ok(Type::Unknown)
Expand Down Expand Up @@ -248,30 +252,34 @@ mod tests {
Ok(TestCase { temp_dir, db, src })
}

fn write_to_path(case: &TestCase, relpath: &str, contents: &str) -> anyhow::Result<()> {
let path = case.src.path().join(relpath);
fn write_to_path(case: &TestCase, relative_path: &str, contents: &str) -> anyhow::Result<()> {
let path = case.src.path().join(relative_path);
std::fs::write(path, contents)?;
Ok(())
}

fn get_public_type(case: &TestCase, modname: &str, varname: &str) -> anyhow::Result<Type> {
fn get_public_type(
case: &TestCase,
module_name: &str,
variable_name: &str,
) -> anyhow::Result<Type> {
let db = &case.db;
let symbol =
resolve_global_symbol(db, ModuleName::new(modname), varname)?.expect("symbol to exist");
let module = resolve_module(db, ModuleName::new(module_name))?.expect("Module to exist");
let symbol = resolve_global_symbol(db, module, variable_name)?.expect("symbol to exist");

Ok(infer_symbol_public_type(db, symbol)?)
}

fn assert_public_type(
case: &TestCase,
modname: &str,
varname: &str,
tyname: &str,
module_name: &str,
variable_name: &str,
type_name: &str,
) -> anyhow::Result<()> {
let ty = get_public_type(case, modname, varname)?;
let ty = get_public_type(case, module_name, variable_name)?;

let jar = HasJar::<SemanticJar>::jar(&case.db)?;
assert_eq!(format!("{}", ty.display(&jar.type_store)), tyname);
assert_eq!(format!("{}", ty.display(&jar.type_store)), type_name);
Ok(())
}

Expand Down Expand Up @@ -399,8 +407,8 @@ mod tests {
.expect("module should be found")
.path(db)?
.file();
let syms = symbol_table(db, file)?;
let x_sym = syms
let symbols = symbol_table(db, file)?;
let x_sym = symbols
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

Expand Down
Loading