Skip to content

[HLSL][RootSignature] Allow for multiple parsing errors in RootSignatureParser #147832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions clang/include/clang/Parse/ParseHLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,21 @@ class RootSignatureParser {
bool tryConsumeExpectedToken(RootSignatureToken::Kind Expected);
bool tryConsumeExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected);

/// Consume tokens until the expected token has been peeked to be next
/// or we have reached the end of the stream. Note that this means the
/// expected token will be the next token not CurToken.
///
/// Returns true if it found a token of the given type.
bool skipUntilExpectedToken(RootSignatureToken::Kind Expected);
bool skipUntilExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected);

/// Consume tokens until we reach a closing right paren, ')', or, until we
/// have reached the end of the stream. This will place the current token
/// to be the end of stream or the right paren.
///
/// Returns true if it is closed before the end of stream.
bool skipUntilClosedParens(uint32_t NumParens = 1);

/// Convert the token's offset in the signature string to its SourceLocation
///
/// This allows to currently retrieve the location for multi-token
Expand Down
103 changes: 88 additions & 15 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ namespace hlsl {

using TokenKind = RootSignatureToken::Kind;

static const TokenKind RootElementKeywords[] = {
TokenKind::kw_RootFlags,
TokenKind::kw_CBV,
TokenKind::kw_UAV,
TokenKind::kw_SRV,
TokenKind::kw_DescriptorTable,
TokenKind::kw_StaticSampler,
};

RootSignatureParser::RootSignatureParser(
llvm::dxbc::RootSignatureVersion Version,
SmallVector<RootSignatureElement> &Elements, StringLiteral *Signature,
Expand All @@ -27,51 +36,76 @@ RootSignatureParser::RootSignatureParser(
bool RootSignatureParser::parse() {
// Iterate as many RootSignatureElements as possible, until we hit the
// end of the stream
bool HadError = false;
while (!peekExpectedToken(TokenKind::end_of_stream)) {
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Flags = parseRootFlags();
if (!Flags.has_value())
return true;
if (!Flags.has_value()) {
HadError = true;
skipUntilExpectedToken(RootElementKeywords);
continue;
}

Elements.emplace_back(ElementLoc, *Flags);
} else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Constants = parseRootConstants();
if (!Constants.has_value())
return true;
if (!Constants.has_value()) {
HadError = true;
skipUntilExpectedToken(RootElementKeywords);
continue;
}
Elements.emplace_back(ElementLoc, *Constants);
} else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Table = parseDescriptorTable();
if (!Table.has_value())
return true;
if (!Table.has_value()) {
HadError = true;
// We are within a DescriptorTable, we will do our best to recover
// by skipping until we encounter the expected closing ')'.
skipUntilClosedParens();
consumeNextToken();
skipUntilExpectedToken(RootElementKeywords);
continue;
}
Elements.emplace_back(ElementLoc, *Table);
} else if (tryConsumeExpectedToken(
{TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Descriptor = parseRootDescriptor();
if (!Descriptor.has_value())
return true;
if (!Descriptor.has_value()) {
HadError = true;
skipUntilExpectedToken(RootElementKeywords);
continue;
}
Elements.emplace_back(ElementLoc, *Descriptor);
} else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Sampler = parseStaticSampler();
if (!Sampler.has_value())
return true;
if (!Sampler.has_value()) {
HadError = true;
skipUntilExpectedToken(RootElementKeywords);
continue;
}
Elements.emplace_back(ElementLoc, *Sampler);
} else {
HadError = true;
consumeNextToken(); // let diagnostic be at the start of invalid token
reportDiag(diag::err_hlsl_invalid_token)
<< /*parameter=*/0 << /*param of*/ TokenKind::kw_RootSignature;
return true;
skipUntilExpectedToken(RootElementKeywords);
continue;
}

// ',' denotes another element, otherwise, expected to be at end of stream
if (!tryConsumeExpectedToken(TokenKind::pu_comma))
if (!tryConsumeExpectedToken(TokenKind::pu_comma)) {
// ',' denotes another element, otherwise, expected to be at end of stream
break;
}
}

return consumeExpectedToken(TokenKind::end_of_stream,
return HadError ||
consumeExpectedToken(TokenKind::end_of_stream,
diag::err_expected_either, TokenKind::pu_comma);
}

Expand Down Expand Up @@ -262,8 +296,13 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
// DescriptorTableClause - CBV, SRV, UAV, or Sampler
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Clause = parseDescriptorTableClause();
if (!Clause.has_value())
if (!Clause.has_value()) {
// We are within a DescriptorTableClause, we will do our best to recover
// by skipping until we encounter the expected closing ')'
skipUntilExpectedToken(TokenKind::pu_r_paren);
consumeNextToken();
return std::nullopt;
}
Elements.emplace_back(ElementLoc, *Clause);
Table.NumClauses++;
} else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
Expand Down Expand Up @@ -1371,6 +1410,40 @@ bool RootSignatureParser::tryConsumeExpectedToken(
return true;
}

bool RootSignatureParser::skipUntilExpectedToken(TokenKind Expected) {
return skipUntilExpectedToken(ArrayRef{Expected});
}

bool RootSignatureParser::skipUntilExpectedToken(
ArrayRef<TokenKind> AnyExpected) {

while (!peekExpectedToken(AnyExpected)) {
if (peekExpectedToken(TokenKind::end_of_stream))
return false;
consumeNextToken();
}

return true;
}

bool RootSignatureParser::skipUntilClosedParens(uint32_t NumParens) {
TokenKind ParenKinds[] = {
TokenKind::pu_l_paren,
TokenKind::pu_r_paren,
};
while (skipUntilExpectedToken(ParenKinds)) {
consumeNextToken();
if (CurToken.TokKind == TokenKind::pu_r_paren)
NumParens--;
else
NumParens++;
if (NumParens == 0)
return true;
}

return false;
}

SourceLocation RootSignatureParser::getTokenLocation(RootSignatureToken Tok) {
return Signature->getLocationOfByte(Tok.LocOffset, PP.getSourceManager(),
PP.getLangOpts(), PP.getTargetInfo());
Expand Down
53 changes: 52 additions & 1 deletion clang/test/SemaHLSL/RootSignature-err.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,57 @@ void bad_root_signature_22() {}
[RootSignature("RootFlags(local_root_signature | root_flag_typo)")]
void bad_root_signature_23() {}

#define DemoMultipleErrorsRootSignature \
"CBV(b0, space = invalid)," \
"StaticSampler()" \
"DescriptorTable(" \
" visibility = SHADER_VISIBILITY_ALL," \
" visibility = SHADER_VISIBILITY_DOMAIN," \
")," \
"SRV(t0, space = 28947298374912374098172)" \
"UAV(u0, flags = 3)" \
"DescriptorTable(Sampler(s0 flags = DATA_VOLATILE))," \
"CBV(b0),,"

// expected-error@+7 {{expected integer literal after '='}}
// expected-error@+6 {{did not specify mandatory parameter 's register'}}
// expected-error@+5 {{specified the same parameter 'visibility' multiple times}}
// expected-error@+4 {{integer literal is too large to be represented as a 32-bit signed integer type}}
// expected-error@+3 {{flag value is neither a literal 0 nor a named value}}
// expected-error@+2 {{expected ')' or ','}}
// expected-error@+1 {{invalid parameter of RootSignature}}
[RootSignature(DemoMultipleErrorsRootSignature)]
void multiple_errors() {}

#define DemoGranularityRootSignature \
"CBV(b0, reported_diag, flags = skipped_diag)," \
"DescriptorTable( " \
" UAV(u0, reported_diag), " \
" SRV(t0, skipped_diag), " \
")," \
"StaticSampler(s0, reported_diag, SRV(t0, reported_diag)" \
""

// expected-error@+4 {{invalid parameter of CBV}}
// expected-error@+3 {{invalid parameter of UAV}}
// expected-error@+2 {{invalid parameter of StaticSampler}}
// expected-error@+1 {{invalid parameter of SRV}}
[RootSignature(DemoGranularityRootSignature)]
void granularity_errors() {}

#define TestTableScope \
"DescriptorTable( " \
" UAV(u0, reported_diag), " \
" SRV(t0, skipped_diag), " \
" Sampler(s0, skipped_diag), " \
")," \
"CBV(s0, reported_diag)"

// expected-error@+2 {{invalid parameter of UAV}}
// expected-error@+1 {{invalid parameter of CBV}}
[RootSignature(TestTableScope)]
void recover_scope_errors() {}

// Basic validation of register value and space

// expected-error@+2 {{value must be in the range [0, 4294967294]}}
Expand Down Expand Up @@ -138,4 +189,4 @@ void basic_validation_5() {}

// expected-error@+1 {{value must be in the range [-16.00, 15.99]}}
[RootSignature("StaticSampler(s0, mipLODBias = 15.990001)")]
void basic_validation_6() {}
void basic_validation_6() {}