Skip to content

Commit

Permalink
Add support for match statements
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Feb 21, 2023
1 parent cdc4e86 commit 8516b81
Show file tree
Hide file tree
Showing 31 changed files with 589 additions and 25 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ libcst = { git = "https://github.com/charliermarsh/LibCST", rev = "f2f0b7a487a87
once_cell = { version = "1.16.0" }
regex = { version = "1.6.0" }
rustc-hash = { version = "1.1.0" }
rustpython-common = { git = "https://github.com/RustPython/RustPython.git", rev = "ef873b4b606f0a58e3640b6186416631fdeead26" }
rustpython-parser = { features = ["lalrpop"], git = "https://github.com/RustPython/RustPython.git", rev = "ef873b4b606f0a58e3640b6186416631fdeead26" }
rustpython-common = { git = "https://github.com/RustPython/RustPython.git", rev = "ddf497623ae56d21aa4166ff1c0725a7db67e955" }
rustpython-parser = { features = ["lalrpop"], git = "https://github.com/RustPython/RustPython.git", rev = "ddf497623ae56d21aa4166ff1c0725a7db67e955" }
schemars = { version = "0.8.11" }
serde = { version = "1.0.147", features = ["derive"] }
serde_json = { version = "1.0.87" }
Expand Down
22 changes: 22 additions & 0 deletions crates/ruff/resources/test/fixtures/flake8_bugbear/B012.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,25 @@ def k():
pass
finally:
break # warning


while True:
try:
pass
finally:
match *0, 1, *2:
case 0,:
y = 0
case 0, *x:
break # warning


while True:
try:
pass
finally:
match *0, 1, *2:
case 0,:
y = 0
case 0, *x:
pass # no warning
8 changes: 8 additions & 0 deletions crates/ruff/resources/test/fixtures/flake8_bugbear/B904.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,11 @@ def context_switch():
raise RuntimeError("boom!")
else:
raise RuntimeError("bang!")


try:
...
except Exception as e:
match 0:
case 0:
raise RuntimeError("boom!")
9 changes: 9 additions & 0 deletions crates/ruff/resources/test/fixtures/flake8_return/RET503.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,12 @@ def while_true():
if y > 0:
return 1
y += 1


# match
def x(y):
match y:
case 0:
return 1
case 1:
print() # error
3 changes: 3 additions & 0 deletions crates/ruff/resources/test/fixtures/pycodestyle/E70.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ def f(): ...
class C: ...; x = 1
#: E701:1:8 E702:1:13
class C: ...; ...
#: E701:2:12
match *0, 1, *2:
case 0,: y = 0
7 changes: 7 additions & 0 deletions crates/ruff/resources/test/fixtures/pyflakes/F401_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ def b(self) -> None:


CustomInt: TypeAlias = "np.int8 | np.int16"


# Test: match statements.
match *0, 1, *2:
case 0,:
import x
import y
6 changes: 3 additions & 3 deletions crates/ruff/resources/test/fixtures/pyflakes/F811_20.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Test that shadowing a global with a class attribute does not produce a
warning.
"""
Test that shadowing a global with a class attribute does not produce a
warning.
"""

import fu

Expand Down
27 changes: 26 additions & 1 deletion crates/ruff/resources/test/fixtures/pyflakes/F841_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def f():
def connect():
return None, None

with (connect() as (connection, cursor)):
with connect() as (connection, cursor):
cursor.execute("SELECT * FROM users")


Expand All @@ -94,3 +94,28 @@ def f():
(exponential := (exponential * base_multiplier) % 3): i + 1 for i in range(2)
}
return hash_map


def f(x: int):
msg1 = "Hello, world!"
msg2 = "Hello, world!"
msg3 = "Hello, world!"
match x:
case 1:
print(msg1)
case 2:
print(msg2)


def f(x: int):
import enum

