Skip to content

Commit

Permalink
Add --derives
Browse files Browse the repository at this point in the history
# Conflicts:
#	crates/libs/bindgen/src/rust/mod.rs
#	crates/libs/bindgen/src/rust/standalone.rs
#	crates/libs/bindgen/src/rust/structs.rs
#	crates/tests/standalone/build.rs
  • Loading branch information
dpaoliello committed Feb 27, 2024
1 parent 503dd9f commit 1b01493
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 23 deletions.
47 changes: 46 additions & 1 deletion crates/libs/bindgen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ enum ArgKind {
Output,
Filter,
Config,
Derives,
}

/// Windows metadata compiler.
Expand All @@ -38,6 +39,7 @@ where
let mut exclude = Vec::<&str>::new();
let mut config = std::collections::BTreeMap::<&str, &str>::new();
let mut format = false;
let mut derives = std::collections::BTreeMap::<(&str, &str), tokens::TokenStream>::new();

for arg in &args {
if arg.starts_with('-') {
Expand All @@ -50,6 +52,7 @@ where
"-o" | "--out" => kind = ArgKind::Output,
"-f" | "--filter" => kind = ArgKind::Filter,
"--config" => kind = ArgKind::Config,
"--derives" => kind = ArgKind::Derives,
"--format" => format = true,
_ => return Err(Error::new(&format!("invalid option `{arg}`"))),
},
Expand All @@ -75,6 +78,22 @@ where
config.insert(arg, "");
}
}
ArgKind::Derives => {
if let Some((ty, traits)) = arg.split_once('=') {
if let Some(last_dot) = ty.rfind('.') {
let name = &ty[last_dot + 1..];
let namespace = &ty[..last_dot];
let traits: tokens::TokenStream = traits.into();
if derives.insert((namespace, name), quote! { #[derive(#traits)] }).is_some() {
return Err(Error::new(&format!("Duplicate entry for type `{ty}` in --derives")));
}
} else {
return Err(Error::new(&format!("The type `{ty}` in --derives must be fully qualified")));
}
} else {
return Err(Error::new(&format!("Invalid format for --derives, expected ty=traits, actual: `{arg}`")));
}
}
}
}

Expand Down Expand Up @@ -113,10 +132,21 @@ where

winmd::verify(reader)?;

let unused_derives = derives.keys().filter(|(namespace, name)| reader.get_type_def(namespace, name).next().is_none()).collect::<Vec<_>>();
if !unused_derives.is_empty() {
let mut message = "unused derives".to_string();

for (namespace, name) in unused_derives {
message.push_str(&format!("\n {namespace}.{name}"));
}

return Err(Error::new(&message));
}

match extension(&output) {
"rdl" => rdl::from_reader(reader, config, &output)?,
"winmd" => winmd::from_reader(reader, config, &output)?,
"rs" => rust::from_reader(reader, config, &output)?,
"rs" => rust::from_reader(reader, config, &derives, &output)?,
_ => return Err(Error::new("output extension must be one of winmd/rdl/rs")),
}

