Skip to content

Commit

Permalink
[red-knot] Add symbol and definition for parameters (#12862)
Browse files Browse the repository at this point in the history
## Summary

This PR adds support for adding symbols and definitions for function and
lambda parameters to the semantic index.

### Notes

* The default expression of a parameter is evaluated in the enclosing
scope (not the type parameter or function scope).
* The annotation expression of a parameter is evaluated in the type
parameter scope if they're present other in the enclosing scope.
* The symbols and definitions are added in the function parameter scope.

### Type Inference

There are two definitions `Parameter` and `ParameterWithDefault` and
their respective `*_definition` methods on the type inference builder.
These methods are preferred and are re-used when checking from a
different region.

## Test Plan

Add test case for validating that the parameters are defined in the
function / lambda scope.

### Benchmark update

Validated the difference in diagnostics for benchmark code between
`main` and this branch. All of them are either directly or indirectly
referencing one of the function parameters. The diff is in the PR description.
  • Loading branch information
dhruvmanila committed Aug 16, 2024
1 parent f121f8b commit bd4a947
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 5 deletions.
97 changes: 97 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,103 @@ y = 2
));
}

#[test]
fn function_parameter_symbols() {
let TestCase { db, file } = test_case(
"
def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
pass
",
);

let index = semantic_index(&db, file);
let global_table = symbol_table(&db, global_scope(&db, file));

assert_eq!(names(&global_table), vec!["f", "str", "int"]);

let [(function_scope_id, _function_scope)] = index
.child_scopes(FileScopeId::global())
.collect::<Vec<_>>()[..]
else {
panic!("Expected a function scope")
};

let function_table = index.symbol_table(function_scope_id);
assert_eq!(
names(&function_table),
vec!["a", "b", "c", "args", "d", "kwargs"],
);

let use_def = index.use_def_map(function_scope_id);
for name in ["a", "b", "c", "d"] {
let [definition] = use_def.public_definitions(
function_table
.symbol_id_by_name(name)
.expect("symbol exists"),
) else {
panic!("Expected parameter definition for {name}");
};
assert!(matches!(
definition.node(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let [definition] = use_def.public_definitions(
function_table
.symbol_id_by_name(name)
.expect("symbol exists"),
) else {
panic!("Expected parameter definition for {name}");
};
assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_)));
}
}

#[test]
fn lambda_parameter_symbols() {
let TestCase { db, file } = test_case("lambda a, b, c=1, *args, d=2, **kwargs: None");

let index = semantic_index(&db, file);
let global_table = symbol_table(&db, global_scope(&db, file));

assert!(names(&global_table).is_empty());

let [(lambda_scope_id, _lambda_scope)] = index
.child_scopes(FileScopeId::global())
.collect::<Vec<_>>()[..]
else {
panic!("Expected a lambda scope")
};

let lambda_table = index.symbol_table(lambda_scope_id);
assert_eq!(
names(&lambda_table),
vec!["a", "b", "c", "args", "d", "kwargs"],
);

let use_def = index.use_def_map(lambda_scope_id);
for name in ["a", "b", "c", "d"] {
let [definition] = use_def
.public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
else {
panic!("Expected parameter definition for {name}");
};
assert!(matches!(
definition.node(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let [definition] = use_def
.public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
else {
panic!("Expected parameter definition for {name}");
};
assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_)));
}
}

/// Test case to validate that the comprehension scope is correctly identified and that the target
/// variable is defined only in the comprehension scope and not in the global scope.
#[test]
Expand Down
48 changes: 48 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,16 @@ where
.add_or_update_symbol(function_def.name.id.clone(), SymbolFlags::IS_DEFINED);
self.add_definition(symbol, function_def);

// The default value of the parameters needs to be evaluated in the
// enclosing scope.
for default in function_def
.parameters
.iter_non_variadic_params()
.filter_map(|param| param.default.as_deref())
{
self.visit_expr(default);
}

self.with_type_params(
NodeWithScopeRef::FunctionTypeParameters(function_def),
function_def.type_params.as_deref(),
Expand All @@ -378,6 +388,16 @@ where
}

builder.push_scope(NodeWithScopeRef::Function(function_def));

// Add symbols and definitions for the parameters to the function scope.
for parameter in &*function_def.parameters {
let symbol = builder.add_or_update_symbol(
parameter.name().id().clone(),
SymbolFlags::IS_DEFINED,
);
builder.add_definition(symbol, parameter);
}

builder.visit_body(&function_def.body);
builder.pop_scope()
},
Expand Down Expand Up @@ -574,9 +594,29 @@ where
}
ast::Expr::Lambda(lambda) => {
if let Some(parameters) = &lambda.parameters {
// The default value of the parameters needs to be evaluated in the
// enclosing scope.
for default in parameters
.iter_non_variadic_params()
.filter_map(|param| param.default.as_deref())
{
self.visit_expr(default);
}
self.visit_parameters(parameters);
}
self.push_scope(NodeWithScopeRef::Lambda(lambda));

// Add symbols and definitions for the parameters to the lambda scope.
if let Some(parameters) = &lambda.parameters {
for parameter in &**parameters {
let symbol = self.add_or_update_symbol(
parameter.name().id().clone(),
SymbolFlags::IS_DEFINED,
);
self.add_definition(symbol, parameter);
}
}

self.visit_expr(lambda.body.as_ref());
}
ast::Expr::If(ast::ExprIf {
Expand Down Expand Up @@ -654,6 +694,14 @@ where
self.pop_scope();
}
}