Foo = enum.Enum("Foo", "A B")
Bar = enum.Enum("Bar", "A B")
Baz = enum.Enum("Baz", "A B")

match x:
case (Foo.A):
print("A")
case [Bar.A, *_]:
print("A")
117 changes: 114 additions & 3 deletions crates/ruff/src/ast/comparable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
use num_bigint::BigInt;
use rustpython_parser::ast::{
Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, Excepthandler,
ExcepthandlerKind, Expr, ExprContext, ExprKind, Keyword, Operator, Stmt, StmtKind, Unaryop,
Withitem,
ExcepthandlerKind, Expr, ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern,
PatternKind, Stmt, StmtKind, Unaryop, Withitem,
};

#[derive(Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -157,6 +157,110 @@ impl<'a> From<&'a Withitem> for ComparableWithitem<'a> {
}
}

#[allow(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub enum ComparablePattern<'a> {
MatchValue {
value: ComparableExpr<'a>,
},
MatchSingleton {
value: ComparableConstant<'a>,
},
MatchSequence {
patterns: Vec<ComparablePattern<'a>>,
},
MatchMapping {
keys: Vec<ComparableExpr<'a>>,
patterns: Vec<ComparablePattern<'a>>,
rest: Option<&'a str>,
},
MatchClass {
cls: ComparableExpr<'a>,
patterns: Vec<ComparablePattern<'a>>,
kwd_attrs: Vec<&'a str>,
kwd_patterns: Vec<ComparablePattern<'a>>,
},
MatchStar {
name: Option<&'a str>,
},
MatchAs {
pattern: Option<Box<ComparablePattern<'a>>>,
name: Option<&'a str>,
},
MatchOr {
patterns: Vec<ComparablePattern<'a>>,
},
}

impl<'a> From<&'a Pattern> for ComparablePattern<'a> {
fn from(pattern: &'a Pattern) -> Self {
match &pattern.node {
PatternKind::MatchValue { value } => Self::MatchValue {
value: value.into(),
},
PatternKind::MatchSingleton { value } => Self::MatchSingleton {
value: value.into(),
},
PatternKind::MatchSequence { patterns } => Self::MatchSequence {
patterns: patterns.iter().map(Into::into).collect(),
},
PatternKind::MatchMapping {
keys,
patterns,
rest,
} => Self::MatchMapping {
keys: keys.iter().map(Into::into).collect(),
patterns: patterns.iter().map(Into::into).collect(),
rest: rest.as_deref(),
},
PatternKind::MatchClass {
cls,
patterns,
kwd_attrs,
kwd_patterns,
} => Self::MatchClass {
cls: cls.into(),
patterns: patterns.iter().map(Into::into).collect(),
kwd_attrs: kwd_attrs.iter().map(String::as_str).collect(),
kwd_patterns: kwd_patterns.iter().map(Into::into).collect(),
},
PatternKind::MatchStar { name } => Self::MatchStar {
name: name.as_deref(),
},
PatternKind::MatchAs { pattern, name } => Self::MatchAs {
pattern: pattern.as_ref().map(Into::into),
name: name.as_deref(),
},
PatternKind::MatchOr { patterns } => Self::MatchOr {
patterns: patterns.iter().map(Into::into).collect(),
},
}
}
}

impl<'a> From<&'a Box<Pattern>> for Box<ComparablePattern<'a>> {
fn from(pattern: &'a Box<Pattern>) -> Self {
Box::new((&**pattern).into())
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ComparableMatchCase<'a> {
pub pattern: ComparablePattern<'a>,
pub guard: Option<ComparableExpr<'a>>,
pub body: Vec<ComparableStmt<'a>>,
}

