Skip to content

Commit

Permalink
feat(derive): Add "required" option for groups
Browse files Browse the repository at this point in the history
This adds the "required" derive option for the group creation.
Needed for #4574.
  • Loading branch information
klnusbaum committed Feb 1, 2023
1 parent 956dc6a commit 2d3acf1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 0 deletions.
2 changes: 2 additions & 0 deletions clap_derive/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl Parse for ClapAttr {
let magic = match name_str.as_str() {
"rename_all" => Some(MagicAttrName::RenameAll),
"rename_all_env" => Some(MagicAttrName::RenameAllEnv),
"required" => Some(MagicAttrName::Required),
"skip" => Some(MagicAttrName::Skip),
"next_display_order" => Some(MagicAttrName::NextDisplayOrder),
"next_help_heading" => Some(MagicAttrName::NextHelpHeading),
Expand Down Expand Up @@ -168,6 +169,7 @@ pub enum MagicAttrName {
Version,
RenameAllEnv,
RenameAll,
Required,
Skip,
DefaultValueT,
DefaultValuesT,
Expand Down
2 changes: 2 additions & 0 deletions clap_derive/src/derives/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ pub fn gen_augment(
quote!()
} else {
let group_id = parent_item.ident().unraw().to_string();
let required = parent_item.required_group();
let literal_group_members = fields
.iter()
.filter_map(|(_field, item)| {
Expand Down Expand Up @@ -404,6 +405,7 @@ pub fn gen_augment(
.group(
clap::ArgGroup::new(#group_id)
.multiple(true)
.required(#required)
.args(#literal_group_members)
)
)
Expand Down
11 changes: 11 additions & 0 deletions clap_derive/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct Item {
next_help_heading: Option<Method>,
is_enum: bool,
is_positional: bool,
required_group: bool,
skip_group: bool,
kind: Sp<Kind>,
}
Expand Down Expand Up @@ -272,6 +273,7 @@ impl Item {
next_help_heading: None,
is_enum: false,
is_positional: true,
required_group: false,
skip_group: false,
kind,
}
Expand Down Expand Up @@ -824,6 +826,10 @@ impl Item {
self.env_casing = CasingStyle::from_lit(lit);
}

Some(MagicAttrName::Required) if actual_attr_kind == AttrKind::Group => {
self.required_group = true;
}

Some(MagicAttrName::Skip) if actual_attr_kind == AttrKind::Group => {
self.skip_group = true;
}
Expand All @@ -838,6 +844,7 @@ impl Item {
| Some(MagicAttrName::LongHelp)
| Some(MagicAttrName::Author)
| Some(MagicAttrName::Version)
| Some(MagicAttrName::Required)
=> {
let expr = attr.value_or_abort();
self.push_method(*attr.kind.get(), attr.name.clone(), expr);
Expand Down Expand Up @@ -1059,6 +1066,10 @@ impl Item {
.any(|m| m.name != "help" && m.name != "long_help")
}

pub fn required_group(&self) -> bool {
self.required_group
}

pub fn skip_group(&self) -> bool {
self.skip_group
}
Expand Down
56 changes: 56 additions & 0 deletions tests/derive/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,59 @@ fn helpful_panic_on_duplicate_groups() {
use clap::CommandFactory;
Opt::command().debug_assert();
}

#[test]
fn required_group() {
#[derive(Parser, Debug)]
struct Opt {
#[command(flatten)]
source: Source,
}

#[derive(clap::Args, Debug)]
#[group(required)]
struct Source {
crates: Vec<String>,
#[arg(long)]
path: Option<std::path::PathBuf>,
#[arg(long)]
git: Option<String>,
}

use clap::CommandFactory;
let target_id = clap::Id::from("Source");
let opt_command = Opt::command();
let source_group = opt_command
.get_groups()
.find(|g| g.get_id() == &target_id)
.unwrap();
assert!(source_group.is_required_set())
}

#[test]
fn required_group_with_arg() {
#[derive(Parser, Debug)]
struct Opt {
#[command(flatten)]
source: Source,
}

#[derive(clap::Args, Debug)]
#[group(required = true)]
struct Source {
crates: Vec<String>,
#[arg(long)]
path: Option<std::path::PathBuf>,
#[arg(long)]
git: Option<String>,
}

use clap::CommandFactory;
let target_id = clap::Id::from("Source");
let opt_command = Opt::command();
let source_group = opt_command
.get_groups()
.find(|g| g.get_id() == &target_id)
.unwrap();
assert!(source_group.is_required_set())
}

0 comments on commit 2d3acf1

Please sign in to comment.