From 4f5ca8cda0701b71a1d89369fc3798f92cef4857 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Mar 2021 11:05:55 +0900 Subject: [PATCH] ubo codegen first cut --- src/target/spirv/codegen_spirv.cc | 24 +++++++++++++++++++----- src/target/spirv/ir_builder.cc | 13 +++++++------ src/target/spirv/ir_builder.h | 3 ++- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 24608ebc93f4..633a0558f26e 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -66,16 +66,30 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: spirv::Value func_ptr = builder_->NewFunction(); builder_->StartFunction(func_ptr); - // All the POD arguments are passed in through PushConstant if (pod_args.size() != 0) { std::vector value_types; for (size_t i = 0; i < pod_args.size(); ++i) { value_types.push_back(builder_->GetSType(pod_args[i].dtype())); } - spirv::Value ptr = builder_->DeclarePushConstant(value_types); - for (size_t i = 0; i < pod_args.size(); ++i) { - spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); - var_map_[pod_args[i].get()] = value; + // All the POD arguments are passed in through PushConstant + if (pod_args.size() * 8 <= 128) { + spirv::Value ptr = builder_->DeclarePushConstant(value_types); + for (size_t i = 0; i < pod_args.size(); ++i) { + spirv::Value value = + builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); + var_map_[pod_args[i].get()] = value; + } + } else { + DataType value_storage_type = DataType::Int(64); + spirv::Value ptr_ubo = + builder_->BufferArgument(builder_->GetSType(value_storage_type), 0, num_buffer, true); + for (size_t i = 0; i < pod_args.size(); ++i) { + spirv::SType ptr_type = builder_->GetPointerType(value_types[i], spv::StorageClassUniform); + spirv::Value ptr = builder_->StructArrayAccess( + ptr_type, ptr_ubo, MakeValue(PrimExpr(static_cast(i * 8)))); + var_map_[pod_args[i].get()] = + builder_->MakeValue(spv::OpLoad, value_types[i], ptr, spv::MemoryAccessMaskNone); + } } } this->VisitStmt(f->body); diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 5a1457387ae5..300355100bd7 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -185,14 +185,15 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { } Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set, - uint32_t binding) { + uint32_t binding, bool uniform) { // NOTE: BufferBlock was deprecated in SPIRV 1.3 // use StorageClassStorageBuffer instead. -#if SPV_VERSION >= 0x10300 - spv::StorageClass storage_class = spv::StorageClassStorageBuffer; -#else - spv::StorageClass storage_class = spv::StorageClassUniform; -#endif + spv::StorageClass storage_class; + if (uniform) { + storage_class = spv::StorageClassUniform; + } else { + storage_class = spv::StorageClassStorageBuffer; + } SType sarr_type = GetStructArrayType(value_type, 0); SType ptr_type = GetPointerType(sarr_type, storage_class); diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 8a08048e1955..ffda3333b57d 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -472,7 +472,8 @@ class IRBuilder { * \param binding The binding locaiton in descriptor set. * \param The argument type. */ - Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding); + Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding, bool uniform=false); + /*! * \brief Declare POD arguments through push constants. *