Skip to content

Commit

Permalink
[pylint] Correct ordering of arguments in fix for if-stmt-min-max
Browse files Browse the repository at this point in the history
… (`PLR1730`) (#16080)

The PR addresses the issue #16040 .

---

The logic used into the rule is the following:

Suppose to have an expression of the form 

```python
if a cmp b:
    c = d
```
where `a`,` b`, `c` and `d` are Python obj and `cmp` one of `<`, `>`,
`<=`, `>=`.

Then:

- `if a=c and b=d`
    
    - if `<=` fix with `a = max(b, a)`
    - if `>=`  fix with `a = min(b, a)`
    - if `>` fix with `a = min(a, b)`
    - if `<` fix with `a = max(a, b)`

- `if a=d and b=c`

    - if `<=` fix with `b = min(a, b)`
    - if `>=`  fix with `b = max(a, b)`
    - if `>` fix with `b = max(b, a)`
    - if `<` fix with `b = min(b, a)`
 
- do nothing, i.e., we cannot fix this case.

---

In total we have 8 different and possible cases.

```

| Case  | Expression       | Fix           |
|-------|------------------|---------------|
| 1     | if a >= b: a = b | a = min(b, a) |
| 2     | if a <= b: a = b | a = max(b, a) |
| 3     | if a <= b: b = a | b = min(a, b) |
| 4     | if a >= b: b = a | b = max(a, b) |
| 5     | if a > b: a = b  | a = min(a, b) |
| 6     | if a < b: a = b  | a = max(a, b) |
| 7     | if a < b: b = a  | b = min(b, a) |
| 8     | if a > b: b = a  | b = max(b, a) |
```

I added them in the tests. 

Please double-check that I didn't make any mistakes. It's quite easy to
mix up > and <.

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
  • Loading branch information
VascoSch92 and MichaReiser authored Feb 12, 2025
1 parent 366ae1f commit ae1b381
Show file tree
Hide file tree
Showing 3 changed files with 745 additions and 415 deletions.
Original file line number Diff line number Diff line change
@@ -1,39 +1,98 @@
# pylint: disable=missing-docstring, invalid-name, too-few-public-methods, redefined-outer-name


# the rule take care of the following cases:
#
# | Case | Expression | Fix |
# |-------|------------------|---------------|
# | 1 | if a >= b: a = b | a = min(b, a) |
# | 2 | if a <= b: a = b | a = max(b, a) |
# | 3 | if a <= b: b = a | b = min(a, b) |
# | 4 | if a >= b: b = a | b = max(a, b) |
# | 5 | if a > b: a = b | a = min(a, b) |
# | 6 | if a < b: a = b | a = max(a, b) |
# | 7 | if a < b: b = a | b = min(b, a) |
# | 8 | if a > b: b = a | b = max(b, a) |

# the 8 base cases
a, b = [], []

# case 1: a = min(b, a)
if a >= b:
a = b

# case 2: a = max(b, a)
if a <= b:
a = b

# case 3: b = min(a, b)
if a <= b:
b = a

# case 4: b = max(a, b)
if a >= b:
b = a

# case 5: a = min(a, b)
if a > b:
a = b

# case 6: a = max(a, b)
if a < b:
a = b

# case 7: b = min(b, a)
if a < b:
b = a

# case 8: b = max(b, a)
if a > b:
b = a


# test cases with assigned variables and primitives
value = 10
value2 = 0
value3 = 3

# Positive
if value < 10: # [max-instead-of-if]
# base case 6: value = max(value, 10)
if value < 10:
value = 10

if value <= 10: # [max-instead-of-if]
# base case 2: value = max(10, value)
if value <= 10:
value = 10

if value < value2: # [max-instead-of-if]
# base case 6: value = max(value, value2)
if value < value2:
value = value2

if value > 10: # [min-instead-of-if]
# base case 5: value = min(value, 10)
if value > 10:
value = 10

if value >= 10: # [min-instead-of-if]
# base case 1: value = min(10, value)
if value >= 10:
value = 10

if value > value2: # [min-instead-of-if]
# base case 5: value = min(value, value2)
if value > value2:
value = value2


# cases with calls
class A:
def __init__(self):
self.value = 13


A1 = A()
if A1.value < 10: # [max-instead-of-if]


if A1.value < 10:
A1.value = 10

if A1.value > 10: # [min-instead-of-if]
if A1.value > 10:
A1.value = 10


Expand Down Expand Up @@ -159,3 +218,22 @@ def foo(self, value) -> None:
self._min = value
if self._max >= value:
self._max = value


counter = {"a": 0, "b": 0}

# base case 2: counter["a"] = max(counter["b"], counter["a"])
if counter["a"] <= counter["b"]:
counter["a"] = counter["b"]

# case 3: counter["b"] = min(counter["a"], counter["b"])
if counter["a"] <= counter["b"]:
counter["b"] = counter["a"]

# case 5: counter["a"] = min(counter["a"], counter["b"])
if counter["a"] > counter["b"]:
counter["b"] = counter["a"]

# case 8: counter["a"] = max(counter["b"], counter["a"])
if counter["a"] > counter["b"]:
counter["b"] = counter["a"]
61 changes: 29 additions & 32 deletions crates/ruff_linter/src/rules/pylint/rules/if_stmt_min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,47 +106,44 @@ pub(crate) fn if_stmt_min_max(checker: &Checker, stmt_if: &ast::StmtIf) {
let [op] = &**ops else {
return;
};

let [right] = &**comparators else {
return;
};

let left_cmp = ComparableExpr::from(left);
let body_target_cmp = ComparableExpr::from(body_target);
let right_cmp = ComparableExpr::from(right);
let body_value_cmp = ComparableExpr::from(body_value);

let left_is_target = left_cmp == body_target_cmp;
let right_is_target = right_cmp == body_target_cmp;
let left_is_value = left_cmp == body_value_cmp;
let right_is_value = right_cmp == body_value_cmp;

let min_max = match (
left_is_target,
right_is_target,
left_is_value,
right_is_value,
) {
(true, false, false, true) => match op {
CmpOp::Lt | CmpOp::LtE => MinMax::Max,
CmpOp::Gt | CmpOp::GtE => MinMax::Min,
// extract helpful info from expression of the form
// `if cmp_left op cmp_right: target = assignment_value`
let cmp_left = ComparableExpr::from(left);
let cmp_right = ComparableExpr::from(right);
let target = ComparableExpr::from(body_target);
let assignment_value = ComparableExpr::from(body_value);

// Ex): if a < b: a = b
let (min_max, flip_args) = if cmp_left == target && cmp_right == assignment_value {
match op {
CmpOp::Lt => (MinMax::Max, false),
CmpOp::LtE => (MinMax::Max, true),
CmpOp::Gt => (MinMax::Min, false),
CmpOp::GtE => (MinMax::Min, true),
_ => return,
},
(false, true, true, false) => match op {
CmpOp::Lt | CmpOp::LtE => MinMax::Min,
CmpOp::Gt | CmpOp::GtE => MinMax::Max,
}
}
// Ex): `if a < b: b = a`
else if cmp_left == assignment_value && cmp_right == target {
match op {
CmpOp::Lt => (MinMax::Min, true),
CmpOp::LtE => (MinMax::Min, false),
CmpOp::Gt => (MinMax::Max, true),
CmpOp::GtE => (MinMax::Max, false),
_ => return,
},
_ => return,
}
} else {
return;
};

// Determine whether to use `min()` or `max()`, and make sure that the first
// arg of the `min()` or `max()` method is equal to the target of the comparison.
// This is to be consistent with the Python implementation of the methods `min()` and `max()`.
let (arg1, arg2) = if left_is_target {
(&**left, right)
} else {
let (arg1, arg2) = if flip_args {
(right, &**left)
} else {
(&**left, right)
};

let replacement = format!(
Expand Down
Loading

0 comments on commit ae1b381

Please sign in to comment.