Skip to content

Commit

Permalink
feat: derive clap::Args for enums
Browse files Browse the repository at this point in the history
  • Loading branch information
ysndr committed Aug 25, 2024
1 parent cdc27b6 commit a0fa7da
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 3 deletions.
121 changes: 119 additions & 2 deletions clap_derive/src/derives/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use syn::{
punctuated::Punctuated, spanned::Spanned, token::Comma, Data, DataStruct, DeriveInput, Field,
Fields, FieldsNamed, Generics,
};
use syn::{DataEnum, Variant};

use crate::item::{Item, Kind, Name};
use crate::utils::{inner_type, sub_type, Sp, Ty};
Expand Down Expand Up @@ -51,10 +52,110 @@ pub(crate) fn derive_args(input: &DeriveInput) -> Result<TokenStream, syn::Error
.collect::<Result<Vec<_>, syn::Error>>()?;
gen_for_struct(&item, ident, &input.generics, &fields)
}
Data::Enum(DataEnum { ref variants, .. }) => {
let name = Name::Derived(ident.clone());
let item = Item::from_args_struct(input, name)?;

let variant_items = variants
.iter()
.map(|variant| {
let item = Item::from_args_enum_variant(variant)?;
Ok((item, variant))
})
.collect::<Result<Vec<_>, syn::Error>>()?;

gen_for_enum(&item, ident, &input.generics, &variant_items)
}
_ => abort_call_site!("`#[derive(Args)]` only supports non-tuple structs"),
}
}

pub(crate) fn gen_for_enum(
_item: &Item,
item_name: &Ident,
generics: &Generics,
variants: &[(Item, &Variant)],
) -> Result<TokenStream, syn::Error> {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let app_var = Ident::new("__clap_app", Span::call_site());
let mut augmentations = TokenStream::default();
let mut augmentations_update = TokenStream::default();

for (item, variant) in variants.iter() {
let Fields::Named(ref fields) = variant.fields else {
abort! { variant.span(),
"`#[derive(Args)]` only supports named enum variants if used on an enum",
}
};

let conflicts = variants
.iter()
.filter_map(|(_, v)| {
if v.ident == variant.ident {
None
} else {
Some(Name::Derived(v.ident.clone()))
}
})
.collect::<Vec<_>>();

let fields = collect_args_fields(&item, fields)?;

let augmentation = gen_augment(&fields, &app_var, item, &conflicts, false)?;
let augmentation = quote! {
let #app_var = #augmentation;
};
let augmentation_update = gen_augment(&fields, &app_var, item, &conflicts, true)?;
let augmentation_update = quote! {
let #app_var = #augmentation_update;
};

augmentations.extend(augmentation);
augmentations_update.extend(augmentation_update);
}

Ok(quote! {

#[allow(
dead_code,
unreachable_code,
unused_variables,
unused_braces,
unused_qualifications,
)]
#[allow(
clippy::style,
clippy::complexity,
clippy::pedantic,
clippy::restriction,
clippy::perf,
clippy::deprecated,
clippy::nursery,
clippy::cargo,
clippy::suspicious_else_formatting,
clippy::almost_swapped,
clippy::redundant_locals,
)]
#[automatically_derived]
impl #impl_generics clap::Args for #item_name #ty_generics #where_clause {
fn group_id() -> Option<clap::Id> {
// todo: how does this interact with nested groups here
None
}
fn augment_args<'b>(#app_var: clap::Command) -> clap::Command {
#augmentations
#app_var
}
fn augment_args_for_update<'b>(#app_var: clap::Command) -> clap::Command {
#augmentations_update
#app_var
}
}

})
}

