Skip to content

Commit

Permalink
[red-knot] Add support for unpacking for target
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvmanila committed Dec 23, 2024
1 parent b6c8f5d commit 6a116dc
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 83 deletions.
101 changes: 101 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/unpacking.md
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,104 @@ def _(arg: tuple[int, str] | Iterable):
reveal_type(a) # revealed: int | bytes
reveal_type(b) # revealed: str | bytes
```

## For statement

Unpacking in a `for` statement.

### Same types

```py
def _(arg: tuple[tuple[int, int], tuple[int, int]]):
for a, b in arg:
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
```

### Mixed types (1)

```py
def _(arg: tuple[tuple[int, int], tuple[int, str]]):
for a, b in arg:
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
```

### Mixed types (2)

```py
def _(arg: tuple[tuple[int, str], tuple[str, int]]):
for a, b in arg:
reveal_type(a) # revealed: int | str
reveal_type(b) # revealed: str | int
```

### Mixed types (3)

```py
def _(arg: tuple[tuple[int, int, int], tuple[int, str, bytes], tuple[int, int, str]]):
for a, b, c in arg:
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
reveal_type(c) # revealed: int | bytes | str
```

### Same literal values

```py
for a, b in ((1, 2), (3, 4)):
reveal_type(a) # revealed: Literal[1, 3]
reveal_type(b) # revealed: Literal[2, 4]
```

### Mixed literal values (1)

```py
for a, b in ((1, 2), ("a", "b")):
reveal_type(a) # revealed: Literal[1] | Literal["a"]
reveal_type(b) # revealed: Literal[2] | Literal["b"]
```

### Mixed literals values (2)

```py
# error: "Object of type `Literal[1]` is not iterable"
# error: "Object of type `Literal[2]` is not iterable"
# error: "Object of type `Literal[4]` is not iterable"
for a, b in (1, 2, (3, "a"), 4, (5, "b"), "c"):
reveal_type(a) # revealed: Unknown | Literal[3, 5] | LiteralString
reveal_type(b) # revealed: Unknown | Literal["a", "b"]
```

### Custom iterator (1)

```py
class Iterator:
def __next__(self) -> tuple[int, int]:
return (1, 2)

class Iterable:
def __iter__(self) -> Iterator:
return Iterator()

for a, b in Iterable():
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
```

### Custom iterator (2)

```py
class Iterator:
def __next__(self) -> bytes:
return b""

class Iterable:
def __iter__(self) -> Iterator:
return Iterator()

