Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Optional, Tuple and bool to be used in varying input/output. #5889

Merged
merged 13 commits into from
Dec 18, 2024
13 changes: 10 additions & 3 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,8 +1279,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case AddressSpace::Uniform:
return SpvStorageClassUniform;
case AddressSpace::Input:
case AddressSpace::BuiltinInput:
return SpvStorageClassInput;
case AddressSpace::Output:
case AddressSpace::BuiltinOutput:
return SpvStorageClassOutput;
case AddressSpace::TaskPayloadWorkgroup:
return SpvStorageClassTaskPayloadWorkgroupEXT;
Expand Down Expand Up @@ -2688,7 +2690,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
IRBuilder builder(spvAsmBuiltinVar);
builder.setInsertBefore(spvAsmBuiltinVar);
auto varInst = getBuiltinGlobalVar(
builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), AddressSpace::Input),
builder.getPtrType(
kIROp_PtrType,
spvAsmBuiltinVar->getDataType(),
AddressSpace::BuiltinInput),
kind,
spvAsmBuiltinVar);
registerInst(spvAsmBuiltinVar, varInst);
Expand Down Expand Up @@ -4214,7 +4219,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto addrSpace = ptrType->getAddressSpace();
if (addrSpace != AddressSpace::Input &&
addrSpace != AddressSpace::Output)
addrSpace != AddressSpace::Output &&
addrSpace != AddressSpace::BuiltinInput &&
addrSpace != AddressSpace::BuiltinOutput)
continue;
}
}
Expand Down Expand Up @@ -4995,7 +5002,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
if (!ptrType)
return;
auto addrSpace = ptrType->getAddressSpace();
if (addrSpace == AddressSpace::Input)
if (addrSpace == AddressSpace::Input || addrSpace == AddressSpace::BuiltinInput)
{
if (isIntegralScalarOrCompositeType(ptrType->getValueType()))
{
Expand Down
150 changes: 106 additions & 44 deletions source/slang/slang-ir-lower-buffer-element-type.cpp

Large diffs are not rendered by default.

17 changes: 15 additions & 2 deletions source/slang/slang-ir-spirv-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Skip load's for referenced `Input` variables since a ref implies
// passing as is, which needs to be a pointer (pass as is).
if (user->getDataType() && user->getDataType()->getOp() == kIROp_RefType &&
addressSpace == AddressSpace::Input)
(addressSpace == AddressSpace::Input ||
addressSpace == AddressSpace::BuiltinInput))
{
builder.replaceOperand(use, addr);
continue;
Expand Down Expand Up @@ -431,7 +432,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
String semanticName = systemValueAttr->getName();
semanticName = semanticName.toLower();
if (semanticName == "sv_pointsize")
addressSpace = AddressSpace::Input;
addressSpace = AddressSpace::BuiltinInput;
}
}

Expand Down Expand Up @@ -661,6 +662,18 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
"resolve a storage class address space.");
}
}

switch (result)
{
case AddressSpace::Input:
if (varLayout->findSystemValueSemanticAttr())
result = AddressSpace::BuiltinInput;
break;
case AddressSpace::Output:
if (varLayout->findSystemValueSemanticAttr())
result = AddressSpace::BuiltinOutput;
break;
}
return result;
}