pub(crate) fn gen_for_struct(
item: &Item,
item_name: &Ident,
Expand All @@ -75,8 +176,8 @@ pub(crate) fn gen_for_struct(
let raw_deprecated = raw_deprecated();

let app_var = Ident::new("__clap_app", Span::call_site());
let augmentation = gen_augment(fields, &app_var, item, false)?;
let augmentation_update = gen_augment(fields, &app_var, item, true)?;
let augmentation = gen_augment(fields, &app_var, item, &[], false)?;
let augmentation_update = gen_augment(fields, &app_var, item, &[], true)?;

let group_id = if item.skip_group() {
quote!(None)
Expand Down Expand Up @@ -170,6 +271,9 @@ pub(crate) fn gen_augment(
fields: &[(&Field, Item)],
app_var: &Ident,
parent_item: &Item,
// when generating mutably exclusive arguments,
// ids of arguments that should conflict
conflicts: &[Name],
override_required: bool,
) -> Result<TokenStream, syn::Error> {
let mut subcommand_specified = false;
Expand Down Expand Up @@ -420,12 +524,25 @@ pub(crate) fn gen_augment(

let group_methods = parent_item.group_methods();

let conflicts_method = if conflicts.is_empty() {
quote!()
} else {
let conflicts_len = conflicts.len();
quote! {
.conflicts_with_all({
let conflicts: [clap::Id; #conflicts_len] = [#( clap::Id::from(#conflicts) ),* ];
conflicts
})
}
};

quote!(
.group(
clap::ArgGroup::new(#group_id)
.multiple(true)
#group_methods
.args(#literal_group_members)
#conflicts_method
)
)
};
Expand Down
2 changes: 1 addition & 1 deletion clap_derive/src/derives/subcommand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ fn gen_augment(
Named(ref fields) => {
// Defer to `gen_augment` for adding cmd methods
let fields = collect_args_fields(item, fields)?;
args::gen_augment(&fields, &subcommand_var, item, override_required)?
args::gen_augment(&fields, &subcommand_var, item, &[], override_required)?
}
Unit => {
let arg_block = quote!( #subcommand_var );
Expand Down
61 changes: 61 additions & 0 deletions clap_derive/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,67 @@ impl Item {
Ok(res)
}

pub(crate) fn from_args_enum_variant(
variant: &Variant,
// struct_casing: Sp<CasingStyle>,
// env_casing: Sp<CasingStyle>,
) -> Result<Self, syn::Error> {

let name = variant.ident.clone();
let ident = variant.ident.clone();
let span = variant.span();

// todo: pass in as arguments
let argument_casing = Sp::new(DEFAULT_CASING, span);
let env_casing = Sp::new(DEFAULT_ENV_CASING, span);

let ty = match variant.fields {
syn::Fields::Unnamed(syn::FieldsUnnamed { ref unnamed, .. }) if unnamed.len() == 1 => {
Ty::from_syn_ty(&unnamed[0].ty)
}
syn::Fields::Named(_) | syn::Fields::Unnamed(..) | syn::Fields::Unit => {
Sp::new(Ty::Other, span)
}
};
let kind = Sp::new(Kind::Command(ty), span);
let mut res = Self::new(
Name::Derived(name),
ident,
None,
argument_casing,
env_casing,
kind,
);
let parsed_attrs = ClapAttr::parse_all(&variant.attrs)?;
res.infer_kind(&parsed_attrs)?;
res.push_attrs(&parsed_attrs)?;
if matches!(&*res.kind, Kind::Command(_) | Kind::Subcommand(_)) {
res.push_doc_comment(&variant.attrs, "about", Some("long_about"));
}

// TODO: ???
match &*res.kind {
Kind::Flatten(_) => {
if res.has_explicit_methods() {
abort!(
res.kind.span(),
"methods are not allowed for flattened entry"
);
}
}

Kind::Subcommand(_)
| Kind::ExternalSubcommand
| Kind::FromGlobal(_)
| Kind::Skip(_, _)
| Kind::Command(_)
| Kind::Value
| Kind::Arg(_) => (),
}

Ok(res)
}

pub(crate) fn from_subcommand_enum(
input: &DeriveInput,
name: Name,
Expand Down

0 comments on commit a0fa7da

Please sign in to comment.