def _(arg: tuple[tuple[int, str], Iterable]):
for a, b in arg:
reveal_type(a) # revealed: int | bytes
reveal_type(b) # revealed: str | bytes
```
67 changes: 51 additions & 16 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::semantic_index::use_def::{
FlowSnapshot, ScopedConstraintId, ScopedVisibilityConstraintId, UseDefMapBuilder,
};
use crate::semantic_index::SemanticIndex;
use crate::unpack::Unpack;
use crate::unpack::{Unpack, UnpackValue};
use crate::visibility_constraints::VisibilityConstraint;
use crate::Db;

Expand Down Expand Up @@ -810,7 +810,7 @@ where
unsafe {
AstNodeRef::new(self.module.clone(), target)
},
value,
UnpackValue::Assign(value),
countme::Count::default(),
)),
})
Expand Down Expand Up @@ -1021,18 +1021,47 @@ where
orelse,
},
) => {
self.add_standalone_expression(iter);
debug_assert_eq!(&self.current_assignments, &[]);

let iter_expr = self.add_standalone_expression(iter);
self.visit_expr(iter);

self.record_ambiguous_visibility();

let pre_loop = self.flow_snapshot();
let saved_break_states = std::mem::take(&mut self.loop_break_states);

debug_assert_eq!(&self.current_assignments, &[]);
self.push_assignment(for_stmt.into());
let current_assignment = match &**target {
ast::Expr::List(_) | ast::Expr::Tuple(_) => Some(CurrentAssignment::For {
node: for_stmt,
first: true,
unpack: Some(Unpack::new(
self.db,
self.file,
self.current_scope(),
#[allow(unsafe_code)]
unsafe {
AstNodeRef::new(self.module.clone(), target)
},
UnpackValue::Iterable(iter_expr),
countme::Count::default(),
)),
}),
ast::Expr::Name(_) => Some(CurrentAssignment::For {
node: for_stmt,
unpack: None,
first: false,
}),
_ => None,
};

if let Some(current_assignment) = current_assignment {
self.push_assignment(current_assignment);
}
self.visit_expr(target);
self.pop_assignment();
if current_assignment.is_some() {
self.pop_assignment();
}

// TODO: Definitions created by loop variables
// (and definitions created inside the body)
Expand Down Expand Up @@ -1283,12 +1312,18 @@ where
Some(CurrentAssignment::AugAssign(aug_assign)) => {
self.add_definition(symbol, aug_assign);
}
Some(CurrentAssignment::For(node)) => {
Some(CurrentAssignment::For {
node,
first,
unpack,
}) => {
self.add_definition(
symbol,
ForStmtDefinitionNodeRef {
unpack,
first,
iterable: &node.iter,
target: name_node,
name: name_node,
is_async: node.is_async,
},
);
Expand Down Expand Up @@ -1324,7 +1359,9 @@ where
}
}

if let Some(CurrentAssignment::Assign { first, .. }) = self.current_assignment_mut()
if let Some(
CurrentAssignment::Assign { first, .. } | CurrentAssignment::For { first, .. },
) = self.current_assignment_mut()
{
*first = false;
}
Expand Down Expand Up @@ -1566,7 +1603,11 @@ enum CurrentAssignment<'a> {
},
AnnAssign(&'a ast::StmtAnnAssign),
AugAssign(&'a ast::StmtAugAssign),
For(&'a ast::StmtFor),
For {
node: &'a ast::StmtFor,
first: bool,
unpack: Option<Unpack<'a>>,
},
Named(&'a ast::ExprNamed),
Comprehension {
node: &'a ast::Comprehension,
Expand All @@ -1590,12 +1631,6 @@ impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> {
}
}

impl<'a> From<&'a ast::StmtFor> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtFor) -> Self {
Self::For(value)
}
}

impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
fn from(value: &'a ast::ExprNamed) -> Self {
Self::Named(value)
Expand Down
50 changes: 31 additions & 19 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,10 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> {

#[derive(Copy, Clone, Debug)]
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
pub(crate) unpack: Option<Unpack<'a>>,
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
pub(crate) name: &'a ast::ExprName,
pub(crate) first: bool,
pub(crate) is_async: bool,
}

Expand Down Expand Up @@ -298,12 +300,16 @@ impl<'db> DefinitionNodeRef<'db> {
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
}
DefinitionNodeRef::For(ForStmtDefinitionNodeRef {
unpack,
iterable,
target,
name,
first,
is_async,
}) => DefinitionKind::For(ForStmtDefinitionKind {
target: TargetKind::from(unpack),
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
name: AstNodeRef::new(parsed, name),
first,
is_async,
}),
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
Expand Down Expand Up @@ -382,10 +388,12 @@ impl<'db> DefinitionNodeRef<'db> {
Self::AnnotatedAssignment(node) => node.into(),
Self::AugmentedAssignment(node) => node.into(),
Self::For(ForStmtDefinitionNodeRef {
unpack: _,
iterable: _,
target,
name,
first: _,
is_async: _,
}) => target.into(),
}) => name.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
Self::VariadicPositionalParameter(node) => node.into(),
Self::VariadicKeywordParameter(node) => node.into(),
Expand Down Expand Up @@ -452,7 +460,7 @@ pub enum DefinitionKind<'db> {
Assignment(AssignmentDefinitionKind<'db>),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
For(ForStmtDefinitionKind),
For(ForStmtDefinitionKind<'db>),
Comprehension(ComprehensionDefinitionKind),
VariadicPositionalParameter(AstNodeRef<ast::Parameter>),
VariadicKeywordParameter(AstNodeRef<ast::Parameter>),
Expand All @@ -477,7 +485,7 @@ impl Ranged for DefinitionKind<'_> {
DefinitionKind::Assignment(assignment) => assignment.name().range(),
DefinitionKind::AnnotatedAssignment(assign) => assign.target.range(),
DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.target.range(),
DefinitionKind::For(for_stmt) => for_stmt.target().range(),
DefinitionKind::For(for_stmt) => for_stmt.name().range(),
DefinitionKind::Comprehension(comp) => comp.target().range(),
DefinitionKind::VariadicPositionalParameter(parameter) => parameter.name.range(),
DefinitionKind::VariadicKeywordParameter(parameter) => parameter.name.range(),
Expand Down Expand Up @@ -665,22 +673,32 @@ impl WithItemDefinitionKind {
}

#[derive(Clone, Debug)]
pub struct ForStmtDefinitionKind {
pub struct ForStmtDefinitionKind<'db> {
target: TargetKind<'db>,
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
name: AstNodeRef<ast::ExprName>,
first: bool,
is_async: bool,
}

impl ForStmtDefinitionKind {
impl<'db> ForStmtDefinitionKind<'db> {
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
}

pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
pub(crate) fn target(&self) -> TargetKind<'db> {
self.target
}

pub(crate) fn is_async(&self) -> bool {
pub(crate) fn name(&self) -> &ast::ExprName {
self.name.node()
}

pub(crate) const fn is_first(&self) -> bool {
self.first
}

pub(crate) const fn is_async(&self) -> bool {
self.is_async
}
}
Expand Down Expand Up @@ -756,12 +774,6 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
}
}

impl From<&ast::StmtFor> for DefinitionNodeKey {
fn from(value: &ast::StmtFor) -> Self {
Self(NodeKey::from_node(value))
}
}

impl From<&ast::Parameter> for DefinitionNodeKey {
fn from(node: &ast::Parameter) -> Self {
Self(NodeKey::from_node(node))
Expand Down
Loading

0 comments on commit 6a116dc

Please sign in to comment.