Expand Down
53 changes: 52 additions & 1 deletion source/slang/slang-parameter-binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2262,12 +2262,63 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(

return ptrTypeLayout;
}
else if (auto optionalType = as<OptionalType>(type))
{
Array<Type*, 2> types =
makeArray(optionalType->getValueType(), context->getASTBuilder()->getBoolType());
auto tupleType = context->getASTBuilder()->getTupleType(types.getView());
return processEntryPointVaryingParameter(context, tupleType, state, varLayout);
}
else if (auto tupleType = as<TupleType>(type))
{
RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
structLayout->type = type;
for (Index i = 0; i < tupleType->getMemberCount(); i++)
{
auto fieldType = tupleType->getMember(i);
RefPtr<VarLayout> fieldVarLayout = new VarLayout();

// We don't really have a "field" decl, so just use the tuple-typed decl
// itself as the varDecl of the elements.
auto fieldDecl = (VarDeclBase*)varLayout->varDecl.getDecl();
fieldVarLayout->varDecl = fieldDecl;

structLayout->fields.add(fieldVarLayout);

auto fieldTypeLayout = processEntryPointVaryingParameterDecl(
context,
fieldDecl,
fieldType,
state,
fieldVarLayout);

if (!fieldTypeLayout)
{
getSink(context)->diagnose(
varLayout->varDecl,
Diagnostics::notValidVaryingParameter,
fieldType);
continue;
}
fieldVarLayout->typeLayout = fieldTypeLayout;

// Assign offsets in var layout for each resource kind of the type.
for (auto fieldTypeResInfo : fieldTypeLayout->resourceInfos)
{
auto kind = fieldTypeResInfo.kind;
auto structTypeResInfo = structLayout->findOrAddResourceInfo(kind);
auto fieldResInfo = fieldVarLayout->findOrAddResourceInfo(kind);
fieldResInfo->index = structTypeResInfo->count.getFiniteValue();
structTypeResInfo->count += fieldTypeResInfo.count;
}
}
return structLayout;
}
// Catch declaration-reference types late in the sequence, since
// otherwise they will include all of the above cases...
else if (auto declRefType = as<DeclRefType>(type))
{
auto declRef = declRefType->getDeclRef();

if (auto structDeclRef = declRef.as<StructDecl>())
{
RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
Expand Down
10 changes: 9 additions & 1 deletion source/slang/slang-type-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4785,10 +4785,18 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type
type,
rules);
}
else if (auto optionalType = as<OptionalType>(type))
{
// OptionalType should be laid out the same way as Tuple<T, bool>.
Array<Type*, 2> types =
makeArray(optionalType->getValueType(), context.astBuilder->getBoolType());
auto tupleType = context.astBuilder->getTupleType(types.getView());
return _createTypeLayout(context, tupleType);
}
else if (auto tupleType = as<TupleType>(type))
{
// A `Tuple` type is laid out exactly the same way as a `struct` type,
// except that we want have a declref to the field.
// except that we won't have a declref to the field.

StructTypeLayoutBuilder typeLayoutBuilder;
StructTypeLayoutBuilder pendingDataTypeLayoutBuilder;
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-type-system-shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ enum class AddressSpace : uint64_t
MetalObjectData,
// Corresponds to SPIR-V's SpvStorageClassInput
Input,
// Same as `Input`, but used for builtin input variables.
BuiltinInput,
// Corresponds to SPIR-V's SpvStorageClassOutput
Output,
// Same as `Output`, but used for builtin output variables.
BuiltinOutput,
// Corresponds to SPIR-V's SpvStorageClassTaskPayloadWorkgroupEXT
TaskPayloadWorkgroup,
// Corresponds to SPIR-V's SpvStorageClassFunction
Expand Down
23 changes: 23 additions & 0 deletions tests/spirv/matrix-vertex-input.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv
// CHECK: OpVectorTimesMatrix

struct Vertex
{
float4x4 m;
float4 pos;
}

struct VertexOut
{
float4 pos : SV_Position;
float4 color;
}

[shader("vertex")]
VertexOut vertMain(Vertex v)
{
VertexOut o;
o.pos = mul(v.m, v.pos);
o.color = v.pos;
return o;
}
26 changes: 26 additions & 0 deletions tests/spirv/optional-vertex-output.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv

// Test that we can use Optional<T> or bool types in varying input or outputs.

// CHECK: OpDecorate %i_inA_value Location 0
// CHECK: OpDecorate %i_inA_hasValue Location 1
// CHECK: OpDecorate %entryPointParam_vertMain_a_value Location 0
// CHECK: OpDecorate %entryPointParam_vertMain_a_hasValue Location 1

struct VIn {
Optional<float> inA;
}

struct VSOut {
Optional<float> a;
bool outputValues[3];
};

[shader("vertex")]
VSOut vertMain(VIn i)
{
VSOut o;
o.a = i.inA;
o.outputValues = { true, false, true };
return o;
}
Loading