Skip to content

Commit

Permalink
ubo codegen first cut
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 10, 2021
1 parent 461d06e commit 4f5ca8c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
24 changes: 19 additions & 5 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,30 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
spirv::Value func_ptr = builder_->NewFunction();
builder_->StartFunction(func_ptr);

// All the POD arguments are passed in through PushConstant
if (pod_args.size() != 0) {
std::vector<spirv::SType> value_types;
for (size_t i = 0; i < pod_args.size(); ++i) {
value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
}
spirv::Value ptr = builder_->DeclarePushConstant(value_types);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
// All the POD arguments are passed in through PushConstant
if (pod_args.size() * 8 <= 128) {
spirv::Value ptr = builder_->DeclarePushConstant(value_types);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value =
builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
}
} else {
DataType value_storage_type = DataType::Int(64);
spirv::Value ptr_ubo =
builder_->BufferArgument(builder_->GetSType(value_storage_type), 0, num_buffer, true);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::SType ptr_type = builder_->GetPointerType(value_types[i], spv::StorageClassUniform);
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, ptr_ubo, MakeValue(PrimExpr(static_cast<int32_t>(i * 8))));
var_map_[pod_args[i].get()] =
builder_->MakeValue(spv::OpLoad, value_types[i], ptr, spv::MemoryAccessMaskNone);
}
}
}
this->VisitStmt(f->body);
Expand Down
13 changes: 7 additions & 6 deletions src/target/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,15 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) {
}

Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set,
uint32_t binding) {
uint32_t binding, bool uniform) {
// NOTE: BufferBlock was deprecated in SPIRV 1.3
// use StorageClassStorageBuffer instead.
#if SPV_VERSION >= 0x10300
spv::StorageClass storage_class = spv::StorageClassStorageBuffer;
#else
spv::StorageClass storage_class = spv::StorageClassUniform;
#endif
spv::StorageClass storage_class;
if (uniform) {
storage_class = spv::StorageClassUniform;
} else {
storage_class = spv::StorageClassStorageBuffer;
}

SType sarr_type = GetStructArrayType(value_type, 0);
SType ptr_type = GetPointerType(sarr_type, storage_class);
Expand Down
3 changes: 2 additions & 1 deletion src/target/spirv/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ class IRBuilder {
* \param binding The binding locaiton in descriptor set.
* \param The argument type.
*/
Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding);
Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding, bool uniform=false);

/*!
* \brief Declare POD arguments through push constants.
*
Expand Down

0 comments on commit 4f5ca8c

Please sign in to comment.