Expand Down Expand Up @@ -262,3 +292,18 @@ fn extension(path: &str) -> &str {
fn directory(path: &str) -> &str {
path.rsplit_once(&['/', '\\']).map_or("", |(directory, _)| directory)
}

#[test]
fn bad_derive_args() {
let result = bindgen(&["--derives", "Foo"]).unwrap_err().to_string();
assert_eq!(result, "error: Invalid format for --derives, expected ty=traits, actual: `Foo`\n");

let result = bindgen(&["--derives", "Foo=bar"]).unwrap_err().to_string();
assert_eq!(result, "error: The type `Foo` in --derives must be fully qualified\n");

let result = bindgen(&["--derives", "Foo.Bar=bar", "Foo.Bar=baz"]).unwrap_err().to_string();
assert_eq!(result, "error: Duplicate entry for type `Foo.Bar` in --derives\n");

let result = bindgen(&["--out", "test.rs", "--filter", "Windows.Win32.System.Com.CoInitialize", "--derives", "Foo.Bar=bar"]).unwrap_err().to_string();
assert_eq!(result, "error: unused derives\n Foo.Bar\n");
}
22 changes: 11 additions & 11 deletions crates/libs/bindgen/src/rust/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use index::*;
use rayon::prelude::*;
use writer::*;

pub fn from_reader(reader: &'static metadata::Reader, mut config: std::collections::BTreeMap<&str, &str>, output: &str) -> Result<()> {
pub fn from_reader(reader: &'static metadata::Reader, mut config: std::collections::BTreeMap<&str, &str>, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>, output: &str) -> Result<()> {
let mut writer = Writer::new(reader, output);
writer.package = config.remove("package").is_some();
writer.flatten = config.remove("flatten").is_some();
Expand All @@ -47,32 +47,32 @@ pub fn from_reader(reader: &'static metadata::Reader, mut config: std::collectio
}

if writer.package {
gen_package(&writer)
gen_package(&writer, derives)
} else {
gen_file(&writer)
gen_file(&writer, derives)
}
}

fn gen_file(writer: &Writer) -> Result<()> {
fn gen_file(writer: &Writer, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>) -> Result<()> {
// TODO: harmonize this output code so we don't need these two wildly differnt code paths
// there should be a simple way to generate the with or without namespaces.

if writer.flatten {
let tokens = standalone::standalone_imp(writer);
let tokens = standalone::standalone_imp(writer, derives);
write_to_file(&writer.output, try_format(writer, &tokens))
} else {
let mut tokens = String::new();
let root = Tree::new(writer.reader);

for tree in root.nested.values() {
tokens.push_str(&namespace(writer, tree));
tokens.push_str(&namespace(writer, tree, derives));
}

write_to_file(&writer.output, try_format(writer, &tokens))
}
}

fn gen_package(writer: &Writer) -> Result<()> {
fn gen_package(writer: &Writer, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>) -> Result<()> {
let directory = directory(&writer.output);
let root = Tree::new(writer.reader);
let mut root_len = 0;
Expand All @@ -86,7 +86,7 @@ fn gen_package(writer: &Writer) -> Result<()> {

trees.par_iter().try_for_each(|tree| {
let directory = format!("{directory}/{}", tree.namespace.replace('.', "/"));
let mut tokens = namespace(writer, tree);
let mut tokens = namespace(writer, tree, derives);

let tokens_impl = if !writer.sys { namespace_impl(writer, tree) } else { String::new() };

Expand Down Expand Up @@ -143,7 +143,7 @@ use std::fmt::Write;
use tokens::*;
use try_format::*;

fn namespace(writer: &Writer, tree: &Tree) -> String {
fn namespace(writer: &Writer, tree: &Tree, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>) -> String {
let writer = &mut writer.clone();
writer.namespace = tree.namespace;
let mut tokens = TokenStream::new();
Expand All @@ -159,7 +159,7 @@ fn namespace(writer: &Writer, tree: &Tree) -> String {
} else {
tokens.combine(&quote! { pub mod #name });
tokens.push_str("{");
tokens.push_str(&namespace(writer, tree));
tokens.push_str(&namespace(writer, tree, derives));
tokens.push_str("}");
}
}
Expand Down Expand Up @@ -200,7 +200,7 @@ fn namespace(writer: &Writer, tree: &Tree) -> String {
continue;
}
}
types.entry(kind).or_default().entry(name).or_default().combine(&structs::writer(writer, def));
types.entry(kind).or_default().entry(name).or_default().combine(&structs::writer(writer, def, derives.get(&(def.namespace(), name))));
}
metadata::TypeKind::Delegate => types.entry(kind).or_default().entry(name).or_default().combine(&delegates::writer(writer, def)),
}
Expand Down
4 changes: 2 additions & 2 deletions crates/libs/bindgen/src/rust/standalone.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::*;
use metadata::AsRow;

pub fn standalone_imp(writer: &Writer) -> String {
pub fn standalone_imp(writer: &Writer, derives: &std::collections::BTreeMap<(&str, &str), TokenStream>) -> String {
let mut types = std::collections::BTreeSet::new();
let mut functions = std::collections::BTreeSet::new();
let mut constants = std::collections::BTreeSet::new();
Expand Down Expand Up @@ -112,7 +112,7 @@ pub fn standalone_imp(writer: &Writer) -> String {
continue;
}
}
sorted.insert(name, structs::writer(writer, def));
sorted.insert(name, structs::writer(writer, def, derives.get(&(def.namespace(), name))));
}
metadata::TypeKind::Delegate => {
sorted.insert(def.name(), delegates::writer(writer, def));
Expand Down
9 changes: 5 additions & 4 deletions crates/libs/bindgen/src/rust/structs.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::*;
use metadata::HasAttributes;

pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream {
pub fn writer(writer: &Writer, def: metadata::TypeDef, derives: Option<&TokenStream>) -> TokenStream {
if def.has_attribute("ApiContractAttribute") {
return quote! {};
}
Expand All @@ -10,10 +10,10 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream {
return handles::writer(writer, def);
}

gen_struct_with_name(writer, def, def.name(), &cfg::Cfg::default())
gen_struct_with_name(writer, def, def.name(), &cfg::Cfg::default(), derives)
}

fn gen_struct_with_name(writer: &Writer, def: metadata::TypeDef, struct_name: &str, cfg: &cfg::Cfg) -> TokenStream {
fn gen_struct_with_name(writer: &Writer, def: metadata::TypeDef, struct_name: &str, cfg: &cfg::Cfg, derives: Option<&TokenStream>) -> TokenStream {
let name = to_ident(struct_name);

if def.fields().next().is_none() {
Expand Down Expand Up @@ -81,6 +81,7 @@ fn gen_struct_with_name(writer: &Writer, def: metadata::TypeDef, struct_name: &s
let mut tokens = quote! {
#repr
#features
#derives
pub #struct_or_union #name {#(#fields)*}
};

Expand All @@ -103,7 +104,7 @@ fn gen_struct_with_name(writer: &Writer, def: metadata::TypeDef, struct_name: &s

for (index, nested_type) in writer.reader.nested_types(def).enumerate() {
let nested_name = format!("{struct_name}_{index}");
tokens.combine(&gen_struct_with_name(writer, nested_type, &nested_name, &cfg));
tokens.combine(&gen_struct_with_name(writer, nested_type, &nested_name, &cfg, None));
}

tokens
Expand Down
29 changes: 24 additions & 5 deletions crates/tests/standalone/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,33 +165,47 @@ fn main() {
"src/b_vtbl_4.rs",
&["Windows.Win32.System.Com.IPersistFile"],
);

// Ensure that derives adds the #[derive(...)] attribute.
write_derives(
"src/b_derives.rs",
&["Windows.Foundation.DateTime"],
&[
"Windows.Foundation.DateTime=::core::cmp::PartialOrd,::core::cmp::Ord",
],
);
}

fn write_sys(output: &str, filter: &[&str]) {
riddle(output, filter, &["flatten", "sys", "minimal"]);
riddle(output, filter, &["flatten", "sys", "minimal"], None);
}

fn write_win(output: &str, filter: &[&str]) {
riddle(output, filter, &["flatten", "minimal"]);
riddle(output, filter, &["flatten", "minimal"], None);
}

fn write_std(output: &str, filter: &[&str]) {
riddle(output, filter, &["flatten", "std", "minimal"]);
riddle(output, filter, &["flatten", "std", "minimal"], None);
}

fn write_no_inner_attr(output: &str, filter: &[&str]) {
riddle(
output,
filter,
&["flatten", "no-inner-attributes", "minimal"],
None,
);
}

fn write_vtbl(output: &str, filter: &[&str]) {
riddle(output, filter, &["flatten", "sys", "minimal", "vtbl"]);
riddle(output, filter, &["flatten", "sys", "minimal", "vtbl"], None);
}

fn riddle(output: &str, filter: &[&str], config: &[&str]) {
fn write_derives(output: &str, filter: &[&str], derives: &[&str]) {
riddle(output, filter, &["flatten", "minimal"], Some(derives));
}

fn riddle(output: &str, filter: &[&str], config: &[&str], derives: Option<&[&str]>) {
// Rust-analyzer may re-run build scripts whenever a source file is deleted
// which causes an endless loop if the file is deleted from a build script.
// To workaround this, we truncate the file instead of deleting it.
Expand Down Expand Up @@ -221,6 +235,11 @@ fn riddle(output: &str, filter: &[&str], config: &[&str]) {
command.arg("--config");
command.args(config);

if let Some(derives) = derives {
command.arg("--derives");
command.args(derives);
}

if !command.status().unwrap().success() {
panic!("Failed to run riddle");
}
Expand Down
45 changes: 45 additions & 0 deletions crates/tests/standalone/src/b_derives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Bindings generated by `windows-bindgen` 0.54.0

#![allow(
non_snake_case,
non_upper_case_globals,
non_camel_case_types,
dead_code,
clippy::all
)]
#[repr(C)]
#[derive(::core::cmp::PartialOrd, ::core::cmp::Ord)]
pub struct DateTime {
pub UniversalTime: i64,
}
impl ::core::marker::Copy for DateTime {}
impl ::core::clone::Clone for DateTime {
fn clone(&self) -> Self {
*self
}
}
impl ::core::fmt::Debug for DateTime {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
f.debug_struct("DateTime")
.field("UniversalTime", &self.UniversalTime)
.finish()
}
}
impl ::windows_core::TypeKind for DateTime {
type TypeKind = ::windows_core::CopyType;
}
impl ::windows_core::RuntimeType for DateTime {
const SIGNATURE: ::windows_core::imp::ConstBuffer =
::windows_core::imp::ConstBuffer::from_slice(b"struct(Windows.Foundation.DateTime;i8)");
}
impl ::core::cmp::PartialEq for DateTime {
fn eq(&self, other: &Self) -> bool {
self.UniversalTime == other.UniversalTime
}
}
impl ::core::cmp::Eq for DateTime {}
impl ::core::default::Default for DateTime {
fn default() -> Self {
unsafe { ::core::mem::zeroed() }
}
}
18 changes: 18 additions & 0 deletions crates/tests/standalone/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod b_bstr;
mod b_calendar;
mod b_constant_types;
mod b_depends;
mod b_derives;
mod b_enumeration;
mod b_enumerator;
mod b_guid;
Expand Down Expand Up @@ -184,3 +185,20 @@ fn from_included() {
included::GetVersion();
}
}

#[test]
fn derive_ord() {
use b_derives::*;
let mut dates = [
DateTime { UniversalTime: 123 },
DateTime { UniversalTime: 42 },
];
dates.sort();
assert_eq!(
&dates,
&[
DateTime { UniversalTime: 42 },
DateTime { UniversalTime: 123 }
]
);
}
1 change: 1 addition & 0 deletions crates/tools/riddle/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Options:
--config <key=value> Override a configuration value
--format Format .rdl files only
--etc <path> File containing command line options
--derives <ty=traits> Emit a derive attribute for a type
"#
);
} else {
Expand Down

0 comments on commit 1b01493

Please sign in to comment.