Skip to content

Commit

Permalink
Extract a naga pub create_validator function for use in Bevy (#5606)
Browse files Browse the repository at this point in the history
  • Loading branch information
atlv24 authored Apr 26, 2024
1 parent d4f3063 commit 739905e
Showing 1 changed file with 74 additions and 66 deletions.
140 changes: 74 additions & 66 deletions wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1480,9 +1480,78 @@ impl<A: HalApi> Device<A> {
};
}

use naga::valid::Capabilities as Caps;
profiling::scope!("naga::validate");
let debug_source =
if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() {
Some(hal::DebugSource {
file_name: Cow::Owned(
desc.label
.as_ref()
.map_or("shader".to_string(), |l| l.to_string()),
),
source_code: Cow::Owned(source.clone()),
})
} else {
None
};

let info = self
.create_validator(naga::valid::ValidationFlags::all())
.validate(&module)
.map_err(|inner| {
pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError {
source,
label: desc.label.as_ref().map(|l| l.to_string()),
inner: Box::new(inner),
})
})?;

let interface =
validation::Interface::new(&module, &info, self.limits.clone(), self.features);
let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
module,
info,
debug_source,
});
let hal_desc = hal::ShaderModuleDescriptor {
label: desc.label.to_hal(self.instance_flags),
runtime_checks: desc.shader_bound_checks.runtime_checks(),
};
let raw = match unsafe {
self.raw
.as_ref()
.unwrap()
.create_shader_module(&hal_desc, hal_shader)
} {
Ok(raw) => raw,
Err(error) => {
return Err(match error {
hal::ShaderError::Device(error) => {
pipeline::CreateShaderModuleError::Device(error.into())
}
hal::ShaderError::Compilation(ref msg) => {
log::error!("Shader error: {}", msg);
pipeline::CreateShaderModuleError::Generation
}
})
}
};

Ok(pipeline::ShaderModule {
raw: Some(raw),
device: self.clone(),
interface: Some(interface),
info: ResourceInfo::new(desc.label.borrow_or_default(), None),
label: desc.label.borrow_or_default().to_string(),
})
}

/// Create a validator with the given validation flags.
pub fn create_validator(
self: &Arc<Self>,
flags: naga::valid::ValidationFlags,
) -> naga::valid::Validator {
use naga::valid::Capabilities as Caps;
let mut caps = Caps::empty();
caps.set(
Caps::PUSH_CONSTANT,
Expand Down Expand Up @@ -1560,20 +1629,6 @@ impl<A: HalApi> Device<A> {
self.features.intersects(wgt::Features::SUBGROUP_BARRIER),
);

let debug_source =
if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() {
Some(hal::DebugSource {
file_name: Cow::Owned(
desc.label
.as_ref()
.map_or("shader".to_string(), |l| l.to_string()),
),
source_code: Cow::Owned(source.clone()),
})
} else {
None
};

let mut subgroup_stages = naga::valid::ShaderStages::empty();
subgroup_stages.set(
naga::valid::ShaderStages::COMPUTE | naga::valid::ShaderStages::FRAGMENT,
Expand All @@ -1590,57 +1645,10 @@ impl<A: HalApi> Device<A> {
} else {
naga::valid::SubgroupOperationSet::empty()
};

let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps)
.subgroup_stages(subgroup_stages)
.subgroup_operations(subgroup_operations)
.validate(&module)
.map_err(|inner| {
pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError {
source,
label: desc.label.as_ref().map(|l| l.to_string()),
inner: Box::new(inner),
})
})?;

let interface =
validation::Interface::new(&module, &info, self.limits.clone(), self.features);
let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
module,
info,
debug_source,
});
let hal_desc = hal::ShaderModuleDescriptor {
label: desc.label.to_hal(self.instance_flags),
runtime_checks: desc.shader_bound_checks.runtime_checks(),
};
let raw = match unsafe {
self.raw
.as_ref()
.unwrap()
.create_shader_module(&hal_desc, hal_shader)
} {
Ok(raw) => raw,
Err(error) => {
return Err(match error {
hal::ShaderError::Device(error) => {
pipeline::CreateShaderModuleError::Device(error.into())
}
hal::ShaderError::Compilation(ref msg) => {
log::error!("Shader error: {}", msg);
pipeline::CreateShaderModuleError::Generation
}
})
}
};

Ok(pipeline::ShaderModule {
raw: Some(raw),
device: self.clone(),
interface: Some(interface),
info: ResourceInfo::new(desc.label.borrow_or_default(), None),
label: desc.label.borrow_or_default().to_string(),
})
let mut validator = naga::valid::Validator::new(flags, caps);
validator.subgroup_stages(subgroup_stages);
validator.subgroup_operations(subgroup_operations);
validator
}

#[allow(unused_unsafe)]
Expand Down

0 comments on commit 739905e

Please sign in to comment.