Skip to content

Commit

Permalink
chore: dx12
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jun 14, 2024
1 parent e1376ac commit 733fd15
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 52 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ impl crate::Scalar {
match self {
Self {
kind: Sk::Float,
width: 4 | 8,
width: 4,
} => "float",
Self {
kind: Sk::Float,
Expand Down Expand Up @@ -1304,9 +1304,9 @@ impl<W: Write> Writer<W> {
write!(self.out, "NAN")?;
} else {
let suffix = if value.fract() == f16::from_f32(0.0) {
".0"
".0h"
} else {
""
"h"
};
write!(self.out, "{value}{suffix}")?;
}
Expand Down
9 changes: 2 additions & 7 deletions naga/src/front/wgsl/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,9 @@ impl PartialEq for Dependency<'_> {

impl Eq for Dependency<'_> {}

/// A module-scope declaration.
#[derive(Debug)]
pub struct GlobalDirective {
pub kind: GlobalDirectiveKind,
}

//A directive modifies how a WGSL program is processed by a WebGPU implementation.
#[derive(Debug)]
pub enum GlobalDirectiveKind {
pub enum GlobalDirective {
Enable(EnableDirective),
}

Expand Down
46 changes: 29 additions & 17 deletions naga/src/front/wgsl/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2167,6 +2167,16 @@ impl Parser {
Ok(fun)
}

fn enable_extension<'a>(
&mut self,
lexer: &mut Lexer<'a>,
) -> Result<crate::Extension, Error<'a>> {
let (ext, ext_span) = lexer.next_extension_with_span()?;
let extension = conv::map_extension(ext, ext_span)?;
lexer.add_extension(extension.clone());
Ok(extension)
}

fn global_directive<'a>(
&mut self,
lexer: &mut Lexer<'a>,
Expand All @@ -2176,27 +2186,29 @@ impl Parser {
let (_, enable_span) = lexer.next_ident_with_span()?;

let mut enable_extension_list = Vec::with_capacity(4);
let mut ready = true;
while !lexer.skip(Token::Separator(';')) {
if !ready {
return Err(Error::Unexpected(
lexer.next().1,
ExpectedToken::Token(Token::Separator(',')),
));
}
let (ext, ext_span) = lexer.next_extension_with_span()?;
let extension = conv::map_extension(ext, ext_span)?;
lexer.add_extension(extension.clone());

// Parse the first extension
let extension = self.enable_extension(lexer)?;
enable_extension_list.push(extension);

// Parse additional extensions separated by commas
while lexer.skip(Token::Separator(',')) {
let extension = self.enable_extension(lexer)?;
enable_extension_list.push(extension);
ready = lexer.skip(Token::Separator(','));
}

// Require a semicolon at the end
if !lexer.skip(Token::Separator(';')) {
return Err(Error::Unexpected(
lexer.next().1,
ExpectedToken::Token(Token::Separator(';')),
));
}

out.directives.append(
ast::GlobalDirective {
kind: ast::GlobalDirectiveKind::Enable(ast::EnableDirective {
enable_extension_list,
}),
},
ast::GlobalDirective::Enable(ast::EnableDirective {
enable_extension_list,
}),
enable_span,
);
}
Expand Down
9 changes: 0 additions & 9 deletions naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -981,15 +981,6 @@ pub enum Extension {
F16,
}

/// Enable directive
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub struct EnableDirective {
extension: Extension,
}

/// Variable defined at module level.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
Expand Down
20 changes: 7 additions & 13 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1491,8 +1491,9 @@ impl<'a> ConstantEvaluator<'a> {
Literal::I32(v) => v,
Literal::U32(v) => v as i32,
Literal::F32(v) => v as i32,
Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
Literal::Bool(v) => v as i32,
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) | Literal::F16(_) => {
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
return make_error();
}
Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
Expand All @@ -1502,8 +1503,9 @@ impl<'a> ConstantEvaluator<'a> {
Literal::I32(v) => v as u32,
Literal::U32(v) => v,
Literal::F32(v) => v as u32,
Literal::F16(v) => f16::to_u32(&v).unwrap(), //Only None on NaN or Inf
Literal::Bool(v) => v as u32,
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) | Literal::F16(_) => {
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
return make_error();
}
Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
Expand All @@ -1517,11 +1519,7 @@ impl<'a> ConstantEvaluator<'a> {
Literal::F64(v) => v as i64,
Literal::I64(v) => v,
Literal::U64(v) => v as i64,
Literal::F16(v) => f16::to_i64(&v).ok_or(
ConstantEvaluatorError::AutomaticConversionFloatToInt {
to_type: "i64",
},
)?,
Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
}),
Expand All @@ -1533,11 +1531,7 @@ impl<'a> ConstantEvaluator<'a> {
Literal::F64(v) => v as u64,
Literal::I64(v) => v as u64,
Literal::U64(v) => v,
Literal::F16(v) => f16::to_u64(&v).ok_or(
ConstantEvaluatorError::AutomaticConversionFloatToInt {
to_type: "u64",
},
)?,
Literal::F16(v) => f16::to_u64(&v).unwrap(), //Only None on NaN or Inf
Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
}),
Expand Down Expand Up @@ -1583,11 +1577,11 @@ impl<'a> ConstantEvaluator<'a> {
Literal::I32(v) => v != 0,
Literal::U32(v) => v != 0,
Literal::F32(v) => v != 0.0,
Literal::F16(v) => v != f16::zero(),
Literal::Bool(v) => v,
Literal::F64(_)
| Literal::I64(_)
| Literal::U64(_)
| Literal::F16(_)
| Literal::AbstractInt(_)
| Literal::AbstractFloat(_) => {
return make_error();
Expand Down
2 changes: 2 additions & 0 deletions naga/src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ bitflags::bitflags! {
const SHADER_INT64_ATOMIC_MIN_MAX = 0x40000;
/// Support for all atomic operations on 64-bit integers.
const SHADER_INT64_ATOMIC_ALL_OPS = 0x80000;
/// Support for 16-bit floating-point types.
const SHADER_FLOAT16 = 0x100000;
}
}

Expand Down
2 changes: 1 addition & 1 deletion naga/src/valid/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ impl super::Validator {
true
}
2 => {
if !self.capabilities.contains(Capabilities::FLOAT16) {
if !self.capabilities.contains(Capabilities::SHADER_FLOAT16) {
return Err(WidthError::MissingCapability {
name: "f16",
flag: "FLOAT16",
Expand Down
5 changes: 4 additions & 1 deletion wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,10 @@ pub fn create_validator(
features.contains(wgt::Features::PUSH_CONSTANTS),
);
caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64));
caps.set(Caps::FLOAT16, features.contains(wgt::Features::SHADER_F16));
caps.set(
Caps::SHADER_FLOAT16,
features.contains(wgt::Features::SHADER_F16),
);
caps.set(
Caps::PRIMITIVE_INDEX,
features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX),
Expand Down
15 changes: 14 additions & 1 deletion wgpu-hal/src/dx12/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,22 @@ impl super::Adapter {
&& features1.Int64ShaderOps != 0,
);

let float16_supported = {
let mut features4: crate::dx12::types::D3D12_FEATURE_DATA_D3D12_OPTIONS4 =
unsafe { mem::zeroed() };
let hr = unsafe {
device.CheckFeatureSupport(
23, // D3D12_FEATURE_D3D12_OPTIONS4: https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_feature#syntax
&mut features4 as *mut _ as *mut _,
mem::size_of::<d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS4>() as _,
)
};
hr == 0 && features4.Native16BitShaderOpsSupported != 0
};

features.set(
wgt::Features::SHADER_F16,
shader_model >= naga::back::hlsl::ShaderModel::V6_2 && hr == 0,
shader_model >= naga::back::hlsl::ShaderModel::V6_2 && float16_supported,
);

features.set(
Expand Down
16 changes: 16 additions & 0 deletions wgpu-hal/src/dx12/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ winapi::ENUM! {
}
}

winapi::ENUM! {
enum D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER {
D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER_0 = 0,
D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER_1,
D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER_2
}
}

winapi::STRUCT! {
struct D3D12_FEATURE_DATA_D3D12_OPTIONS4 {
MSAA64KBAlignedTextureSupported: winapi::shared::minwindef::BOOL,
SharedResourceCompatibilityTier: D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER,
Native16BitShaderOpsSupported: winapi::shared::minwindef::BOOL,
}
}

winapi::STRUCT! {
struct D3D12_FEATURE_DATA_D3D12_OPTIONS9 {
MeshShaderPipelineStatsSupported: winapi::shared::minwindef::BOOL,
Expand Down

0 comments on commit 733fd15

Please sign in to comment.