Skip to content

Commit

Permalink
DRAFT: Start generating types/commands per feature
Browse files Browse the repository at this point in the history
  • Loading branch information
MarijnS95 committed Dec 8, 2024
1 parent 97eab2c commit 3229342
Showing 1 changed file with 132 additions and 68 deletions.
200 changes: 132 additions & 68 deletions generator/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -912,22 +912,33 @@ impl FieldExt for vk_parse::CommandParam {
}
}

pub type CommandMap<'a> = HashMap<vkxml::Identifier, &'a vk_parse::CommandDefinition>;
/// Minimized, generalized description of [`vk_parse::Feature`] and [`vk_parse::Extension`]
#[derive(Clone, Debug)]
pub struct FeatureDescription<'a> {
pub name: &'a str,
pub provisional: bool,
pub children: &'a Vec<vk_parse::ExtensionChild>,
}

// TODO: Why is this a map?
pub type CommandMap<'a> =
HashMap<&'a str, (&'a vk_parse::CommandDefinition, &'a FeatureDescription<'a>)>;

/// Returns (raw bindings, function pointer table)
fn generate_function_pointers<'a>(
ident: Ident,
commands: &[&'a vk_parse::CommandDefinition],
rename_commands: &HashMap<&'a str, &'a str>,
commands: &[(&'a vk_parse::CommandDefinition, &'a FeatureDescription<'_>)],
rename_commands: &HashMap<&str, &str>,
fn_cache: &mut HashSet<&'a str>,
has_lifetimes: &HashSet<Ident>,
doc: &str,
) -> (TokenStream, TokenStream) {
// Commands can have duplicates inside them because they are declared per features. But we only
// really want to generate one function pointer.
// TODO: Didn't we have a map of this to make them unique already...?
let commands = commands
.iter()
.unique_by(|cmd| cmd.proto.name.as_str())
.unique_by(|cmd| cmd.0.proto.name.as_str())
.collect::<Vec<_>>();

struct Command<'a> {
Expand All @@ -945,7 +956,7 @@ fn generate_function_pointers<'a>(
let commands = commands
.iter()
.map(|cmd| {
let name = &cmd.proto.name;
let name = &cmd.0.proto.name;
let pfn_type_name = format_ident!("PFN_{}", name);

// We might need to generate a function pointer for an extension, where we are given the original
Expand All @@ -957,6 +968,7 @@ fn generate_function_pointers<'a>(
let type_name = format_ident!("{}", type_name);

let params = cmd
.0
.params
.iter()
.filter(|param| matches!(param.api.as_deref(), None | Some(DESIRED_API)));
Expand Down Expand Up @@ -992,6 +1004,7 @@ fn generate_function_pointers<'a>(
.collect();

let ret = cmd
.0
.proto
.type_name
.as_ref()
Expand Down Expand Up @@ -1307,6 +1320,7 @@ pub fn generate_extension_commands<'a>(
let mut instance_commands = Vec::new();
let mut device_commands = Vec::new();

// TODO: Is this the only reason why commands are stored in a map by name?
let mut rename_commands = HashMap::new();
let names = extension
.children
Expand All @@ -1330,7 +1344,7 @@ pub fn generate_extension_commands<'a>(
}

let command = cmd_map[name];
match command.function_type() {
match command.0.function_type() {
FunctionType::Static | FunctionType::Entry => unreachable!(),
FunctionType::Instance => instance_commands.push(command),
FunctionType::Device => device_commands.push(command),
Expand Down Expand Up @@ -1465,7 +1479,7 @@ pub fn generate_extension_commands<'a>(

pub fn generate_define(
define: &vk_parse::Type,
allowed_types: &HashSet<&str>,
allowed_types: &HashMap<&str, FeatureDescription<'_>>,
identifier_renames: &mut BTreeMap<String, Ident>,
) -> TokenStream {
let vk_parse::TypeSpec::Code(spec) = &define.spec else {
Expand All @@ -1475,7 +1489,7 @@ pub fn generate_define(
return quote!();
};

if !allowed_types.contains(define_name.as_str()) {
if !allowed_types.contains_key(define_name.as_str()) {
return quote!();
}

Expand Down Expand Up @@ -2713,7 +2727,7 @@ fn generate_union(union: &vkxml::Union, has_lifetimes: &HashSet<Ident>) -> Token
/// Root structs are all structs that are extended by other structs.
pub fn root_structs(
definitions: &[&vk_parse::Type],
allowed_types: &HashSet<&str>,
allowed_types: &HashMap<&str, FeatureDescription<'_>>,
) -> HashSet<Ident> {
// Loop over all structs and collect their extends
definitions
Expand All @@ -2722,7 +2736,7 @@ pub fn root_structs(
type_
.name
.as_ref()
.map_or(false, |name| allowed_types.contains(name.as_str()))
.map_or(false, |name| allowed_types.contains_key(name.as_str()))
})
.filter_map(|type_| type_.structextends.as_ref())
.flat_map(|e| e.split(','))
Expand All @@ -2731,7 +2745,7 @@ pub fn root_structs(
}
pub fn generate_definition_vk_parse(
definition: &vk_parse::Type,
allowed_types: &HashSet<&str>,
allowed_types: &HashMap<&str, FeatureDescription<'_>>,
identifier_renames: &mut BTreeMap<String, Ident>,
) -> Option<TokenStream> {
if let Some(api) = &definition.api {
Expand All @@ -2752,7 +2766,7 @@ pub fn generate_definition_vk_parse(
#[allow(clippy::too_many_arguments)]
pub fn generate_definition(
definition: &vkxml::DefinitionsElement,
allowed_types: &HashSet<&str>,
allowed_types: &HashMap<&str, FeatureDescription<'_>>,
union_types: &HashSet<&str>,
root_structs: &HashSet<Ident>,
has_lifetimes: &HashSet<Ident>,
Expand All @@ -2762,12 +2776,12 @@ pub fn generate_definition(
) -> Option<TokenStream> {
match *definition {
vkxml::DefinitionsElement::Typedef(ref typedef)
if allowed_types.contains(typedef.name.as_str()) =>
if allowed_types.contains_key(typedef.name.as_str()) =>
{
Some(generate_typedef(typedef))
}
vkxml::DefinitionsElement::Struct(ref struct_)
if allowed_types.contains(struct_.name.as_str()) =>
if allowed_types.contains_key(struct_.name.as_str()) =>
{
Some(generate_struct(
struct_,
Expand All @@ -2778,20 +2792,22 @@ pub fn generate_definition(
))
}
vkxml::DefinitionsElement::Bitmask(ref mask)
if allowed_types.contains(mask.name.as_str()) =>
if allowed_types.contains_key(mask.name.as_str()) =>
{
generate_bitmask(mask, bitflags_cache, const_values)
}
vkxml::DefinitionsElement::Handle(ref handle)
if allowed_types.contains(handle.name.as_str()) =>
if allowed_types.contains_key(handle.name.as_str()) =>
{
generate_handle(handle)
}
vkxml::DefinitionsElement::FuncPtr(ref fp) if allowed_types.contains(fp.name.as_str()) => {
vkxml::DefinitionsElement::FuncPtr(ref fp)
if allowed_types.contains_key(fp.name.as_str()) =>
{
Some(generate_funcptr(fp, has_lifetimes))
}
vkxml::DefinitionsElement::Union(ref union)
if allowed_types.contains(union.name.as_str()) =>
if allowed_types.contains_key(union.name.as_str()) =>
{
Some(generate_union(union, has_lifetimes))
}
Expand All @@ -2800,7 +2816,7 @@ pub fn generate_definition(
}
pub fn generate_feature<'a>(
feature: &vkxml::Feature,
commands: &CommandMap<'a>,
commands: &'a CommandMap<'a>,
fn_cache: &mut HashSet<&'a str>,
has_lifetimes: &HashSet<Ident>,
) -> (TokenStream, TokenStream) {
Expand All @@ -2814,11 +2830,11 @@ pub fn generate_feature<'a>(
.filter_map(get_variant!(vkxml::FeatureElement::Require))
.flat_map(|spec| &spec.elements)
.filter_map(get_variant!(vkxml::FeatureReference::CommandReference))
.filter_map(|cmd_ref| commands.get(&cmd_ref.name))
.filter_map(|cmd_ref| commands.get(cmd_ref.name.as_str()))
.fold(
(Vec::new(), Vec::new(), Vec::new(), Vec::new()),
|mut accs, &cmd_ref| {
let acc = match cmd_ref.function_type() {
let acc = match cmd_ref.0.function_type() {
FunctionType::Static => &mut accs.0,
FunctionType::Entry => &mut accs.1,
FunctionType::Device => &mut accs.2,
Expand Down Expand Up @@ -3078,37 +3094,23 @@ pub fn extract_native_types(registry: &vk_parse::Registry) -> (Vec<(String, Stri

(header_includes, header_types)
}
pub fn generate_aliases_of_types(
types: &vk_parse::Types,
allowed_types: &HashSet<&str>,
pub fn generate_alias_of_type(
ty: &vk_parse::Type,
has_lifetimes: &HashSet<Ident>,
ty_cache: &mut HashSet<Ident>,
) -> TokenStream {
let aliases = types
.children
.iter()
.filter_map(get_variant!(vk_parse::TypesChild::Type))
.filter_map(|ty| {
let name = ty.name.as_ref()?;
if !allowed_types.contains(name.as_str()) {
return None;
}
let alias = ty.alias.as_ref()?;
let name_ident = name_to_tokens(name);
if !ty_cache.insert(name_ident.clone()) {
return None;
};
let alias_ident = name_to_tokens(alias);
let tokens = if has_lifetimes.contains(&alias_ident) {
quote!(pub type #name_ident<'a> = #alias_ident<'a>;)
} else {
quote!(pub type #name_ident = #alias_ident;)
};
Some(tokens)
});
quote! {
#(#aliases)*
}
) -> Option<TokenStream> {
let name = ty.name.as_ref()?;
let alias = ty.alias.as_ref()?;
let name_ident = name_to_tokens(name);
if !ty_cache.insert(name_ident.clone()) {
return None;
};
let alias_ident = name_to_tokens(alias);
Some(if has_lifetimes.contains(&alias_ident) {
quote!(pub type #name_ident<'a> = #alias_ident<'a>;)
} else {
quote!(pub type #name_ident = #alias_ident;)
})
}
pub fn write_source_code<P: AsRef<Path>>(vk_headers_dir: &Path, src_dir: P) {
let vk_xml = vk_headers_dir.join("registry/vk.xml");
Expand Down Expand Up @@ -3161,30 +3163,80 @@ pub fn write_source_code<P: AsRef<Path>>(vk_headers_dir: &Path, src_dir: P) {
.flat_map(|constants| &constants.elements)
.collect();

// let features_children = spec2
// .0
// .iter()
// .filter_map(get_variant!(vk_parse::RegistryChild::Feature))
// .filter(|feature| contains_desired_api(&feature.api))
// .flat_map(|features| &features.children);

// let extension_children = extensions.iter().flat_map(|extension| &extension.children);

// let (required_types, required_commands) = features_children
// .chain(extension_children)
// .filter_map(get_variant!(vk_parse::FeatureChild::Require { api, items }))
// .filter(|(api, _items)| matches!(api.as_deref(), None | Some(DESIRED_API)))
// .flat_map(|(_api, items)| items)
// .fold((HashSet::new(), HashSet::new()), |mut acc, elem| {
// match elem {
// vk_parse::InterfaceItem::Type { name, .. } => {
// acc.0.insert(name.as_str());
// }
// vk_parse::InterfaceItem::Command { name, .. } => {
// acc.1.insert(name.as_str());
// }
// _ => {}
// };
// acc
// });

let features_children = spec2
.0
.iter()
.filter_map(get_variant!(vk_parse::RegistryChild::Feature))
.filter(|feature| contains_desired_api(&feature.api))
.flat_map(|features| &features.children);
.map(|features| FeatureDescription {
name: &features.name,
provisional: false,
children: &features.children,
});

let extension_children = extensions.iter().flat_map(|extension| &extension.children);
let extension_children = extensions.iter().map(|extension| FeatureDescription {
name: &extension.name,
provisional: extension.provisional,
children: &extension.children,
});

let (required_types, required_commands) = features_children
.chain(extension_children)
.filter_map(get_variant!(vk_parse::FeatureChild::Require { api, items }))
.filter(|(api, _items)| matches!(api.as_deref(), None | Some(DESIRED_API)))
.flat_map(|(_api, items)| items)
.fold((HashSet::new(), HashSet::new()), |mut acc, elem| {
match elem {
vk_parse::InterfaceItem::Type { name, .. } => {
acc.0.insert(name.as_str());
// .filter_map(get_variant!(vk_parse::FeatureChild::Require { api, items }))
// .filter(|(api, _items)| matches!(api.as_deref(), None | Some(DESIRED_API)))
// .flat_map(|(_api, items)| items)
.fold((HashMap::new(), HashMap::new()), |mut acc, feature| {
for child in feature.children {
let vk_parse::FeatureChild::Require { api, items, .. } = child else {
continue;
};
if !matches!(api.as_deref(), None | Some(DESIRED_API)) {
continue;
}
vk_parse::InterfaceItem::Command { name, .. } => {
acc.1.insert(name.as_str());
for elem in items {
let prev = match elem {
vk_parse::InterfaceItem::Type { name, .. } => {
acc.0.insert(name.as_str(), feature.clone())
}
vk_parse::InterfaceItem::Command { name, .. } => {
acc.1.insert(name.as_str(), feature.clone())
}
_ => continue,
};
if let Some(prev) = &prev {
dbg!(feature.name, prev.name);
}
// dbg!(&prev.name);
// assert!(prev.is_none());
}
_ => {}
};
}
acc
});

Expand All @@ -3194,8 +3246,11 @@ pub fn write_source_code<P: AsRef<Path>>(vk_headers_dir: &Path, src_dir: P) {
.filter_map(get_variant!(vk_parse::RegistryChild::Commands))
.flat_map(|cmds| &cmds.children)
.filter_map(get_variant!(vk_parse::Command::Definition))
.filter(|cmd| required_commands.contains(&cmd.proto.name.as_str()))
.map(|cmd| (cmd.proto.name.clone(), cmd))
.filter_map(|cmd| {
required_commands
.get(cmd.proto.name.as_str())
.map(|feature| (cmd.proto.name.as_str(), (cmd, feature)))
})
.collect();

let cmd_aliases: HashMap<_, _> = spec2
Expand All @@ -3204,7 +3259,8 @@ pub fn write_source_code<P: AsRef<Path>>(vk_headers_dir: &Path, src_dir: P) {
.filter_map(get_variant!(vk_parse::RegistryChild::Commands))
.flat_map(|cmds| &cmds.children)
.filter_map(get_variant!(vk_parse::Command::Alias { name, alias }))
.filter(|(name, _alias)| required_commands.contains(name.as_str()))
// TODO: Pass through feature info
.filter(|(name, _alias)| required_commands.contains_key(name.as_str()))
.map(|(name, alias)| (name.as_str(), alias.as_str()))
.collect();

Expand All @@ -3221,7 +3277,8 @@ pub fn write_source_code<P: AsRef<Path>>(vk_headers_dir: &Path, src_dir: P) {
.filter(|enums| enums.kind.is_some())
.filter(|enums| {
enums.name.as_ref().map_or(true, |n| {
required_types.contains(n.replace("FlagBits", "Flags").as_str())
// TODO: pass through feature info?
required_types.contains_key(n.replace("FlagBits", "Flags").as_str())
})
})
.map(|e| generate_enum(e, &mut const_cache, &mut const_values, &mut bitflags_cache))
Expand Down Expand Up @@ -3384,7 +3441,14 @@ pub fn write_source_code<P: AsRef<Path>>(vk_headers_dir: &Path, src_dir: P) {
.0
.iter()
.filter_map(get_variant!(vk_parse::RegistryChild::Types))
.map(|ty| generate_aliases_of_types(ty, &required_types, &has_lifetimes, &mut ty_cache))
.flat_map(|types| &types.children)
.filter_map(get_variant!(vk_parse::TypesChild::Type))
.filter(|ty| {
ty.name
.as_ref()
.is_some_and(|name| required_types.contains_key(name.as_str()))
})
.filter_map(|ty| generate_alias_of_type(ty, &has_lifetimes, &mut ty_cache))
.collect();

let (feature_fp_code, feature_table_code): (Vec<_>, Vec<_>) = features
Expand Down

0 comments on commit 3229342

Please sign in to comment.