diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_trio/TRIO115.py b/crates/ruff_linter/resources/test/fixtures/flake8_trio/TRIO115.py index d7466beb0f5d3..764b5c1d6e9f5 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_trio/TRIO115.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_trio/TRIO115.py @@ -29,8 +29,8 @@ async def func(): trio.sleep(e) # TRIO115 m_x, m_y = 0 - trio.sleep(m_y) # TRIO115 - trio.sleep(m_x) # TRIO115 + trio.sleep(m_y) # OK + trio.sleep(m_x) # OK m_a = m_b = 0 trio.sleep(m_a) # TRIO115 @@ -43,6 +43,8 @@ async def func(): def func(): + import trio + trio.run(trio.sleep(0)) # TRIO115 @@ -55,3 +57,10 @@ def func(): async def func(): await sleep(seconds=0) # TRIO115 + + +def func(): + import trio + + if (walrus := 0) == 0: + trio.sleep(walrus) # TRIO115 diff --git a/crates/ruff_linter/src/rules/flake8_trio/snapshots/ruff_linter__rules__flake8_trio__tests__TRIO115_TRIO115.py.snap b/crates/ruff_linter/src/rules/flake8_trio/snapshots/ruff_linter__rules__flake8_trio__tests__TRIO115_TRIO115.py.snap index 1ade9f757bbaa..7710be928504a 100644 --- a/crates/ruff_linter/src/rules/flake8_trio/snapshots/ruff_linter__rules__flake8_trio__tests__TRIO115_TRIO115.py.snap +++ b/crates/ruff_linter/src/rules/flake8_trio/snapshots/ruff_linter__rules__flake8_trio__tests__TRIO115_TRIO115.py.snap @@ -143,47 +143,7 @@ TRIO115.py:29:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.s 29 |+ trio.lowlevel.checkpoint() # TRIO115 30 30 | 31 31 | m_x, m_y = 0 -32 32 | trio.sleep(m_y) # TRIO115 - -TRIO115.py:32:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` - | -31 | m_x, m_y = 0 -32 | trio.sleep(m_y) # TRIO115 - | ^^^^^^^^^^^^^^^ TRIO115 -33 | trio.sleep(m_x) # TRIO115 - | - = help: Replace with `trio.lowlevel.checkpoint()` - -ℹ Safe fix -29 29 | trio.sleep(e) # TRIO115 -30 30 | -31 31 | m_x, m_y = 0 -32 |- trio.sleep(m_y) # TRIO115 - 32 |+ trio.lowlevel.checkpoint() # TRIO115 -33 33 | trio.sleep(m_x) # TRIO115 -34 34 | -35 35 | m_a = m_b = 0 - -TRIO115.py:33:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` - | -31 | m_x, m_y = 0 -32 | trio.sleep(m_y) # TRIO115 -33 | trio.sleep(m_x) # TRIO115 - | ^^^^^^^^^^^^^^^ TRIO115 -34 | -35 | m_a = m_b = 0 - | - = help: Replace with `trio.lowlevel.checkpoint()` - -ℹ Safe fix -30 30 | -31 31 | m_x, m_y = 0 -32 32 | trio.sleep(m_y) # TRIO115 -33 |- trio.sleep(m_x) # TRIO115 - 33 |+ trio.lowlevel.checkpoint() # TRIO115 -34 34 | -35 35 | m_a = m_b = 0 -36 36 | trio.sleep(m_a) # TRIO115 +32 32 | trio.sleep(m_y) # OK TRIO115.py:36:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` | @@ -195,7 +155,7 @@ TRIO115.py:36:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.s = help: Replace with `trio.lowlevel.checkpoint()` ℹ Safe fix -33 33 | trio.sleep(m_x) # TRIO115 +33 33 | trio.sleep(m_x) # OK 34 34 | 35 35 | m_a = m_b = 0 36 |- trio.sleep(m_a) # TRIO115 @@ -264,51 +224,88 @@ TRIO115.py:42:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.s 44 44 | 45 45 | def func(): -TRIO115.py:53:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` +TRIO115.py:48:14: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` | -52 | def func(): -53 | sleep(0) # TRIO115 - | ^^^^^^^^ TRIO115 +46 | import trio +47 | +48 | trio.run(trio.sleep(0)) # TRIO115 + | ^^^^^^^^^^^^^ TRIO115 | = help: Replace with `trio.lowlevel.checkpoint()` ℹ Safe fix -46 46 | trio.run(trio.sleep(0)) # TRIO115 +45 45 | def func(): +46 46 | import trio 47 47 | -48 48 | -49 |-from trio import Event, sleep - 49 |+from trio import Event, sleep, lowlevel +48 |- trio.run(trio.sleep(0)) # TRIO115 + 48 |+ trio.run(trio.lowlevel.checkpoint()) # TRIO115 +49 49 | +50 50 | +51 51 | from trio import Event, sleep + +TRIO115.py:55:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` + | +54 | def func(): +55 | sleep(0) # TRIO115 + | ^^^^^^^^ TRIO115 + | + = help: Replace with `trio.lowlevel.checkpoint()` + +ℹ Safe fix +48 48 | trio.run(trio.sleep(0)) # TRIO115 +49 49 | 50 50 | -51 51 | -52 52 | def func(): -53 |- sleep(0) # TRIO115 - 53 |+ lowlevel.checkpoint() # TRIO115 -54 54 | -55 55 | -56 56 | async def func(): +51 |-from trio import Event, sleep + 51 |+from trio import Event, sleep, lowlevel +52 52 | +53 53 | +54 54 | def func(): +55 |- sleep(0) # TRIO115 + 55 |+ lowlevel.checkpoint() # TRIO115 +56 56 | +57 57 | +58 58 | async def func(): -TRIO115.py:57:11: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` +TRIO115.py:59:11: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` | -56 | async def func(): -57 | await sleep(seconds=0) # TRIO115 +58 | async def func(): +59 | await sleep(seconds=0) # TRIO115 | ^^^^^^^^^^^^^^^^ TRIO115 | = help: Replace with `trio.lowlevel.checkpoint()` ℹ Safe fix -46 46 | trio.run(trio.sleep(0)) # TRIO115 -47 47 | -48 48 | -49 |-from trio import Event, sleep - 49 |+from trio import Event, sleep, lowlevel +48 48 | trio.run(trio.sleep(0)) # TRIO115 +49 49 | 50 50 | -51 51 | -52 52 | def func(): +51 |-from trio import Event, sleep + 51 |+from trio import Event, sleep, lowlevel +52 52 | +53 53 | +54 54 | def func(): -------------------------------------------------------------------------------- -54 54 | -55 55 | -56 56 | async def func(): -57 |- await sleep(seconds=0) # TRIO115 - 57 |+ await lowlevel.checkpoint() # TRIO115 +56 56 | +57 57 | +58 58 | async def func(): +59 |- await sleep(seconds=0) # TRIO115 + 59 |+ await lowlevel.checkpoint() # TRIO115 +60 60 | +61 61 | +62 62 | def func(): + +TRIO115.py:66:9: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` + | +65 | if (walrus := 0) == 0: +66 | trio.sleep(walrus) # TRIO115 + | ^^^^^^^^^^^^^^^^^^ TRIO115 + | + = help: Replace with `trio.lowlevel.checkpoint()` + +ℹ Safe fix +63 63 | import trio +64 64 | +65 65 | if (walrus := 0) == 0: +66 |- trio.sleep(walrus) # TRIO115 + 66 |+ trio.lowlevel.checkpoint() # TRIO115 diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs b/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs index f238198ec1e88..1fc4ade6fda9a 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs @@ -206,8 +206,8 @@ pub(crate) fn is_singledispatch_implementation( /// This requires more than just wrapping the reference itself in quotes. For example: /// - When quoting `Series` in `Series[pd.Timestamp]`, we want `"Series[pd.Timestamp]"`. /// - When quoting `kubernetes` in `kubernetes.SecurityContext`, we want `"kubernetes.SecurityContext"`. -/// - When quoting `Series` in `Series["pd.Timestamp"]`, we want `"Series[pd.Timestamp]"`. -/// - When quoting `Series` in `Series[Literal["pd.Timestamp"]]`, we want `"Series[Literal['pd.Timestamp']]"`. +/// - When quoting `Series` in `Series["pd.Timestamp"]`, we want `"Series[pd.Timestamp]"`. (This is currently unsupported.) +/// - When quoting `Series` in `Series[Literal["pd.Timestamp"]]`, we want `"Series[Literal['pd.Timestamp']]"`. (This is currently unsupported.) /// /// In general, when expanding a component of a call chain, we want to quote the entire call chain. pub(crate) fn quote_annotation( diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs b/crates/ruff_linter/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs index 0bec7f317009c..5eee0365e12fa 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs @@ -86,6 +86,16 @@ impl Violation for RuntimeImportInTypeCheckingBlock { } } +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +enum Action { + /// The import should be moved out of the type-checking block. + Move, + /// All usages of the import should be wrapped in quotes. + Quote, + /// The import should be ignored. + Ignore, +} + /// TCH004 pub(crate) fn runtime_import_in_type_checking_block( checker: &Checker, @@ -93,9 +103,7 @@ pub(crate) fn runtime_import_in_type_checking_block( diagnostics: &mut Vec, ) { // Collect all runtime imports by statement. - let mut moves_by_statement: FxHashMap> = FxHashMap::default(); - let mut quotes_by_statement: FxHashMap> = FxHashMap::default(); - let mut ignores_by_statement: FxHashMap> = FxHashMap::default(); + let mut actions: FxHashMap<(NodeId, Action), Vec> = FxHashMap::default(); for binding_id in scope.binding_ids() { let binding = checker.semantic().binding(binding_id); @@ -137,8 +145,8 @@ pub(crate) fn runtime_import_in_type_checking_block( ) }) { - ignores_by_statement - .entry(node_id) + actions + .entry((node_id, Action::Ignore)) .or_default() .push(import); } else { @@ -151,92 +159,104 @@ pub(crate) fn runtime_import_in_type_checking_block( || reference.in_runtime_evaluated_annotation() }) { - quotes_by_statement.entry(node_id).or_default().push(import); + actions + .entry((node_id, Action::Quote)) + .or_default() + .push(import); } else { - moves_by_statement.entry(node_id).or_default().push(import); + actions + .entry((node_id, Action::Move)) + .or_default() + .push(import); } } } } - // Generate a diagnostic for every import, but share a fix across all imports within the same - // statement (excluding those that are ignored). - for (node_id, imports) in moves_by_statement { - let fix = move_imports(checker, node_id, &imports).ok(); + for ((node_id, action), imports) in actions { + match action { + // Generate a diagnostic for every import, but share a fix across all imports within the same + // statement (excluding those that are ignored). + Action::Move => { + let fix = move_imports(checker, node_id, &imports).ok(); - for ImportBinding { - import, - range, - parent_range, - .. - } in imports - { - let mut diagnostic = Diagnostic::new( - RuntimeImportInTypeCheckingBlock { - qualified_name: import.qualified_name(), - strategy: Strategy::MoveImport, - }, - range, - ); - if let Some(range) = parent_range { - diagnostic.set_parent(range.start()); - } - if let Some(fix) = fix.as_ref() { - diagnostic.set_fix(fix.clone()); + for ImportBinding { + import, + range, + parent_range, + .. + } in imports + { + let mut diagnostic = Diagnostic::new( + RuntimeImportInTypeCheckingBlock { + qualified_name: import.qualified_name(), + strategy: Strategy::MoveImport, + }, + range, + ); + if let Some(range) = parent_range { + diagnostic.set_parent(range.start()); + } + if let Some(fix) = fix.as_ref() { + diagnostic.set_fix(fix.clone()); + } + diagnostics.push(diagnostic); + } } - diagnostics.push(diagnostic); - } - } - // Generate a diagnostic for every import, but share a fix across all imports within the same - // statement (excluding those that are ignored). - for (node_id, imports) in quotes_by_statement { - let fix = quote_imports(checker, node_id, &imports).ok(); + // Generate a diagnostic for every import, but share a fix across all imports within the same + // statement (excluding those that are ignored). + Action::Quote => { + let fix = quote_imports(checker, node_id, &imports).ok(); - for ImportBinding { - import, - range, - parent_range, - .. - } in imports - { - let mut diagnostic = Diagnostic::new( - RuntimeImportInTypeCheckingBlock { - qualified_name: import.qualified_name(), - strategy: Strategy::QuoteUsages, - }, - range, - ); - if let Some(range) = parent_range { - diagnostic.set_parent(range.start()); - } - if let Some(fix) = fix.as_ref() { - diagnostic.set_fix(fix.clone()); + for ImportBinding { + import, + range, + parent_range, + .. + } in imports + { + let mut diagnostic = Diagnostic::new( + RuntimeImportInTypeCheckingBlock { + qualified_name: import.qualified_name(), + strategy: Strategy::QuoteUsages, + }, + range, + ); + if let Some(range) = parent_range { + diagnostic.set_parent(range.start()); + } + if let Some(fix) = fix.as_ref() { + diagnostic.set_fix(fix.clone()); + } + diagnostics.push(diagnostic); + } } - diagnostics.push(diagnostic); - } - } - // Separately, generate a diagnostic for every _ignored_ import, to ensure that the - // suppression comments aren't marked as unused. - for ImportBinding { - import, - range, - parent_range, - .. - } in ignores_by_statement.into_values().flatten() - { - let mut diagnostic = Diagnostic::new( - RuntimeImportInTypeCheckingBlock { - qualified_name: import.qualified_name(), - strategy: Strategy::MoveImport, - }, - range, - ); - if let Some(range) = parent_range { - diagnostic.set_parent(range.start()); + // Separately, generate a diagnostic for every _ignored_ import, to ensure that the + // suppression comments aren't marked as unused. + Action::Ignore => { + for ImportBinding { + import, + range, + parent_range, + .. + } in imports + { + let mut diagnostic = Diagnostic::new( + RuntimeImportInTypeCheckingBlock { + qualified_name: import.qualified_name(), + strategy: Strategy::MoveImport, + }, + range, + ); + if let Some(range) = parent_range { + diagnostic.set_parent(range.start()); + } + diagnostics.push(diagnostic); + } + } } - diagnostics.push(diagnostic); } } diff --git a/crates/ruff_python_semantic/src/analyze/typing.rs b/crates/ruff_python_semantic/src/analyze/typing.rs index 4ff2e27e3221c..2dd7f1003e398 100644 --- a/crates/ruff_python_semantic/src/analyze/typing.rs +++ b/crates/ruff_python_semantic/src/analyze/typing.rs @@ -582,42 +582,64 @@ pub fn resolve_assignment<'a>( pub fn find_assigned_value<'a>(symbol: &str, semantic: &'a SemanticModel<'a>) -> Option<&'a Expr> { let binding_id = semantic.lookup_symbol(symbol)?; let binding = semantic.binding(binding_id); - if binding.kind.is_assignment() || binding.kind.is_named_expr_assignment() { - let parent_id = binding.source?; - let parent = semantic.statement(parent_id); - match parent { - Stmt::Assign(ast::StmtAssign { value, targets, .. }) => match value.as_ref() { - Expr::Tuple(ast::ExprTuple { elts, .. }) - | Expr::List(ast::ExprList { elts, .. }) => { + match binding.kind { + // Ex) `x := 1` + BindingKind::NamedExprAssignment => { + let parent_id = binding.source?; + let parent = semantic + .expressions(parent_id) + .find_map(|expr| expr.as_named_expr_expr()); + if let Some(ast::ExprNamedExpr { target, value, .. }) = parent { + return match_value(symbol, target.as_ref(), value.as_ref()); + } + } + // Ex) `x = 1` + BindingKind::Assignment => { + let parent_id = binding.source?; + let parent = semantic.statement(parent_id); + match parent { + Stmt::Assign(ast::StmtAssign { value, targets, .. }) => { if let Some(target) = targets.iter().find(|target| defines(symbol, target)) { - return match target { - Expr::Tuple(ast::ExprTuple { - elts: target_elts, .. - }) - | Expr::List(ast::ExprList { - elts: target_elts, .. - }) - | Expr::Set(ast::ExprSet { - elts: target_elts, .. - }) => get_value_by_id(symbol, target_elts, elts), - _ => Some(value.as_ref()), - }; + return match_value(symbol, target, value.as_ref()); } } - _ => return Some(value.as_ref()), - }, - Stmt::AnnAssign(ast::StmtAnnAssign { - value: Some(value), .. - }) => { - return Some(value.as_ref()); + Stmt::AnnAssign(ast::StmtAnnAssign { + value: Some(value), + target, + .. + }) => { + return match_value(symbol, target, value.as_ref()); + } + _ => {} } - Stmt::AugAssign(_) => return None, - _ => return None, } + _ => {} } None } +/// Given a target and value, find the value that's assigned to the given symbol. +fn match_value<'a>(symbol: &str, target: &Expr, value: &'a Expr) -> Option<&'a Expr> { + match target { + Expr::Name(ast::ExprName { id, .. }) if id.as_str() == symbol => Some(value), + Expr::Tuple(ast::ExprTuple { elts, .. }) | Expr::List(ast::ExprList { elts, .. }) => { + match value { + Expr::Tuple(ast::ExprTuple { + elts: value_elts, .. + }) + | Expr::List(ast::ExprList { + elts: value_elts, .. + }) + | Expr::Set(ast::ExprSet { + elts: value_elts, .. + }) => get_value_by_id(symbol, elts, value_elts), + _ => None, + } + } + _ => None, + } +} + /// Returns `true` if the [`Expr`] defines the symbol. fn defines(symbol: &str, expr: &Expr) -> bool { match expr { @@ -629,11 +651,7 @@ fn defines(symbol: &str, expr: &Expr) -> bool { } } -fn get_value_by_id<'a>( - target_id: &str, - targets: &'a [Expr], - values: &'a [Expr], -) -> Option<&'a Expr> { +fn get_value_by_id<'a>(target_id: &str, targets: &[Expr], values: &'a [Expr]) -> Option<&'a Expr> { for (target, value) in targets.iter().zip(values.iter()) { match target { Expr::Tuple(ast::ExprTuple { diff --git a/crates/ruff_python_semantic/src/model.rs b/crates/ruff_python_semantic/src/model.rs index dbcc021b71f22..d92dab84d0f5c 100644 --- a/crates/ruff_python_semantic/src/model.rs +++ b/crates/ruff_python_semantic/src/model.rs @@ -993,14 +993,6 @@ impl<'a> SemanticModel<'a> { &self.nodes[node_id] } - /// Return the [`Expr`] corresponding to the given [`NodeId`]. - #[inline] - pub fn expression(&self, node_id: NodeId) -> Option<&'a Expr> { - self.nodes - .ancestor_ids(node_id) - .find_map(|id| self.nodes[id].as_expression()) - } - /// Given a [`Expr`], return its parent, if any. #[inline] pub fn parent_expression(&self, node_id: NodeId) -> Option<&'a Expr> { @@ -1044,6 +1036,22 @@ impl<'a> SemanticModel<'a> { .nth(1) } + /// Return the [`Expr`] corresponding to the given [`NodeId`]. + #[inline] + pub fn expression(&self, node_id: NodeId) -> Option<&'a Expr> { + self.nodes + .ancestor_ids(node_id) + .find_map(|id| self.nodes[id].as_expression()) + } + + /// Returns an [`Iterator`] over the expressions, starting from the given [`NodeId`]. + /// through to any parents. + pub fn expressions(&self, node_id: NodeId) -> impl Iterator + '_ { + self.nodes + .ancestor_ids(node_id) + .filter_map(move |id| self.nodes[id].as_expression()) + } + /// Set the [`Globals`] for the current [`Scope`]. pub fn set_globals(&mut self, globals: Globals<'a>) { // If any global bindings don't already exist in the global scope, add them.