Skip to content

Commit

Permalink
Merge pull request #4449 from harupy/fix-dict-spread-in-dict
Browse files Browse the repository at this point in the history
Fix AST generated from a dict literal containing dict unpacking
  • Loading branch information
youknowone authored Jan 22, 2023
2 parents 6dba843 + f2ffe12 commit d9df131
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 43 deletions.
11 changes: 8 additions & 3 deletions ast/asdl_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,25 @@ def visitConstructor(self, cons, parent, depth):
if cons.fields:
self.emit(f"{cons.name} {{", depth)
for f in cons.fields:
self.visit(f, parent, "", depth + 1)
self.visit(f, parent, "", depth + 1, cons.name)
self.emit("},", depth)
else:
self.emit(f"{cons.name},", depth)

def visitField(self, field, parent, vis, depth):
def visitField(self, field, parent, vis, depth, constructor=None):
typ = get_rust_type(field.type)
fieldtype = self.typeinfo.get(field.type)
if fieldtype and fieldtype.has_userdata:
typ = f"{typ}<U>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if fieldtype and fieldtype.boxed and (not (parent.product or field.seq) or field.opt):
typ = f"Box<{typ}>"
if field.opt:
if field.opt or (
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict" and field.name == "keys"
):
typ = f"Option<{typ}>"
if field.seq:
typ = f"Vec<{typ}>"
Expand Down
2 changes: 1 addition & 1 deletion ast/src/ast_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ pub enum ExprKind<U = ()> {
orelse: Box<Expr<U>>,
},
Dict {
keys: Vec<Expr<U>>,
keys: Vec<Option<Expr<U>>>,
values: Vec<Expr<U>>,
},
Set {
Expand Down
6 changes: 5 additions & 1 deletion ast/src/unparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ impl<'a> Unparser<'a> {
let (packed, unpacked) = values.split_at(keys.len());
for (k, v) in keys.iter().zip(packed) {
self.p_delim(&mut first, ", ")?;
write!(self, "{}: {}", *k, *v)?;
if let Some(k) = k {
write!(self, "{}: {}", *k, *v)?;
} else {
write!(self, "**{}", *v)?;
}
}
for d in unpacked {
self.p_delim(&mut first, ", ")?;
Expand Down
43 changes: 5 additions & 38 deletions parser/python.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -1136,44 +1136,11 @@ Atom<Goal>: ast::Expr = {
}.into())
},
<location:@L> "{" <e:DictLiteralValues?> "}" <end_location:@R> => {
let pairs = e.unwrap_or_default();

let (keys, values) = match pairs.iter().position(|(k,_)| k.is_none()) {
Some(unpack_idx) => {
let mut pairs = pairs;
let (keys, mut values): (_, Vec<_>) = pairs.drain(..unpack_idx).map(|(k, v)| (*k.unwrap(), v)).unzip();

fn build_map(items: &mut Vec<(ast::Expr, ast::Expr)>) -> ast::Expr {
let location = items[0].0.location;
let end_location = items[0].0.end_location;
let (keys, values) = items.drain(..).unzip();
ast::Expr {
location,
end_location,
custom: (),
node: ast::ExprKind::Dict { keys, values }
}
}

let mut items = Vec::new();
for (key, value) in pairs.into_iter() {
if let Some(key) = key {
items.push((*key, value));
continue;
}
if !items.is_empty() {
values.push(build_map(&mut items));
}
values.push(value);
}
if !items.is_empty() {
values.push(build_map(&mut items));
}
(keys, values)
},
None => pairs.into_iter().map(|(k, v)| (*k.unwrap(), v)).unzip()
};

let (keys, values) = e
.unwrap_or_default()
.into_iter()
.map(|(k, v)| (k.map(|x| *x), v))
.unzip();
ast::Expr {
location,
end_location: Some(end_location),
Expand Down
6 changes: 6 additions & 0 deletions parser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,4 +309,10 @@ with (0 as a, 1 as b,): pass
assert!(parse_program(source, "<test>").is_err());
}
}

#[test]
fn test_dict_unpacking() {
let parse_ast = parse_expression(r#"{"a": "b", **c, "d": "e"}"#, "<test>").unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
---
source: compiler/parser/src/parser.rs
expression: parse_ast
---
Located {
location: Location {
row: 1,
column: 0,
},
end_location: Some(
Location {
row: 1,
column: 25,
},
),
custom: (),
node: Dict {
keys: [
Some(
Located {
location: Location {
row: 1,
column: 1,
},
end_location: Some(
Location {
row: 1,
column: 4,
},
),
custom: (),
node: Constant {
value: Str(
"a",
),
kind: None,
},
},
),
None,
Some(
Located {
location: Location {
row: 1,
column: 16,
},
end_location: Some(
Location {
row: 1,
column: 19,
},
),
custom: (),
node: Constant {
value: Str(
"d",
),
kind: None,
},
},
),
],
values: [
Located {
location: Location {
row: 1,
column: 6,
},
end_location: Some(
Location {
row: 1,
column: 9,
},
),
custom: (),
node: Constant {
value: Str(
"b",
),
kind: None,
},
},
Located {
location: Location {
row: 1,
column: 13,
},
end_location: Some(
Location {
row: 1,
column: 14,
},
),
custom: (),
node: Name {
id: "c",
ctx: Load,
},
},
Located {
location: Location {
row: 1,
column: 21,
},
end_location: Some(
Location {
row: 1,
column: 24,
},
),
custom: (),
node: Constant {
value: Str(
"e",
),
kind: None,
},
},
],
},
}

0 comments on commit d9df131

Please sign in to comment.