Skip to content

Commit

Permalink
Fix geometry shader related modifier lowering. (#6197)
Browse files Browse the repository at this point in the history
* Fix geometry shader related modifier lowering.

* Cleanup.

* Delete obselete test.

* Enable geometryShader test on windows only.

* Fix test.
  • Loading branch information
csyonghe authored Jan 28, 2025
1 parent cd27fbd commit 4b9a342
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 44 deletions.
68 changes: 24 additions & 44 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2333,6 +2333,30 @@ void addVarDecorations(IRGenContext* context, IRInst* inst, Decl* decl)
inst,
IRIntegerValue(collection->getMemoryQualifierBit()));
}
else if (auto geometryModifier = as<HLSLGeometryShaderInputPrimitiveTypeModifier>(mod))
{
IROp op = kIROp_Invalid;
switch (geometryModifier->astNodeType)
{
case ASTNodeType::HLSLTriangleModifier:
op = kIROp_TriangleInputPrimitiveTypeDecoration;
break;
case ASTNodeType::HLSLPointModifier:
op = kIROp_PointInputPrimitiveTypeDecoration;
break;
case ASTNodeType::HLSLLineModifier:
op = kIROp_LineInputPrimitiveTypeDecoration;
break;
case ASTNodeType::HLSLLineAdjModifier:
op = kIROp_LineAdjInputPrimitiveTypeDecoration;
break;
case ASTNodeType::HLSLTriangleAdjModifier:
op = kIROp_TriangleAdjInputPrimitiveTypeDecoration;
break;
}
if (op != kIROp_Invalid)
builder->addDecoration(inst, op);
}
// TODO: what are other modifiers we need to propagate through?
}
if (auto t =
Expand Down Expand Up @@ -11180,50 +11204,6 @@ static void lowerFrontEndEntryPointToIR(
entryPointName->text.getUnownedSlice(),
moduleName.getUnownedSlice());
}

// Go through the entry point parameters creating decorations from layout as appropriate
// But only if this is a definition not a declaration
if (isDefinition(instToDecorate))
{
FilteredMemberList<ParamDecl> params = entryPointFuncDecl->getParameters();

IRGlobalValueWithParams* valueWithParams = as<IRGlobalValueWithParams>(instToDecorate);
if (valueWithParams)
{
IRParam* irParam = valueWithParams->getFirstParam();

for (auto param : params)
{
if (auto modifier =
param->findModifier<HLSLGeometryShaderInputPrimitiveTypeModifier>())
{
IROp op = kIROp_Invalid;

if (as<HLSLTriangleModifier>(modifier))
op = kIROp_TriangleInputPrimitiveTypeDecoration;
else if (as<HLSLPointModifier>(modifier))
op = kIROp_PointInputPrimitiveTypeDecoration;
else if (as<HLSLLineModifier>(modifier))
op = kIROp_LineInputPrimitiveTypeDecoration;
else if (as<HLSLLineAdjModifier>(modifier))
op = kIROp_LineAdjInputPrimitiveTypeDecoration;
else if (as<HLSLTriangleAdjModifier>(modifier))
op = kIROp_TriangleAdjInputPrimitiveTypeDecoration;

if (op != kIROp_Invalid)
{
builder->addDecoration(irParam, op);
}
else
{
SLANG_UNEXPECTED("unhandled primitive type");
}
}

irParam = irParam->getNextParam();
}
}
}
}

static void lowerProgramEntryPointToIR(
Expand Down
93 changes: 93 additions & 0 deletions tools/slang-unit-test/unit-test-geometry-shader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// unit-test-geometry-shader.cpp

#include "../../source/core/slang-io.h"
#include "../../source/core/slang-process.h"
#include "slang-com-ptr.h"
#include "slang.h"
#include "unit-test/slang-unit-test.h"

#include <stdio.h>
#include <stdlib.h>

using namespace Slang;

// Test the compilation API for compiling geometry shaders to DXIL.

#if SLANG_WINDOWS_FAMILY

SLANG_UNIT_TEST(geometryShader)
{
const char* userSourceBody = R"(
struct GS_INPUT
{
float4 PosSS : TEXTURE0; // [Screen Space] Position
};
struct PS_INPUT
{
float4 PosSS : SV_POSITION; // [Screen Space] Position
};
[maxvertexcount(3)]
void main(triangle GS_INPUT input[3], inout TriangleStream<PS_INPUT> outStream)
{
PS_INPUT output;
output.PosSS = input[0].PosSS;
outStream.Append(output);
output.PosSS = input[1].PosSS;
outStream.Append(output);
output.PosSS = input[2].PosSS;
outStream.Append(output);
outStream.RestartStrip();
}
)";
ComPtr<slang::IGlobalSession> globalSession;
SlangGlobalSessionDesc globalDesc = {};
globalDesc.enableGLSL = true;
SLANG_CHECK(slang_createGlobalSession2(&globalDesc, globalSession.writeRef()) == SLANG_OK);
slang::TargetDesc targetDesc = {};
targetDesc.format = SLANG_DXIL;
targetDesc.profile = globalSession->findProfile("sm_6_0");
slang::SessionDesc sessionDesc = {};
sessionDesc.targetCount = 1;
sessionDesc.targets = &targetDesc;
ComPtr<slang::ISession> session;
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

ComPtr<slang::IBlob> diagnosticBlob;
auto module = session->loadModuleFromSourceString(
"m",
"m.slang",
userSourceBody,
diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);

ComPtr<slang::IEntryPoint> entryPoint;
module->findAndCheckEntryPoint(
"main",
SLANG_STAGE_GEOMETRY,
entryPoint.writeRef(),
diagnosticBlob.writeRef());

slang::IComponentType* componentTypes[2] = {module, entryPoint.get()};
ComPtr<slang::IComponentType> composedProgram;
session->createCompositeComponentType(
componentTypes,
2,
composedProgram.writeRef(),
diagnosticBlob.writeRef());

ComPtr<slang::IComponentType> linkedProgram;
composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());

ComPtr<slang::IBlob> code;
linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());

SLANG_CHECK(code != nullptr);
}

#endif

0 comments on commit 4b9a342

Please sign in to comment.