From dd60a3865c3067f22e692104249a941fee8431e9 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 4 Jul 2023 22:11:29 -0400 Subject: [PATCH] Avoid triggering `unnecessary-map` (`C417`) for late-bound lambdas (#5520) Closes https://github.com/astral-sh/ruff/issues/5502. --- .../fixtures/flake8_comprehensions/C417.py | 7 +- .../rules/unnecessary_map.rs | 100 +++++++++++++++++- ...8_comprehensions__tests__C417_C417.py.snap | 18 ++-- 3 files changed, 110 insertions(+), 15 deletions(-) diff --git a/crates/ruff/resources/test/fixtures/flake8_comprehensions/C417.py b/crates/ruff/resources/test/fixtures/flake8_comprehensions/C417.py index 5cf7548cc3632..7077acfad6bc5 100644 --- a/crates/ruff/resources/test/fixtures/flake8_comprehensions/C417.py +++ b/crates/ruff/resources/test/fixtures/flake8_comprehensions/C417.py @@ -32,5 +32,8 @@ def func(arg1: int, arg2: int = 4): # Non-error: `func` is not a lambda. list(map(func, nums)) -# False positive: need to preserve the late-binding of `x`. -callbacks = map(lambda x: lambda: x, range(4)) +# False positive: need to preserve the late-binding of `x` in the inner lambda. +map(lambda x: lambda: x, range(4)) + +# Error: the `x` is overridden by the inner lambda. +map(lambda x: lambda x: x, range(4)) diff --git a/crates/ruff/src/rules/flake8_comprehensions/rules/unnecessary_map.rs b/crates/ruff/src/rules/flake8_comprehensions/rules/unnecessary_map.rs index 4f29021ff99df..0505d08028bdc 100644 --- a/crates/ruff/src/rules/flake8_comprehensions/rules/unnecessary_map.rs +++ b/crates/ruff/src/rules/flake8_comprehensions/rules/unnecessary_map.rs @@ -1,10 +1,13 @@ use std::fmt; -use rustpython_parser::ast::{self, Expr, Ranged}; +use rustpython_parser::ast::{self, Arguments, Expr, ExprContext, Ranged, Stmt}; use ruff_diagnostics::{AutofixKind, Violation}; use ruff_diagnostics::{Diagnostic, Fix}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::helpers::includes_arg_name; +use ruff_python_ast::visitor; +use ruff_python_ast::visitor::Visitor; use crate::checkers::ast::Checker; use crate::registry::AsRule; @@ -95,7 +98,11 @@ pub(crate) fn unnecessary_map( }; // Only flag, e.g., `map(lambda x: x + 1, iterable)`. - if !matches!(args, [Expr::Lambda(_), _]) { + let [Expr::Lambda(ast::ExprLambda { args, body, .. }), _] = args else { + return; + }; + + if late_binding(args, body) { return; } } @@ -114,7 +121,11 @@ pub(crate) fn unnecessary_map( return; }; - if !argument.is_lambda_expr() { + let Expr::Lambda(ast::ExprLambda { args, body, .. }) = argument else { + return; + }; + + if late_binding(args, body) { return; } } @@ -129,7 +140,7 @@ pub(crate) fn unnecessary_map( return; }; - let Expr::Lambda(ast::ExprLambda { body, .. }) = argument else { + let Expr::Lambda(ast::ExprLambda { args, body, .. }) = argument else { return; }; @@ -142,6 +153,10 @@ pub(crate) fn unnecessary_map( if elts.len() != 2 { return; } + + if late_binding(args, body) { + return; + } } } @@ -173,3 +188,80 @@ impl fmt::Display for ObjectType { } } } + +/// Returns `true` if the lambda defined by the given arguments and body contains any names that +/// are late-bound within nested lambdas. +/// +/// For example, given: +/// +/// ```python +/// map(lambda x: lambda: x, range(4)) # (0, 1, 2, 3) +/// ``` +/// +/// The `x` in the inner lambda is "late-bound". Specifically, rewriting the above as: +/// +/// ```python +/// (lambda: x for x in range(4)) # (3, 3, 3, 3) +/// ``` +/// +/// Would yield an incorrect result, as the `x` in the inner lambda would be bound to the last +/// value of `x` in the comprehension. +fn late_binding(args: &Arguments, body: &Expr) -> bool { + let mut visitor = LateBindingVisitor::new(args); + visitor.visit_expr(body); + visitor.late_bound +} + +#[derive(Debug)] +struct LateBindingVisitor<'a> { + /// The arguments to the current lambda. + args: &'a Arguments, + /// The arguments to any lambdas within the current lambda body. + lambdas: Vec<&'a Arguments>, + /// Whether any names within the current lambda body are late-bound within nested lambdas. + late_bound: bool, +} + +impl<'a> LateBindingVisitor<'a> { + fn new(args: &'a Arguments) -> Self { + Self { + args, + lambdas: Vec::new(), + late_bound: false, + } + } +} + +impl<'a> Visitor<'a> for LateBindingVisitor<'a> { + fn visit_stmt(&mut self, _stmt: &'a Stmt) {} + + fn visit_expr(&mut self, expr: &'a Expr) { + match expr { + Expr::Lambda(ast::ExprLambda { args, .. }) => { + self.lambdas.push(args); + visitor::walk_expr(self, expr); + self.lambdas.pop(); + } + Expr::Name(ast::ExprName { + id, + ctx: ExprContext::Load, + .. + }) => { + // If we're within a nested lambda... + if !self.lambdas.is_empty() { + // If the name is defined in the current lambda... + if includes_arg_name(id, self.args) { + // And isn't overridden by any nested lambdas... + if !self.lambdas.iter().any(|args| includes_arg_name(id, args)) { + // Then it's late-bound. + self.late_bound = true; + } + } + } + } + _ => visitor::walk_expr(self, expr), + } + } + + fn visit_body(&mut self, _body: &'a [Stmt]) {} +} diff --git a/crates/ruff/src/rules/flake8_comprehensions/snapshots/ruff__rules__flake8_comprehensions__tests__C417_C417.py.snap b/crates/ruff/src/rules/flake8_comprehensions/snapshots/ruff__rules__flake8_comprehensions__tests__C417_C417.py.snap index 318af2d9d28ff..acb6700d3fbed 100644 --- a/crates/ruff/src/rules/flake8_comprehensions/snapshots/ruff__rules__flake8_comprehensions__tests__C417_C417.py.snap +++ b/crates/ruff/src/rules/flake8_comprehensions/snapshots/ruff__rules__flake8_comprehensions__tests__C417_C417.py.snap @@ -260,19 +260,19 @@ C417.py:21:1: C417 Unnecessary `map` usage (rewrite using a generator expression | = help: Replace `map` with a generator expression -C417.py:36:13: C417 [*] Unnecessary `map` usage (rewrite using a generator expression) +C417.py:39:1: C417 [*] Unnecessary `map` usage (rewrite using a generator expression) | -35 | # False positive: need to preserve the late-binding of `x`. -36 | callbacks = map(lambda x: lambda: x, range(4)) - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ C417 +38 | # Error: the `x` is overridden by the inner lambda. +39 | map(lambda x: lambda x: x, range(4)) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ C417 | = help: Replace `map` with a generator expression ℹ Suggested fix -33 33 | list(map(func, nums)) -34 34 | -35 35 | # False positive: need to preserve the late-binding of `x`. -36 |-callbacks = map(lambda x: lambda: x, range(4)) - 36 |+callbacks = (lambda: x for x in range(4)) +36 36 | map(lambda x: lambda: x, range(4)) +37 37 | +38 38 | # Error: the `x` is overridden by the inner lambda. +39 |-map(lambda x: lambda x: x, range(4)) + 39 |+(lambda x: x for x in range(4))