Skip to content

Commit

Permalink
Implement Varargs
Browse files Browse the repository at this point in the history
  • Loading branch information
fcard committed Sep 12, 2019
1 parent a692ccc commit 66aa702
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 60 deletions.
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ fn main() {
}
```

Methods that Return references requires the call to be expressed as `(FUNC.rr)(args...)`. Hopefully this won't be necessary in the future. (or I will find a nicer syntax, at least)
Methods that return references requires the call to be expressed as `(FUNC.rr)(args...)`. Hopefully this won't be necessary in the future. (or I will find a nicer syntax, at least)

```rust
multifunction! {
Expand Down Expand Up @@ -287,6 +287,37 @@ fn main() {
}
```

A variadic method can be defined using the special `Vararg![T]` macro. The type of the variadic
argument is `multimethods::types::vararg::Vararg<T>`, which can be iterated through and indexed.

```rust
// Vararg doesn't need to be imported as it's merely a marker for the multifunction! macro
use multimethods::multifunction;

multifunction! {
fn SUM(args: Vararg![i32]) -> i32 {
args.iter().sum()
}
}

// Vararg![] is equivalent to Vararg![Abstract![ANY]]
multifunction! {
fn PRINT_ALL(args: Vararg![]) {
for arg in args {
println!("{}", arg)
}
}
}

fn main() {
println!("{}", SUM(1, 2, 3)); // 6

PRINT_ALL("a", 2); // a
// 2
}
```


## Limitations

* Only up to 12 arguments per method are allowed. This number was chosen as it is the largest size of a tuple that has trait implementations for it in the standard library.
Expand Down
177 changes: 136 additions & 41 deletions multimethods_proc/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
#![feature(decl_macro)]

extern crate proc_macro;
use std::collections::HashMap;

use proc_macro as pm;
use proc_macro2 as pm2;
use syn::*;
use syn::punctuated::Punctuated;
use syn::token::{Paren, Comma};
use quote::*;

macro ident($str: literal$(, $expr: expr)*) {
Ident::new(&format!($str$(, $expr)*), pm2::Span::call_site())
}

const MAX_ARGS: usize = 12;

struct Method {
public: bool,
expr: pm2::TokenStream,
Expand All @@ -19,7 +26,7 @@ struct Keys(Vec<Ident>);

impl syn::parse::Parse for Keys {
fn parse(input: syn::parse::ParseStream) -> Result<Self> {
let punct = <Punctuated<Ident, Comma>>::parse_terminated(input).unwrap();
let punct = <Punctuated<Ident, Token![,]>>::parse_terminated(input).unwrap();
Ok(Keys(punct.into_iter().collect()))
}
}
Expand Down Expand Up @@ -216,18 +223,29 @@ fn method_defs<'a, I: Iterator<Item=&'a ItemFn>>(item_fns: I, fmc: bool) -> Meth
let mut methods = Methods::new();

for item_fn in item_fns {
let root = root(fmc);
let name = item_fn.sig.ident.clone();
let num_args = item_fn.sig.inputs.len();
let is_abstract = has_abstract_type(args(&item_fn.sig));
let types = types(args(&item_fn.sig), &item_fn.sig.output, is_abstract, fmc);
let closure = create_closure(&item_fn, fmc);
let insertion = get_insertion_function(is_abstract);
let variant = get_variant(num_args, &item_fn.sig.output);
let inner_func = get_inner_function(num_args, &item_fn.sig.output, fmc);
let constructor = get_inner_constructor(args(&item_fn.sig), &item_fn.sig.output);
let inner_trait = get_inner_trait(args(&item_fn.sig), &item_fn.sig.output, fmc);
let has_vararg = has_vararg_type(args(&item_fn.sig));

let is_abstract;
let match_value;
let insertion;

if has_vararg {
let funcs = vararg_functions(&item_fn, fmc);
let values = funcs.iter().map(|f| function(f,fmc));
let positionals = item_fn.sig.inputs.len() - 1;

is_abstract = true;
match_value = quote!((#positionals, ::std::vec![#(#values),*]));
insertion = quote!(insert_vararg);

} else {
is_abstract = has_abstract_type(args(&item_fn.sig));
match_value = function(item_fn, fmc);
insertion = get_insertion_function(is_abstract);
}

let types = types(args(&item_fn.sig), &item_fn.sig.output, is_abstract, has_vararg, fmc);

if !methods.contains_key(&name) {
methods.insert(name.clone(), Vec::new());
Expand All @@ -238,28 +256,60 @@ fn method_defs<'a, I: Iterator<Item=&'a ItemFn>>(item_fns: I, fmc: bool) -> Meth
of_func.push(
Method {
public: if let Visibility::Public(_) = item_fn.vis { true } else { false },

expr: if num_args == 0 {
quote! {
table.#insertion(
#types,
#root Function::#variant(#inner_func::new(#closure))
)
}
} else {
quote! {
table.#insertion(
#types,
#root Function::#variant(<#inner_func as #inner_trait>::#constructor(#closure))
)
}
}
expr: quote!(table.#insertion(#types, #match_value))
}
);
}
methods
}


fn vararg_functions(origin: &ItemFn, fmc: bool) -> Vec<ItemFn> {
let root = root(fmc);
let vis = &origin.vis;
let body = &origin.block;
let output = &origin.sig.output;
let num_args = origin.sig.inputs.len() - 1;
let p_args = args(&origin.sig).take(num_args).collect::<Vec<_>>();
let p_names = args(&origin.sig).take(num_args).map(arg_name).collect::<Vec<_>>();
let vararg_arg = origin.sig.inputs.iter().last().unwrap();
let vararg_name = arg_name(vararg_arg.clone());
let vararg_type = vararg(&arg_type(vararg_arg.clone()), fmc);

let mut functions = Vec::new();

for i in (num_args..=MAX_ARGS).rev() {
let vs = (i..MAX_ARGS).map(|j| ident!("__VArg_Multimethods_{}", j)).collect::<Vec<_>>();

functions.push(parse2(quote! {
#vis fn _f(#(#p_args,)* #(#vs: #vararg_type),*) #output {
let __VarargCall = |#(#p_args,)* #vararg_name: #root Vararg<#vararg_type>| #body;
__VarargCall(#(#p_names,)* #root Vararg::new(::std::vec![#(#vs),*]))
}
}).unwrap());
}
functions
}


fn function(item_fn: &ItemFn, fmc: bool) -> pm2::TokenStream {
let root = root(fmc);
let num_args = item_fn.sig.inputs.len();
let closure = create_closure(&item_fn, fmc);
let variant = get_variant(num_args, &item_fn.sig.output);
let inner_func = get_inner_function(num_args, &item_fn.sig.output, fmc);
let constructor = get_inner_constructor(args(&item_fn.sig), &item_fn.sig.output);
let inner_trait = get_inner_trait(args(&item_fn.sig), &item_fn.sig.output, fmc);

if num_args == 0 {
quote!(#root Function::#variant(#inner_func::new(#closure)))

} else {
quote!(#root Function::#variant(<#inner_func as #inner_trait>::#constructor(#closure)))
}
}


fn is_public<I: Iterator<Item=bool>>(vis: I) -> bool {
let mut public = None;

Expand Down Expand Up @@ -301,12 +351,12 @@ fn args(sig: &Signature) -> impl Iterator<Item=FnArg> {
}


fn types<I>(inputs: I, output: &ReturnType, is_abs: bool, fmc: bool) -> pm2::TokenStream
fn types<I>(inputs: I, output: &ReturnType, is_abs: bool, has_var: bool, fmc: bool) -> pm2::TokenStream
where
I: Iterator<Item=FnArg>
{
if is_abs {
type_matches(inputs, output, fmc)
type_matches(inputs, output, has_var, fmc)

} else {
type_ids(inputs, output, fmc)
Expand All @@ -329,7 +379,7 @@ fn type_ids<I>(inputs: I, output: &ReturnType, fmc: bool) -> pm2::TokenStream
types.push(quote!(<#ty as #root TypeOf>::associated_type_of()));
}

let variant = Ident::new(&format!("T{}", types.len()), pm2::Span::call_site());
let variant = ident!("T{}", types.len());
let returns_ref = is_ref_return(output);

quote! {
Expand All @@ -338,7 +388,7 @@ fn type_ids<I>(inputs: I, output: &ReturnType, fmc: bool) -> pm2::TokenStream
}


fn type_matches<I>(inputs: I, output: &ReturnType, fmc: bool) -> pm2::TokenStream
fn type_matches<I>(inputs: I, output: &ReturnType, has_var: bool, fmc: bool) -> pm2::TokenStream
where
I: Iterator<Item=FnArg>
{
Expand All @@ -351,14 +401,17 @@ fn type_matches<I>(inputs: I, output: &ReturnType, fmc: bool) -> pm2::TokenStrea

for input in inputs {
let ty = arg_type(input);
let ty = if let Some(vty) = vararg(&ty, fmc) { vty } else { ty };

if let Some(aty) = abstract_type(&ty) {
types.push(quote!(#type_match::Abstract(#aty)));

} else {
types.push(quote!(#type_match::Concrete(<#ty as #sub_type>::#assoc_type())));
}
}

let variant = Ident::new(&format!("T{}", types.len()), pm2::Span::call_site());
let variant = ident!("{}{}", if has_var {"V"} else {"T"}, types.len());
let returns_ref = is_ref_return(output);

quote! {
Expand Down Expand Up @@ -459,6 +512,7 @@ fn get_inner_trait<I>(inputs: I, output: &ReturnType, fmc: bool) -> pm2::TokenSt
}
}


fn has_abstract_type<I>(inputs: I) -> bool
where
I: Iterator<Item=FnArg>
Expand All @@ -471,6 +525,20 @@ fn has_abstract_type<I>(inputs: I) -> bool
false
}

fn has_vararg_type<I>(inputs: I) -> bool
where
I: Iterator<Item=FnArg>
{
for input in inputs {
if vararg(&arg_type(input), false).is_some() {
return true;
}
}
false
}



fn root(fmc: bool) -> pm2::TokenStream {
if fmc {
quote!()
Expand All @@ -479,15 +547,20 @@ fn root(fmc: bool) -> pm2::TokenStream {
}
}

fn arg_name(arg: FnArg) -> Pat {
if let FnArg::Typed(pat) = arg {
*pat.pat
} else {
panic!("methods cannot have a `self` argument")
}
}

fn arg_type(arg: FnArg) -> Type {
if let FnArg::Typed(pat) = arg {
*pat.ty

} else {
Type::Tuple(TypeTuple {
elems: Punctuated::new(),
paren_token: Paren { span: pm2::Span::call_site() }
})
panic!("methods cannot have a `self` argument")
}
}

Expand Down Expand Up @@ -534,14 +607,13 @@ fn is_ref(ty: &Type) -> bool {
}
}

fn abstract_type(ty: &Type) -> Option<Ident> {
fn abstract_type(ty: &Type) -> Option<Expr> {
match ty {
Type::Paren(t) => abstract_type(&*t.elem),

Type::Macro(m) => {
if path_ends_with(&m.mac.path, "Abstract") {
let tokens = m.mac.tokens.clone();
parse2::<Ident>(tokens).ok()
parse2::<Expr>(m.mac.tokens.clone()).ok()

} else {
None
Expand All @@ -552,6 +624,30 @@ fn abstract_type(ty: &Type) -> Option<Ident> {
}
}

fn vararg(ty: &Type, fmc: bool) -> Option<Type> {
match ty {
Type::Paren(t) => vararg(&*t.elem, fmc),

Type::Macro(m) => {
if path_ends_with(&m.mac.path, "Vararg") {
if m.mac.tokens.is_empty() {
let root = root(fmc);
parse2(quote!(#root Abstract![#root ANY])).ok()

} else {
parse2(m.mac.tokens.clone()).ok()
}

} else {
None
}
}
_ => None
}
}



fn path_ends_with(p: &Path, s: &str) -> bool {
if let Some(segment) = p.segments.iter().last() {
segment.ident.to_string() == s.to_string()
Expand All @@ -561,7 +657,6 @@ fn path_ends_with(p: &Path, s: &str) -> bool {
}
}


#[allow(dead_code)]
fn arg_has_attr(f: &FnArg, attr: &str) -> bool {
match f {
Expand Down
29 changes: 29 additions & 0 deletions multimethods_tests/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,35 @@ mod readme {
assert_eq!(IS_MY_COLLECTION(coll2), true);
}
}

mod vararg {
use multimethods::{multifunction, FromValue};

multifunction! {
fn SUM(args: Vararg![i32]) -> i32 {
args.iter().sum()
}
}

multifunction! {
fn PRINT_ALL(args: Vararg![]) -> Vec<String> {
let mut result = Vec::new();
for arg in args {
result.push(format!("{}", arg))
}
result
}
}

#[test]
fn readme_vararg() {
assert_eq!(SUM(1, 2, 3), 6);
assert_eq!(
<Vec<String>>::from_value(PRINT_ALL("a", 2)),
vec!["a".to_string(), "2".to_string()]
);
}
}
}

fn main() {
Expand Down
Loading

0 comments on commit 66aa702

Please sign in to comment.