diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 53f59349ae029..742b2c55bcce9 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Frontend/HLSL/RootSignatureMetadata.h" +#include "llvm/BinaryFormat/DXContainer.h" #include "llvm/Frontend/HLSL/RootSignatureValidations.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" @@ -559,11 +560,17 @@ bool MetadataParser::validateRootSignature( assert(dxbc::isValidParameterType(Info.Header.ParameterType) && "Invalid value for ParameterType"); - switch (Info.Header.ParameterType) { + dxbc::RootParameterType PT = + static_cast(Info.Header.ParameterType); - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): { + switch (PT) { + case dxbc::RootParameterType::Constants32Bit: + // ToDo: Add proper validation. + continue; + + case dxbc::RootParameterType::CBV: + case dxbc::RootParameterType::UAV: + case dxbc::RootParameterType::SRV: { const dxbc::RTS0::v2::RootDescriptor &Descriptor = RSD.ParametersContainer.getRootDescriptor(Info.Location); if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) @@ -580,7 +587,7 @@ bool MetadataParser::validateRootSignature( } break; } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + case dxbc::RootParameterType::DescriptorTable: { const mcdxbc::DescriptorTable &Table = RSD.ParametersContainer.getDescriptorTable(Info.Location); for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index f11c7d2033bfb..e55d3ec15c19d 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -53,10 +53,9 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { bool verifyRangeType(uint32_t Type) { switch (Type) { - case llvm::to_underlying(dxbc::DescriptorRangeType::CBV): - case llvm::to_underlying(dxbc::DescriptorRangeType::SRV): - case llvm::to_underlying(dxbc::DescriptorRangeType::UAV): - case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler): +#define DESCRIPTOR_RANGE(Num, Val) \ + case llvm::to_underlying(dxbc::DescriptorRangeType::Val): +#include "llvm/BinaryFormat/DXContainerConstants.def" return true; }; diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp index 482280b5ef289..c94a39f80eeb2 100644 --- a/llvm/lib/MC/DXContainerRootSignature.cpp +++ b/llvm/lib/MC/DXContainerRootSignature.cpp @@ -8,6 +8,7 @@ #include "llvm/MC/DXContainerRootSignature.h" #include "llvm/ADT/SmallString.h" +#include "llvm/BinaryFormat/DXContainer.h" #include "llvm/Support/EndianStream.h" using namespace llvm; @@ -35,20 +36,26 @@ size_t RootSignatureDesc::getSize() const { StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler); for (const RootParameterInfo &I : ParametersContainer) { - switch (I.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): + if (!dxbc::isValidParameterType(I.Header.ParameterType)) + continue; + + dxbc::RootParameterType PT = + static_cast(I.Header.ParameterType); + + switch (PT) { + case dxbc::RootParameterType::Constants32Bit: Size += sizeof(dxbc::RTS0::v1::RootConstants); break; - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): + case dxbc::RootParameterType::CBV: + case dxbc::RootParameterType::SRV: + case dxbc::RootParameterType::UAV: if (Version == 1) Size += sizeof(dxbc::RTS0::v1::RootDescriptor); else Size += sizeof(dxbc::RTS0::v2::RootDescriptor); break; - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): + case dxbc::RootParameterType::DescriptorTable: const DescriptorTable &Table = ParametersContainer.getDescriptorTable(I.Location); @@ -97,8 +104,12 @@ void RootSignatureDesc::write(raw_ostream &OS) const { for (size_t I = 0; I < NumParameters; ++I) { rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]); const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I); - switch (Type) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): { + if (!dxbc::isValidParameterType(Type)) + continue; + dxbc::RootParameterType PT = static_cast(Type); + + switch (PT) { + case dxbc::RootParameterType::Constants32Bit: { const dxbc::RTS0::v1::RootConstants &Constants = ParametersContainer.getConstant(Loc); support::endian::write(BOS, Constants.ShaderRegister, @@ -109,9 +120,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const { llvm::endianness::little); break; } - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): { + case dxbc::RootParameterType::CBV: + case dxbc::RootParameterType::SRV: + case dxbc::RootParameterType::UAV: { const dxbc::RTS0::v2::RootDescriptor &Descriptor = ParametersContainer.getRootDescriptor(Loc); @@ -123,7 +134,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const { support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little); break; } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + case dxbc::RootParameterType::DescriptorTable: { const DescriptorTable &Table = ParametersContainer.getDescriptorTable(Loc); support::endian::write(BOS, (uint32_t)Table.Ranges.size(), diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp index 043b575a43b11..35e1c3e3b5953 100644 --- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp +++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp @@ -278,8 +278,19 @@ void DXContainerWriter::writeParts(raw_ostream &OS) { dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility, L.Header.Offset}; - switch (L.Header.Type) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): { + if (!dxbc::isValidParameterType(L.Header.Type)) { + // Handling invalid parameter type edge case. We intentionally let + // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order + // for that to be used as a testing tool more effectively. + RS.ParametersContainer.addInvalidParameter(Header); + continue; + } + + dxbc::RootParameterType ParameterType = + static_cast(L.Header.Type); + + switch (ParameterType) { + case dxbc::RootParameterType::Constants32Bit: { const DXContainerYAML::RootConstantsYaml &ConstantYaml = P.RootSignature->Parameters.getOrInsertConstants(L); dxbc::RTS0::v1::RootConstants Constants; @@ -289,9 +300,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) { RS.ParametersContainer.addParameter(Header, Constants); break; } - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): { + case dxbc::RootParameterType::CBV: + case dxbc::RootParameterType::SRV: + case dxbc::RootParameterType::UAV: { const DXContainerYAML::RootDescriptorYaml &DescriptorYaml = P.RootSignature->Parameters.getOrInsertDescriptor(L); @@ -303,7 +314,7 @@ void DXContainerWriter::writeParts(raw_ostream &OS) { RS.ParametersContainer.addParameter(Header, Descriptor); break; } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + case dxbc::RootParameterType::DescriptorTable: { const DXContainerYAML::DescriptorTableYaml &TableYaml = P.RootSignature->Parameters.getOrInsertTable(L); mcdxbc::DescriptorTable Table; @@ -323,11 +334,6 @@ void DXContainerWriter::writeParts(raw_ostream &OS) { RS.ParametersContainer.addParameter(Header, Table); break; } - default: - // Handling invalid parameter type edge case. We intentionally let - // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order - // for that to be used as a testing tool more effectively. - RS.ParametersContainer.addInvalidParameter(Header); } } diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp index 263f7bdf37bca..aca605e099535 100644 --- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp +++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp @@ -424,22 +424,29 @@ void MappingContextTraits(L.Header.Type); + switch (PT) { + case dxbc::RootParameterType::Constants32Bit: { DXContainerYAML::RootConstantsYaml &Constants = S.Parameters.getOrInsertConstants(L); IO.mapRequired("Constants", Constants); break; } - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): { + case dxbc::RootParameterType::CBV: + case dxbc::RootParameterType::SRV: + case dxbc::RootParameterType::UAV: { DXContainerYAML::RootDescriptorYaml &Descriptor = S.Parameters.getOrInsertDescriptor(L); IO.mapRequired("Descriptor", Descriptor); break; } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + case dxbc::RootParameterType::DescriptorTable: { DXContainerYAML::DescriptorTableYaml &Table = S.Parameters.getOrInsertTable(L); IO.mapRequired("Table", Table); diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index ebdfcaa566b51..04c7c77953a5b 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -175,8 +175,10 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, OS << "- Parameter Type: " << Type << "\n" << " Shader Visibility: " << Header.ShaderVisibility << "\n"; - switch (Type) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): { + assert(dxbc::isValidParameterType(Type) && "Invalid Parameter Type"); + dxbc::RootParameterType PT = static_cast(Type); + switch (PT) { + case dxbc::RootParameterType::Constants32Bit: { const dxbc::RTS0::v1::RootConstants &Constants = RS.ParametersContainer.getConstant(Loc); OS << " Register Space: " << Constants.RegisterSpace << "\n" @@ -184,9 +186,9 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, << " Num 32 Bit Values: " << Constants.Num32BitValues << "\n"; break; } - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): { + case dxbc::RootParameterType::CBV: + case dxbc::RootParameterType::UAV: + case dxbc::RootParameterType::SRV: { const dxbc::RTS0::v2::RootDescriptor &Descriptor = RS.ParametersContainer.getRootDescriptor(Loc); OS << " Register Space: " << Descriptor.RegisterSpace << "\n" @@ -195,7 +197,7 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, OS << " Flags: " << Descriptor.Flags << "\n"; break; } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + case dxbc::RootParameterType::DescriptorTable: { const mcdxbc::DescriptorTable &Table = RS.ParametersContainer.getDescriptorTable(Loc); OS << " NumRanges: " << Table.Ranges.size() << "\n";