impl<'a> From<&'a MatchCase> for ComparableMatchCase<'a> {
fn from(match_case: &'a MatchCase) -> Self {
Self {
pattern: (&match_case.pattern).into(),
guard: match_case.guard.as_ref().map(Into::into),
body: match_case.body.iter().map(Into::into).collect(),
}
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub enum ComparableConstant<'a> {
None,
Expand Down Expand Up @@ -644,6 +748,10 @@ pub enum ComparableStmt<'a> {
body: Vec<ComparableStmt<'a>>,
type_comment: Option<&'a str>,
},
Match {
subject: ComparableExpr<'a>,
cases: Vec<ComparableMatchCase<'a>>,
},
Raise {
exc: Option<ComparableExpr<'a>>,
cause: Option<ComparableExpr<'a>>,
Expand Down Expand Up @@ -817,7 +925,10 @@ impl<'a> From<&'a Stmt> for ComparableStmt<'a> {
body: body.iter().map(Into::into).collect(),
type_comment: type_comment.as_ref().map(String::as_str),
},
StmtKind::Match { .. } => unreachable!("StmtKind::Match is not supported"),
StmtKind::Match { subject, cases } => Self::Match {
subject: subject.into(),
cases: cases.iter().map(Into::into).collect(),
},
StmtKind::Raise { exc, cause } => Self::Raise {
exc: exc.as_ref().map(Into::into),
cause: cause.as_ref().map(Into::into),
Expand Down
59 changes: 56 additions & 3 deletions crates/ruff/src/ast/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use regex::Regex;
use rustc_hash::{FxHashMap, FxHashSet};
use rustpython_parser::ast::{
Arguments, Constant, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Keyword, KeywordData,
Located, Location, Stmt, StmtKind,
Located, Location, MatchCase, Pattern, PatternKind, Stmt, StmtKind,
};
use rustpython_parser::lexer;
use rustpython_parser::lexer::Tok;
Expand Down Expand Up @@ -249,6 +249,46 @@ where
}
}

pub fn any_over_pattern<F>(pattern: &Pattern, func: &F) -> bool
where
F: Fn(&Expr) -> bool,
{
match &pattern.node {
PatternKind::MatchValue { value } => any_over_expr(value, func),
PatternKind::MatchSingleton { .. } => false,
PatternKind::MatchSequence { patterns } => patterns
.iter()
.any(|pattern| any_over_pattern(pattern, func)),
PatternKind::MatchMapping { keys, patterns, .. } => {
keys.iter().any(|key| any_over_expr(key, func))
|| patterns
.iter()
.any(|pattern| any_over_pattern(pattern, func))
}
PatternKind::MatchClass {
cls,
patterns,
kwd_patterns,
..
} => {
any_over_expr(cls, func)
|| patterns
.iter()
.any(|pattern| any_over_pattern(pattern, func))
|| kwd_patterns
.iter()
.any(|pattern| any_over_pattern(pattern, func))
}
PatternKind::MatchStar { .. } => false,
PatternKind::MatchAs { pattern, .. } => pattern
.as_ref()
.map_or(false, |pattern| any_over_pattern(pattern, func)),
PatternKind::MatchOr { patterns } => patterns
.iter()
.any(|pattern| any_over_pattern(pattern, func)),
}
}

pub fn any_over_stmt<F>(stmt: &Stmt, func: &F) -> bool
where
F: Fn(&Expr) -> bool,
Expand Down Expand Up @@ -415,8 +455,21 @@ where
.as_ref()
.map_or(false, |value| any_over_expr(value, func))
}
// TODO(charlie): Handle match statements.
StmtKind::Match { .. } => false,
StmtKind::Match { subject, cases } => {
any_over_expr(subject, func)
|| cases.iter().any(|case| {
let MatchCase {
pattern,
guard,
body,
} = case;
any_over_pattern(pattern, func)
|| guard
.as_ref()
.map_or(false, |expr| any_over_expr(expr, func))
|| any_over_body(body, func)
})
}
StmtKind::Import { .. } => false,
StmtKind::ImportFrom { .. } => false,
StmtKind::Global { .. } => false,
Expand Down
Loading

0 comments on commit 8516b81

Please sign in to comment.