From 1de8ff3308ed7dfbcc384f6d0062b6f7d04b30d0 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Wed, 17 Jul 2024 12:03:36 -0400 Subject: [PATCH] Detect enumerate iterations in `loop-iterator-mutation` (#12366) ## Summary Closes https://github.com/astral-sh/ruff/issues/12164. --- .../test/fixtures/flake8_bugbear/B909.py | 12 ++++ .../rules/loop_iterator_mutation.rs | 64 +++++++++++++++---- ...__flake8_bugbear__tests__B909_B909.py.snap | 10 +++ 3 files changed, 74 insertions(+), 12 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_bugbear/B909.py b/crates/ruff_linter/resources/test/fixtures/flake8_bugbear/B909.py index 68afaf87fb257..1a9d76ecf8e77 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_bugbear/B909.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_bugbear/B909.py @@ -158,3 +158,15 @@ def __init__(self, ls): some_list[elem] = 1 some_list.remove(elem) some_list.discard(elem) + +# should error +for i, elem in enumerate(some_list): + some_list.pop(0) + +# should not error (list) +for i, elem in enumerate(some_list): + some_list[i] = 1 + +# should not error (dict) +for i, elem in enumerate(some_list): + some_list[elem] = 1 diff --git a/crates/ruff_linter/src/rules/flake8_bugbear/rules/loop_iterator_mutation.rs b/crates/ruff_linter/src/rules/flake8_bugbear/rules/loop_iterator_mutation.rs index d48466813acb7..210152c3bc4ac 100644 --- a/crates/ruff_linter/src/rules/flake8_bugbear/rules/loop_iterator_mutation.rs +++ b/crates/ruff_linter/src/rules/flake8_bugbear/rules/loop_iterator_mutation.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use ruff_diagnostics::Diagnostic; use ruff_diagnostics::Violation; use ruff_macros::{derive_message_formats, violation}; @@ -7,10 +5,12 @@ use ruff_python_ast::comparable::ComparableExpr; use ruff_python_ast::name::UnqualifiedName; use ruff_python_ast::{ visitor::{self, Visitor}, - Arguments, Expr, ExprAttribute, ExprCall, ExprSubscript, Stmt, StmtAssign, StmtAugAssign, - StmtBreak, StmtDelete, StmtFor, StmtIf, + Arguments, Expr, ExprAttribute, ExprCall, ExprSubscript, ExprTuple, Stmt, StmtAssign, + StmtAugAssign, StmtBreak, StmtDelete, StmtFor, StmtIf, }; use ruff_text_size::TextRange; +use std::collections::HashMap; +use std::fmt::Debug; use crate::checkers::ast::Checker; use crate::fix::snippet::SourceCodeSnippet; @@ -64,13 +64,44 @@ pub(crate) fn loop_iterator_mutation(checker: &mut Checker, stmt_for: &StmtFor) range: _, } = stmt_for; - if !matches!(iter.as_ref(), Expr::Name(_) | Expr::Attribute(_)) { - return; - } + let (index, target, iter) = match iter.as_ref() { + Expr::Name(_) | Expr::Attribute(_) => { + // Ex) Given, `for item in items:`, `item` is the index and `items` is the iterable. + (&**target, &**target, &**iter) + } + Expr::Call(ExprCall { + func, arguments, .. + }) => { + // Ex) Given `for i, item in enumerate(items):`, `i` is the index and `items` is the + // iterable. + if checker.semantic().match_builtin_expr(func, "enumerate") { + // Ex) `items` + let Some(iter) = arguments.args.first() else { + return; + }; + + let Expr::Tuple(ExprTuple { elts, .. }) = &**target else { + return; + }; + + let [index, target] = elts.as_slice() else { + return; + }; + + // Ex) `i` + (index, target, iter) + } else { + return; + } + } + _ => { + return; + } + }; // Collect mutations to the iterable. let mutations = { - let mut visitor = LoopMutationsVisitor::new(iter, target); + let mut visitor = LoopMutationsVisitor::new(iter, target, index); visitor.visit_body(body); visitor.mutations }; @@ -114,6 +145,7 @@ fn is_mutating_function(function_name: &str) -> bool { struct LoopMutationsVisitor<'a> { iter: &'a Expr, target: &'a Expr, + index: &'a Expr, mutations: HashMap>, branches: Vec, branch: u32, @@ -121,10 +153,11 @@ struct LoopMutationsVisitor<'a> { impl<'a> LoopMutationsVisitor<'a> { /// Initialize the visitor. - fn new(iter: &'a Expr, target: &'a Expr) -> Self { + fn new(iter: &'a Expr, target: &'a Expr, index: &'a Expr) -> Self { Self { iter, target, + index, mutations: HashMap::new(), branches: vec![0], branch: 0, @@ -149,7 +182,9 @@ impl<'a> LoopMutationsVisitor<'a> { // Find, e.g., `del items[0]`. if ComparableExpr::from(self.iter) == ComparableExpr::from(value) { // But allow, e.g., `for item in items: del items[item]`. - if ComparableExpr::from(self.target) != ComparableExpr::from(slice) { + if ComparableExpr::from(self.index) != ComparableExpr::from(slice) + && ComparableExpr::from(self.target) != ComparableExpr::from(slice) + { self.add_mutation(range); } } @@ -170,7 +205,9 @@ impl<'a> LoopMutationsVisitor<'a> { // Find, e.g., `items[0] = 1`. if ComparableExpr::from(self.iter) == ComparableExpr::from(value) { // But allow, e.g., `for item in items: items[item] = 1`. - if ComparableExpr::from(self.target) != ComparableExpr::from(slice) { + if ComparableExpr::from(self.index) != ComparableExpr::from(slice) + && ComparableExpr::from(self.target) != ComparableExpr::from(slice) + { self.add_mutation(range); } } @@ -201,7 +238,10 @@ impl<'a> LoopMutationsVisitor<'a> { if matches!(attr.as_str(), "remove" | "discard" | "pop") { if arguments.len() == 1 { if let [arg] = &*arguments.args { - if ComparableExpr::from(self.target) == ComparableExpr::from(arg) { + if ComparableExpr::from(self.index) == ComparableExpr::from(arg) + || ComparableExpr::from(self.target) + == ComparableExpr::from(arg) + { return; } } diff --git a/crates/ruff_linter/src/rules/flake8_bugbear/snapshots/ruff_linter__rules__flake8_bugbear__tests__B909_B909.py.snap b/crates/ruff_linter/src/rules/flake8_bugbear/snapshots/ruff_linter__rules__flake8_bugbear__tests__B909_B909.py.snap index 7f70841c6c066..a0fadcf86520f 100644 --- a/crates/ruff_linter/src/rules/flake8_bugbear/snapshots/ruff_linter__rules__flake8_bugbear__tests__B909_B909.py.snap +++ b/crates/ruff_linter/src/rules/flake8_bugbear/snapshots/ruff_linter__rules__flake8_bugbear__tests__B909_B909.py.snap @@ -339,3 +339,13 @@ B909.py:150:8: B909 Mutation to loop iterable `some_list` during iteration 151 | pass 152 | else: | + +B909.py:164:5: B909 Mutation to loop iterable `some_list` during iteration + | +162 | # should error +163 | for i, elem in enumerate(some_list): +164 | some_list.pop(0) + | ^^^^^^^^^^^^^ B909 +165 | +166 | # should not error (list) + |