diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 4a65ecfa15bc6f..875d09ca3e2030 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -934,7 +934,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_import_definition(&mut self, alias: &ast::Alias, definition: Definition<'db>) { + fn infer_import_definition(&mut self, alias: &'db ast::Alias, definition: Definition<'db>) { let ast::Alias { range: _, name, @@ -945,11 +945,7 @@ impl<'db> TypeInferenceBuilder<'db> { if let Some(module) = self.module_ty_from_name(module_name) { module } else { - self.add_diagnostic( - AnyNodeRef::Alias(alias), - "unresolved-import", - format_args!("Cannot resolve import '{name}'."), - ); + self.unresolved_module_diagnostic(alias, 0, Some(name)); Type::Unknown } } else { @@ -994,18 +990,18 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_optional_expression(cause.as_deref()); } - fn unresolved_module_name( + fn unresolved_module_diagnostic( &mut self, - import_from: &ast::StmtImportFrom, - level: usize, + import_node: impl Into>, + level: u32, module: Option<&str>, ) { self.add_diagnostic( - AnyNodeRef::StmtImportFrom(import_from), + import_node.into(), "unresolved-import", format_args!( "Cannot resolve import '{}{}'.", - ".".repeat(level), + ".".repeat(level as usize), module.unwrap_or_default() ), ); @@ -1022,46 +1018,32 @@ impl<'db> TypeInferenceBuilder<'db> { /// - `from .foo import bar` => `tail == "foo"` /// - `from ..foo.bar import baz` => `tail == "foo.bar"` fn relative_module_name( - &mut self, + &self, tail: Option<&str>, level: NonZeroU32, - import_from: &ast::StmtImportFrom, - ) -> Option { + ) -> Result { + let module = file_to_module(self.db, self.file) + .ok_or(ModuleNameResolutionError::UnknownCurrentModule)?; let mut level = level.get(); - let Some(module) = file_to_module(self.db, self.file) else { - tracing::debug!( - "Relative module resolution '{}' failed; could not resolve file '{}' to a module", - format_import_from_module(level, tail), - self.file.path(self.db) - ); - self.unresolved_module_name(import_from, level as usize, tail); - return None; - }; if module.kind().is_package() { level -= 1; } let mut module_name = module.name().to_owned(); for _ in 0..level { - let Some(parent) = module_name.parent() else { - self.unresolved_module_name(import_from, level as usize, tail); - return None; - }; - module_name = parent; + module_name = module_name + .parent() + .ok_or(ModuleNameResolutionError::TooManyDots)?; } if let Some(tail) = tail { - if let Some(valid_tail) = ModuleName::new(tail) { - module_name.extend(&valid_tail); - } else { - tracing::debug!("Relative module resolution failed: invalid syntax"); - return None; - } + let tail = ModuleName::new(tail).ok_or(ModuleNameResolutionError::InvalidSyntax)?; + module_name.extend(&tail); } - Some(module_name) + Ok(module_name) } fn infer_import_from_definition( &mut self, - import_from: &ast::StmtImportFrom, + import_from: &'db ast::StmtImportFrom, alias: &ast::Alias, definition: Definition<'db>, ) { @@ -1077,6 +1059,7 @@ impl<'db> TypeInferenceBuilder<'db> { let ast::StmtImportFrom { module, level, .. } = import_from; tracing::trace!("Resolving imported object {alias:?} from statement {import_from:?}"); let module = module.as_deref(); + let module_name = if let Some(level) = NonZeroU32::new(*level) { tracing::trace!( "Resolving imported object '{}' from module '{}' relative to file '{}'", @@ -1084,26 +1067,49 @@ impl<'db> TypeInferenceBuilder<'db> { format_import_from_module(level.get(), module), self.file.path(self.db), ); - self.relative_module_name(module, level, import_from) + self.relative_module_name(module, level) } else { tracing::trace!( "Resolving imported object '{}' from module '{}'", alias.name, format_import_from_module(*level, module), ); - module.and_then(ModuleName::new) + module + .and_then(ModuleName::new) + .ok_or(ModuleNameResolutionError::InvalidSyntax) }; - let module_ty = if let Some(module_name) = module_name { - if let Some(module_ty) = self.module_ty_from_name(module_name) { - module_ty - } else { - self.unresolved_module_name(import_from, *level as usize, module); + let module_ty = match module_name { + Ok(name) => { + if let Some(ty) = self.module_ty_from_name(name) { + ty + } else { + self.unresolved_module_diagnostic(import_from, *level, module); + Type::Unknown + } + } + Err(ModuleNameResolutionError::InvalidSyntax) => { + tracing::debug!("Failed to resolve import due to invalid syntax"); + // Invalid syntax diagnostics are emitted elsewhere. + Type::Unknown + } + Err(ModuleNameResolutionError::TooManyDots) => { + tracing::trace!( + "Relative module resolution '{}' failed: too many leading dots", + format_import_from_module(*level, module), + ); + self.unresolved_module_diagnostic(import_from, *level, module); + Type::Unknown + } + Err(ModuleNameResolutionError::UnknownCurrentModule) => { + tracing::debug!( + "Relative module resolution '{}' failed; could not resolve file '{}' to a module", + format_import_from_module(*level, module), + self.file.path(self.db) + ); + self.unresolved_module_diagnostic(import_from, *level, module); Type::Unknown } - } else { - tracing::debug!("Failed to resolve import due to invalid syntax"); - Type::Unknown }; let ast::Alias { @@ -1926,6 +1932,23 @@ fn format_import_from_module(level: u32, module: Option<&str>) -> String { ) } +/// Various ways in which resolving a [`ModuleName`] +/// from an [`ast::StmtImport`] or [`ast::StmtImportFrom`] node might fail +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum ModuleNameResolutionError { + /// The import statement has invalid syntax + InvalidSyntax, + + /// We couldn't resolve the file we're currently analyzing back to a module + /// (Only necessary for relative import statements) + UnknownCurrentModule, + + /// The relative import statement seems to take us outside of the module search path + /// (e.g. our current module is `foo.bar`, and the relative import statement in `foo.bar` + /// is `from ....baz import spam`) + TooManyDots, +} + #[cfg(test)] mod tests { use anyhow::Context; diff --git a/crates/ruff_benchmark/benches/red_knot.rs b/crates/ruff_benchmark/benches/red_knot.rs index 94a7a5edaaf03a..6996c42fadca78 100644 --- a/crates/ruff_benchmark/benches/red_knot.rs +++ b/crates/ruff_benchmark/benches/red_knot.rs @@ -21,7 +21,7 @@ struct Case { const TOMLLIB_312_URL: &str = "https://raw.githubusercontent.com/python/cpython/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib"; -// This first "unresolved import" is because we don't understand `*` imports yet. +// The "unresolved import" is because we don't understand `*` imports yet. static EXPECTED_DIAGNOSTICS: &[&str] = &[ "/src/tomllib/_parser.py:7:29: Module 'collections.abc' has no member 'Iterable'", "Line 69 is too long (89 characters)",