fn visit_parameters(&mut self, parameters: &'ast ruff_python_ast::Parameters) {
// Intentionally avoid walking default expressions, as we handle them in the enclosing
// scope.
for parameter in parameters.iter().map(ast::AnyParameterRef::as_parameter) {
self.visit_parameter(parameter);
}
}
}

#[derive(Copy, Clone, Debug)]
Expand Down
33 changes: 33 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
Assignment(AssignmentDefinitionNodeRef<'a>),
AnnotatedAssignment(&'a ast::StmtAnnAssign),
Comprehension(ComprehensionDefinitionNodeRef<'a>),
Parameter(ast::AnyParameterRef<'a>),
}

impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
Expand Down Expand Up @@ -95,6 +96,12 @@ impl<'a> From<ComprehensionDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
}
}

impl<'a> From<ast::AnyParameterRef<'a>> for DefinitionNodeRef<'a> {
fn from(node: ast::AnyParameterRef<'a>) -> Self {
Self::Parameter(node)
}
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtImportFrom,
Expand Down Expand Up @@ -150,6 +157,14 @@ impl DefinitionNodeRef<'_> {
first,
})
}
DefinitionNodeRef::Parameter(parameter) => match parameter {
ast::AnyParameterRef::Variadic(parameter) => {
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
}
ast::AnyParameterRef::NonVariadic(parameter) => {
DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter))
}
},
}
}

Expand All @@ -168,6 +183,10 @@ impl DefinitionNodeRef<'_> {
}) => target.into(),
Self::AnnotatedAssignment(node) => node.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(),
Self::Parameter(node) => match node {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
},
}
}
}
Expand All @@ -182,6 +201,8 @@ pub enum DefinitionKind {
Assignment(AssignmentDefinitionKind),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
Comprehension(ComprehensionDefinitionKind),
Parameter(AstNodeRef<ast::Parameter>),
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -273,3 +294,15 @@ impl From<&ast::Comprehension> for DefinitionNodeKey {
Self(NodeKey::from_node(node))
}
}

impl From<&ast::Parameter> for DefinitionNodeKey {
fn from(node: &ast::Parameter) -> Self {
Self(NodeKey::from_node(node))
}
}

impl From<&ast::ParameterWithDefault> for DefinitionNodeKey {
fn from(node: &ast::ParameterWithDefault) -> Self {
Self(NodeKey::from_node(node))
}
}
50 changes: 47 additions & 3 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ impl<'db> TypeInferenceBuilder<'db> {
definition,
);
}
DefinitionKind::Parameter(parameter) => {
self.infer_parameter_definition(parameter, definition);
}
DefinitionKind::ParameterWithDefault(parameter_with_default) => {
self.infer_parameter_with_default_definition(parameter_with_default, definition);
}
}
}

Expand Down Expand Up @@ -421,6 +427,13 @@ impl<'db> TypeInferenceBuilder<'db> {
.map(|decorator| self.infer_decorator(decorator))
.collect();

for default in parameters
.iter_non_variadic_params()
.filter_map(|param| param.default.as_deref())
{
self.infer_expression(default);
}

// If there are type params, parameters and returns are evaluated in that scope.
if type_params.is_none() {
self.infer_parameters(parameters);
Expand Down Expand Up @@ -458,10 +471,12 @@ impl<'db> TypeInferenceBuilder<'db> {
let ast::ParameterWithDefault {
range: _,
parameter,
default,
default: _,
} = parameter_with_default;
self.infer_parameter(parameter);
self.infer_optional_expression(default.as_deref());

self.infer_optional_expression(parameter.annotation.as_deref());

self.infer_definition(parameter_with_default);
}

fn infer_parameter(&mut self, parameter: &ast::Parameter) {
Expand All @@ -470,7 +485,29 @@ impl<'db> TypeInferenceBuilder<'db> {
name: _,
annotation,
} = parameter;

self.infer_optional_expression(annotation.as_deref());

self.infer_definition(parameter);
}

fn infer_parameter_with_default_definition(
&mut self,
_parameter_with_default: &ast::ParameterWithDefault,
definition: Definition<'db>,
) {
// TODO(dhruvmanila): Infer types from annotation or default expression
self.types.definitions.insert(definition, Type::Unknown);
}

fn infer_parameter_definition(
&mut self,
_parameter: &ast::Parameter,
definition: Definition<'db>,
) {
// TODO(dhruvmanila): Annotation expression is resolved at the enclosing scope, infer the
// parameter type from there
self.types.definitions.insert(definition, Type::Unknown);
}

fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) {
Expand Down Expand Up @@ -1277,6 +1314,13 @@ impl<'db> TypeInferenceBuilder<'db> {
} = lambda_expression;

if let Some(parameters) = parameters {
for default in parameters
.iter_non_variadic_params()
.filter_map(|param| param.default.as_deref())
{
self.infer_expression(default);
}

self.infer_parameters(parameters);
}

Expand Down
4 changes: 2 additions & 2 deletions crates/ruff_benchmark/benches/red_knot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn benchmark_incremental(criterion: &mut Criterion) {
let Case { db, parser, .. } = case;
let result = db.check_file(*parser).unwrap();

assert_eq!(result.len(), 402);
assert_eq!(result.len(), 111);
},
BatchSize::SmallInput,
);
Expand All @@ -104,7 +104,7 @@ fn benchmark_cold(criterion: &mut Criterion) {
let Case { db, parser, .. } = case;
let result = db.check_file(*parser).unwrap();

assert_eq!(result.len(), 402);
assert_eq!(result.len(), 111);
},
BatchSize::SmallInput,
);
Expand Down

0 comments on commit bd4a947

Please sign in to comment.