Skip to content

Commit

Permalink
Auto merge of #15667 - rmehri01:bool_to_enum_top_level, r=Veykril
Browse files Browse the repository at this point in the history
fix: make bool_to_enum assist create enum at top-level

This pr makes the `bool_to_enum` assist create the `enum` at the next closest module block or at top-level, which fixes a few tricky cases such as with an associated `const` in a trait or module:

```rust
trait Foo {
    const $0BOOL: bool;
}

impl Foo for usize {
    const BOOL: bool = true;
}

fn main() {
    if <usize as Foo>::BOOL {
        println!("foo");
    }
}
```

Which now properly produces:

```rust
#[derive(PartialEq, Eq)]
enum Bool { True, False }

trait Foo {
    const BOOL: Bool;
}

impl Foo for usize {
    const BOOL: Bool = Bool::True;
}

fn main() {
    if <usize as Foo>::BOOL == Bool::True {
        println!("foo");
    }
}
```

I also think it's a bit nicer, especially for local variables, but didn't really know to do it in the first PR :)
  • Loading branch information
bors committed Sep 29, 2023
2 parents f19479a + 1b3e5b2 commit 87e2c31
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 36 deletions.
224 changes: 191 additions & 33 deletions crates/ide-assists/src/handlers/bool_to_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use syntax::{
edit_in_place::{AttrsOwnerEdit, Indent},
make, HasName,
},
ted, AstNode, NodeOrToken, SyntaxNode, T,
ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
};
use text_edit::TextRange;

Expand All @@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
// ```
// ->
// ```
// fn main() {
// #[derive(PartialEq, Eq)]
// enum Bool { True, False }
// #[derive(PartialEq, Eq)]
// enum Bool { True, False }
//
// fn main() {
// let bool = Bool::True;
//
// if bool == Bool::True {
Expand Down Expand Up @@ -270,6 +270,15 @@ fn replace_usages(
}
_ => (),
}
} else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name)
{
edit.replace(ty_annotation.syntax().text_range(), "Bool");
replace_bool_expr(edit, initializer);
} else if let Some(receiver) = find_method_call_expr_usage(&new_name) {
edit.replace(
receiver.syntax().text_range(),
format!("({} == Bool::True)", receiver),
);
} else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
// for any other usage in an expression, replace it with a check that it is the true variant
if let Some((record_field, expr)) = new_name
Expand Down Expand Up @@ -413,6 +422,26 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
}
}

fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() {
return None;
}

Some((const_.ty()?, const_.body()?))
}

fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
let receiver = method_call.receiver()?;

if !receiver.syntax().descendants().contains(name.syntax()) {
return None;
}

Some(receiver)
}

/// Adds the definition of the new enum before the target node.
fn add_enum_def(
edit: &mut SourceChangeBuilder,
Expand All @@ -430,18 +459,31 @@ fn add_enum_def(
.any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
let enum_def = make_bool_enum(make_enum_pub);

let indent = IndentLevel::from_node(&target_node);
let insert_before = node_to_insert_before(target_node);
let indent = IndentLevel::from_node(&insert_before);
enum_def.reindent_to(indent);

ted::insert_all(
ted::Position::before(&edit.make_syntax_mut(target_node)),
ted::Position::before(&edit.make_syntax_mut(insert_before)),
vec![
enum_def.syntax().clone().into(),
make::tokens::whitespace(&format!("\n\n{indent}")).into(),
],
);
}

/// Finds where to put the new enum definition.
/// Tries to find the ast node at the nearest module or at top-level, otherwise just
/// returns the input node.
fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode {
target_node
.ancestors()
.take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE))
.filter(|it| ast::Item::can_cast(it.kind()))
.last()
.unwrap_or(target_node)
}

fn make_bool_enum(make_pub: bool) -> ast::Enum {
let enum_def = make::enum_(
if make_pub { Some(make::visibility_pub()) } else { None },
Expand Down Expand Up @@ -491,10 +533,10 @@ fn main() {
}
"#,
r#"
fn main() {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let foo = Bool::True;
if foo == Bool::True {
Expand All @@ -520,10 +562,10 @@ fn main() {
}
"#,
r#"
fn main() {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let foo = Bool::True;
if foo == Bool::False {
Expand All @@ -545,10 +587,10 @@ fn main() {
}
"#,
r#"
fn main() {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let foo: Bool = Bool::False;
}
"#,
Expand All @@ -565,10 +607,10 @@ fn main() {
}
"#,
r#"
fn main() {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let foo = if 1 == 2 { Bool::True } else { Bool::False };
}
"#,
Expand All @@ -590,10 +632,10 @@ fn main() {
}
"#,
r#"
fn main() {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let foo = Bool::False;
let bar = true;
Expand All @@ -619,10 +661,10 @@ fn main() {
}
"#,
r#"
fn main() {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let foo = Bool::True;
if *&foo == Bool::True {
Expand All @@ -645,10 +687,10 @@ fn main() {
}
"#,
r#"
fn main() {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let foo: Bool;
foo = Bool::True;
}
Expand All @@ -671,10 +713,10 @@ fn main() {
}
"#,
r#"
fn main() {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let foo = Bool::True;
let bar = foo == Bool::False;
Expand Down Expand Up @@ -702,11 +744,11 @@ fn main() {
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
if !"foo".chars().any(|c| {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
let foo = Bool::True;
foo == Bool::True
}) {
Expand Down Expand Up @@ -1244,6 +1286,38 @@ fn main() {
)
}

#[test]
fn field_method_chain_usage() {
check_assist(
bool_to_enum,
r#"
struct Foo {
$0bool: bool,
}
fn main() {
let foo = Foo { bool: true };
foo.bool.then(|| 2);
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
struct Foo {
bool: Bool,
}
fn main() {
let foo = Foo { bool: Bool::True };
(foo.bool == Bool::True).then(|| 2);
}
"#,
)
}

#[test]
fn field_non_bool() {
cov_mark::check!(not_applicable_non_bool_field);
Expand Down Expand Up @@ -1445,6 +1519,90 @@ pub mod bar {
)
}

#[test]
fn const_in_impl_cross_file() {
check_assist(
bool_to_enum,
r#"
//- /main.rs
mod foo;
struct Foo;
impl Foo {
pub const $0BOOL: bool = true;
}
//- /foo.rs
use crate::Foo;
fn foo() -> bool {
Foo::BOOL
}
"#,
r#"
//- /main.rs
mod foo;
struct Foo;
#[derive(PartialEq, Eq)]
pub enum Bool { True, False }
impl Foo {
pub const BOOL: Bool = Bool::True;
}
//- /foo.rs
use crate::{Foo, Bool};
fn foo() -> bool {
Foo::BOOL == Bool::True
}
"#,
)
}

#[test]
fn const_in_trait() {
check_assist(
bool_to_enum,
r#"
trait Foo {
const $0BOOL: bool;
}
impl Foo for usize {
const BOOL: bool = true;
}
fn main() {
if <usize as Foo>::BOOL {
println!("foo");
}
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
trait Foo {
const BOOL: Bool;
}
impl Foo for usize {
const BOOL: Bool = Bool::True;
}
fn main() {
if <usize as Foo>::BOOL == Bool::True {
println!("foo");
}
}
"#,
)
}

#[test]
fn const_non_bool() {
cov_mark::check!(not_applicable_non_bool_const);
Expand Down
6 changes: 3 additions & 3 deletions crates/ide-assists/src/tests/generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,10 @@ fn main() {
}
"#####,
r#####"
fn main() {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let bool = Bool::True;
if bool == Bool::True {
Expand Down

0 comments on commit 87e2c31

Please sign in to comment.