diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 5cd4812f41c4f..2fe5b30330aaa 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -45,6 +45,42 @@ static constexpr const int kVulkanMaxNumDevice = 8; /*! \brief TVM Vulkan binary pack magic number */ static constexpr const int kVulkanModuleMagic = 0x02700027; +struct VulkanBuffer { + VkBuffer buffer{VK_NULL_HANDLE}; + VkDeviceMemory memory{VK_NULL_HANDLE}; +}; + +/*! \brief A struct to represent Vulkan buffers backed by host visible memory */ +struct VulkanHostVisibleBuffer { + // A device where the buffer is allocated + VkDevice device{nullptr}; + // Vulkan buffer and memory + VulkanBuffer* vk_buf{nullptr}; + // The corresponding pointer to the host memory + void* host_addr{nullptr}; + // The size of the buffer in bytes + size_t size{0}; +}; + +using VulkanStagingBuffer = VulkanHostVisibleBuffer; +using VulkanUniformBuffer = VulkanHostVisibleBuffer; + +void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf) { + if (buf && buf->vk_buf) { + if (buf->host_addr != nullptr) { + vkUnmapMemory(buf->device, buf->vk_buf->memory); + } + if (buf->vk_buf->memory != VK_NULL_HANDLE) { + vkFreeMemory(buf->device, buf->vk_buf->memory, nullptr); + } + if (buf->vk_buf->buffer != VK_NULL_HANDLE) { + vkDestroyBuffer(buf->device, buf->vk_buf->buffer, nullptr); + } + buf->host_addr = nullptr; + delete buf->vk_buf; + } +} + class VulkanThreadEntry { public: VulkanThreadEntry(); @@ -60,19 +96,7 @@ class VulkanThreadEntry { pool.reset(); streams_.clear(); for (const auto& kv : staging_buffers_) { - if (!kv.second) { - continue; - } - auto& buf = *(kv.second); - if (buf.host_addr != nullptr) { - vkUnmapMemory(buf.device, buf.memory); - } - if (buf.memory != VK_NULL_HANDLE) { - vkFreeMemory(buf.device, buf.memory, nullptr); - } - if (buf.buffer != VK_NULL_HANDLE) { - vkDestroyBuffer(buf.device, buf.buffer, nullptr); - } + DeleteHostVisibleBuffer(kv.second.get()); } } @@ -80,15 +104,13 @@ class VulkanThreadEntry { std::unique_ptr pool; VulkanStream* Stream(size_t device_id); VulkanStagingBuffer* StagingBuffer(int device_id, size_t size); + void AllocateUniformBuffer(int device_id, size_t size); + VulkanUniformBuffer* GetUniformBuffer(int device_id, size_t size); private: std::unordered_map> streams_; std::unordered_map> staging_buffers_; -}; - -struct VulkanBuffer { - VkBuffer buffer{VK_NULL_HANDLE}; - VkDeviceMemory memory{VK_NULL_HANDLE}; + std::unordered_map> uniform_buffers_; }; struct VulkanPipeline { @@ -100,10 +122,107 @@ struct VulkanPipeline { VkPipelineLayout pipeline_layout{VK_NULL_HANDLE}; VkPipeline pipeline{VK_NULL_HANDLE}; VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE}; + bool use_ubo{false}; }; typedef dmlc::ThreadLocalStore VulkanThreadStore; +uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, + VkMemoryPropertyFlags req_prop) { + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); + + VkMemoryRequirements mem_reqs; + vkGetBufferMemoryRequirements(vctx.device, buffer, &mem_reqs); + uint32_t type_bits = mem_reqs.memoryTypeBits; + VkPhysicalDeviceMemoryProperties phy_mem_prop; + vkGetPhysicalDeviceMemoryProperties(vctx.phy_device, &phy_mem_prop); + for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { + if ((type_bits & 1) == 1 && + (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { + return i; + } + type_bits >>= 1; + } + LOG(FATAL) << "Requested memory type not found"; + return 0; +} + +VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes, + VkBufferUsageFlags usage) { + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = nbytes; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &(vctx.queue_family_index); + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.usage = usage; + return info; +} + +VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index) { + auto info = MakeBufferCreateInfo(vctx, nbytes, usage); + // create buffer + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); + + // bind to memory + bool dedicated_allocation = false; + VkMemoryRequirements2KHR req2; + + if (vctx.get_buffer_memory_requirements_2_functions) { + VkBufferMemoryRequirementsInfo2KHR req_info2; + req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; + req_info2.pNext = 0; + req_info2.buffer = buffer; + + req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; + req2.pNext = 0; + + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = 0; + req2.pNext = &dedicated_req; + + vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + vctx.device, &req_info2, &req2); + dedicated_allocation = + dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; + } + + VkDeviceMemory memory; + if (!dedicated_allocation) { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = info.size; + minfo.memoryTypeIndex = mem_type_index; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } else { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = req2.memoryRequirements.size; + minfo.memoryTypeIndex = mem_type_index; + + VkMemoryDedicatedAllocateInfoKHR mdinfo; + mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; + mdinfo.pNext = 0; + mdinfo.image = 0; + mdinfo.buffer = buffer; + minfo.pNext = &mdinfo; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } + VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); + VulkanBuffer* pbuf = new VulkanBuffer(); + pbuf->memory = memory; + pbuf->buffer = buffer; + return pbuf; +} + class VulkanDeviceAPI final : public DeviceAPI { public: VulkanDeviceAPI(); @@ -124,70 +243,9 @@ class VulkanDeviceAPI final : public DeviceAPI { nbytes = 1; } const auto& vctx = context(dev.device_id); - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = nbytes; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(vctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - // create buffer - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - // bind to memory - VkBufferMemoryRequirementsInfo2KHR req_info2; - req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; - req_info2.pNext = 0; - req_info2.buffer = buffer; - - VkMemoryRequirements2KHR req2; - req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; - req2.pNext = 0; - - VkMemoryDedicatedRequirementsKHR dedicated_req; - dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; - dedicated_req.pNext = 0; - req2.pNext = &dedicated_req; - - bool dedicated_allocation = false; - if (vctx.get_buffer_memory_requirements_2_functions) { - vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - vctx.device, &req_info2, &req2); - dedicated_allocation = - dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; - } - - VkDeviceMemory memory; - if (!dedicated_allocation) { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = nbytes; - minfo.memoryTypeIndex = vctx.compute_mtype_index; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } else { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = req2.memoryRequirements.size; - minfo.memoryTypeIndex = vctx.compute_mtype_index; - - VkMemoryDedicatedAllocateInfoKHR mdinfo; - mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; - mdinfo.pNext = 0; - mdinfo.image = 0; - mdinfo.buffer = buffer; - minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } - VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); - VulkanBuffer* pbuf = new VulkanBuffer(); - pbuf->memory = memory; - pbuf->buffer = buffer; - return pbuf; + return CreateBuffer(vctx, nbytes, usage, vctx.staging_mtype_index); } void FreeDataSpace(Device dev, void* ptr) final { @@ -252,14 +310,15 @@ class VulkanDeviceAPI final : public DeviceAPI { copy_info.srcOffset = from_offset; copy_info.dstOffset = 0; copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->buffer, 1, ©_info); + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->vk_buf->buffer, 1, + ©_info); }); VulkanThreadEntry::ThreadLocal()->Stream(dev_from.device_id)->Synchronize(); if (!vctx.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; - mrange.memory = temp->memory; + mrange.memory = temp->vk_buf->memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange)); @@ -277,14 +336,14 @@ class VulkanDeviceAPI final : public DeviceAPI { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; - mrange.memory = temp->memory; + mrange.memory = temp->vk_buf->memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange)); } VulkanThreadEntry::ThreadLocal() - ->Stream(dev_from.device_id) + ->Stream(dev_to.device_id) ->Launch([&](VulkanStreamState* state) { // 0: barrier(host->transfer) VkMemoryBarrier barrier_info; @@ -300,11 +359,12 @@ class VulkanDeviceAPI final : public DeviceAPI { copy_info.srcOffset = 0; copy_info.dstOffset = to_offset; copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, temp->buffer, to_buf->buffer, 1, ©_info); + vkCmdCopyBuffer(state->cmd_buffer_, temp->vk_buf->buffer, to_buf->buffer, 1, + ©_info); }); // TODO(tulloch): should we instead make the staging buffer a property of the // Stream? This would allow us to elide synchronizations here. - VulkanThreadEntry::ThreadLocal()->Stream(dev_from.device_id)->Synchronize(); + VulkanThreadEntry::ThreadLocal()->Stream(dev_to.device_id)->Synchronize(); } else { LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan" << ", from=" << from_dev_type << ", to=" << to_dev_type; @@ -794,11 +854,12 @@ class VulkanModuleNode final : public runtime::ModuleNode { return cp; } // Create new pipeline - auto pe = std::shared_ptr(new VulkanPipeline()); + auto pe = std::make_shared(); { // create shader auto sit = smap_.find(func_name); ICHECK(sit != smap_.end()); + pe->use_ubo = sit->second.flag & (1 << ShaderMetaDataFlagMask::kUseUBO); const std::vector& data = sit->second.data; VkShaderModuleCreateInfo shader_cinfo; shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; @@ -812,30 +873,35 @@ class VulkanModuleNode final : public runtime::ModuleNode { std::vector arg_template; uint32_t num_pod = 0, num_buffer = 0; + auto push_arg_info = [&arg_binding, &arg_template](uint32_t binding, + VkDescriptorType desc_type) { + { + VkDescriptorSetLayoutBinding bd; + bd.binding = binding; + bd.descriptorType = desc_type; + bd.descriptorCount = 1; + bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + bd.pImmutableSamplers = nullptr; + arg_binding.push_back(bd); + } + { + VkDescriptorUpdateTemplateEntryKHR tpl; + tpl.dstBinding = binding; + tpl.dstArrayElement = 0; + tpl.descriptorCount = 1; + tpl.descriptorType = desc_type; + tpl.offset = binding * sizeof(VkDescriptorBufferInfo); + tpl.stride = sizeof(VkDescriptorBufferInfo); + arg_template.push_back(tpl); + } + }; + { auto fit = fmap_.find(func_name); ICHECK(fit != fmap_.end()); for (DLDataType arg_type : fit->second.arg_types) { if (arg_type.code == kTVMOpaqueHandle) { - { - VkDescriptorSetLayoutBinding bd; - bd.binding = num_buffer; - bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - bd.descriptorCount = 1; - bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - bd.pImmutableSamplers = nullptr; - arg_binding.push_back(bd); - } - { - VkDescriptorUpdateTemplateEntryKHR tpl; - tpl.dstBinding = num_buffer; - tpl.dstArrayElement = 0; - tpl.descriptorCount = 1; - tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo); - tpl.stride = sizeof(VkDescriptorBufferInfo); - arg_template.push_back(tpl); - } + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); ++num_buffer; } else { ++num_pod; @@ -843,6 +909,13 @@ class VulkanModuleNode final : public runtime::ModuleNode { } } + size_t nbytes_scalars = num_pod * sizeof(ArgUnion64); + if (pe->use_ubo) { + // Use UBO instead of push constants + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); + VulkanThreadEntry::ThreadLocal()->AllocateUniformBuffer(device_id, nbytes_scalars); + } + { VkDescriptorSetLayoutCreateInfo descrip_cinfo; descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; @@ -894,7 +967,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { playout_cinfo.setLayoutCount = 1; playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout); - if (num_pack_args != 0) { + if (0 < nbytes_scalars && !pe->use_ubo) { playout_cinfo.pushConstantRangeCount = 1; playout_cinfo.pPushConstantRanges = &crange; ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize); @@ -990,57 +1063,64 @@ Module VulkanModuleCreate(std::unordered_map smap, VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); } -VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) { - if (!staging_buffers_[device_id]) { - staging_buffers_[device_id] = std::unique_ptr(new VulkanStagingBuffer()); +VulkanHostVisibleBuffer* GetOrAllocate( + int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, + std::unordered_map>* buffers_ptr, + bool sync_before_realloc = false) { + auto& buffers = *buffers_ptr; + if (!buffers[device_id]) { + buffers[device_id] = std::make_unique(); } - auto& buf = *(staging_buffers_[device_id]); + + auto& buf = *(buffers[device_id]); if (buf.device != nullptr && buf.size < size) { // free previous buffer - if (buf.host_addr != nullptr) { - vkUnmapMemory(buf.device, buf.memory); - } - if (buf.memory != VK_NULL_HANDLE) { - vkFreeMemory(buf.device, buf.memory, nullptr); - } - if (buf.buffer != VK_NULL_HANDLE) { - vkDestroyBuffer(buf.device, buf.buffer, nullptr); + if (sync_before_realloc) { + // For the deferred execution mode, we need to make sure that old tasks that use + // the older, smaller buffer get finished + // Synchronization on staging buffers is done after host to device memory copy + // For UBO, we sync here before we reallocate a larger buffer, to minimize synchronization + // points + VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Synchronize(); } - buf.host_addr = nullptr; - buf.memory = VK_NULL_HANDLE; - buf.buffer = VK_NULL_HANDLE; + DeleteHostVisibleBuffer(&buf); } + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); if (buf.device == nullptr) { buf.device = vctx.device; } - if (buf.memory == VK_NULL_HANDLE) { - // allocate the stagging buffer memory if necessary - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = size; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(vctx.queue_family_index); - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer))); - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = size; - minfo.memoryTypeIndex = vctx.staging_mtype_index; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory))); - VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0)); - VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr))); + if (buf.host_addr == nullptr) { + buf.vk_buf = CreateBuffer(vctx, size, usage, mem_type_index); + VULKAN_CALL(vkMapMemory(vctx.device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); buf.size = size; } - memset(buf.host_addr, 0, size); return &buf; } +VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) { + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + return GetOrAllocate(device_id, size, usage, vctx.staging_mtype_index, &staging_buffers_); +} + +void VulkanThreadEntry::AllocateUniformBuffer(int device_id, size_t size) { + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + auto info = MakeBufferCreateInfo(vctx, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); + auto mem_type_index = FindMemoryType(vctx, info, prop); + GetOrAllocate(device_id, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, mem_type_index, + &uniform_buffers_, true); +} + +VulkanUniformBuffer* VulkanThreadEntry::GetUniformBuffer(int device_id, size_t size) { + auto& buf = uniform_buffers_[device_id]; + ICHECK(buf); + ICHECK_GE(buf->size, size); + return buf.get(); +} + VulkanThreadEntry::VulkanThreadEntry() : pool(std::make_unique(static_cast(kDLVulkan), VulkanDeviceAPI::Global())) { @@ -1076,6 +1156,16 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, binfo.range = VK_WHOLE_SIZE; descriptor_buffers[i] = binfo; } + const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); + if (pipeline->use_ubo) { + auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); + CHECK(ubo->host_addr) << "The UBO host buffer is not allocated"; + VkDescriptorBufferInfo binfo; + binfo.buffer = ubo->vk_buf->buffer; + binfo.offset = 0; + binfo.range = VK_WHOLE_SIZE; + descriptor_buffers.push_back(binfo); + } if (vctx.UseImmediate()) { // Can safely capture by reference as this lambda is immediately executed on the calling thread. VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { @@ -1084,11 +1174,16 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, descriptor_buffers.data()); - if (num_pack_args_ != 0) { + + if (pipeline->use_ubo) { + auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); + memcpy(ubo->host_addr, pack_args, nbytes_scalars); + } else if (num_pack_args_ > 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), pack_args); } + vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); VkMemoryBarrier barrier_info; barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; @@ -1115,24 +1210,36 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, write_descriptor_sets[i].dstBinding = i; write_descriptor_sets[i].dstArrayElement = 0; write_descriptor_sets[i].descriptorCount = 1; - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; write_descriptor_sets[i].pImageInfo = 0; write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]); write_descriptor_sets[i].pTexelBufferView = 0; + + if (pipeline->use_ubo && i == write_descriptor_sets.size() - 1) { + // The last binding is for UBO + write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; + } else { + write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + } } vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(), 0, 0); }; - const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) { + const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, + device_id](VulkanStreamState* state) { vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, nullptr); - if (pack_args_storage.size() != 0) { + + if (pipeline->use_ubo) { + auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); + memcpy(ubo->host_addr, pack_args_storage.data(), nbytes_scalars); + } else if (num_pack_args_ > 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data()); } + vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); VkMemoryBarrier barrier_info; barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 3083ba6f9ce4d..2ef879a487a65 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -35,6 +35,11 @@ namespace tvm { namespace runtime { namespace vulkan { +const int kMaxPushConstantsBytes = 128; + +/*! \brief A mask used when we attach additional information to shaders */ +enum ShaderMetaDataFlagMask { kUseUBO = 0 }; + inline const char* VKGetErrorString(VkResult error) { switch (error) { case VK_SUCCESS: @@ -105,14 +110,6 @@ struct VulkanGetBufferMemoryRequirements2Functions { PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr}; }; -struct VulkanStagingBuffer { - VkDevice device{nullptr}; - VkBuffer buffer{VK_NULL_HANDLE}; - VkDeviceMemory memory{VK_NULL_HANDLE}; - void* host_addr{nullptr}; - size_t size{0}; -}; - struct VulkanContext { // phyiscal device VkPhysicalDevice phy_device{nullptr}; diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index a0f0b76eefbd0..9f9718bef18e4 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -88,10 +88,9 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); - - VulkanShader shader; std::string entry = webgpu_restriction ? "main" : f_name; - shader.data = cg.BuildFunction(f, entry); + + VulkanShader shader = cg.BuildFunction(f, entry); if (webgpu_restriction) { for (auto param : f->params) { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 24608ebc93f42..5b26e9acf5a2f 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -30,10 +30,14 @@ #include +#include "../../runtime/pack_args.h" +#include "../../runtime/vulkan/vulkan_common.h" +#include "../../runtime/vulkan/vulkan_shader.h" + namespace tvm { namespace codegen { -std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { +runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; @@ -66,16 +70,28 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: spirv::Value func_ptr = builder_->NewFunction(); builder_->StartFunction(func_ptr); - // All the POD arguments are passed in through PushConstant + runtime::VulkanShader shader; + if (pod_args.size() != 0) { std::vector value_types; for (size_t i = 0; i < pod_args.size(); ++i) { value_types.push_back(builder_->GetSType(pod_args[i].dtype())); } - spirv::Value ptr = builder_->DeclarePushConstant(value_types); - for (size_t i = 0; i < pod_args.size(); ++i) { - spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); - var_map_[pod_args[i].get()] = value; + if (pod_args.size() * sizeof(runtime::ArgUnion64) <= runtime::vulkan::kMaxPushConstantsBytes) { + spirv::Value ptr = builder_->DeclarePushConstant(value_types); + for (size_t i = 0; i < pod_args.size(); ++i) { + spirv::Value value = + builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); + var_map_[pod_args[i].get()] = value; + } + } else { + shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO; + // If we need to pass more arguments than push constants could handle, we use UBO. + spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer); + for (size_t i = 0; i < pod_args.size(); ++i) { + spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast(i)); + var_map_[pod_args[i].get()] = value; + } } } this->VisitStmt(f->body); @@ -85,7 +101,8 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: builder_->CommitKernelFunction(func_ptr, name); - return builder_->Finalize(); + shader.data = builder_->Finalize(); + return shader; } void CodeGenSPIRV::InitFuncState() { diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 1e80fcc4a9318..e3d6c153d06fd 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -36,6 +36,7 @@ #include #include "../../runtime/thread_storage_scope.h" +#include "../../runtime/vulkan/vulkan_shader.h" #include "ir_builder.h" namespace tvm { @@ -55,7 +56,7 @@ class CodeGenSPIRV : public ExprFunctor, * \param name The name of the target function. * \return The final spirv module. */ - virtual std::vector BuildFunction(const PrimFunc& f, const std::string& name); + virtual runtime::VulkanShader BuildFunction(const PrimFunc& f, const std::string& name); /*! * \brief Create Value for expression e * \param e The expression to be created value for. diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 5a1457387ae58..cd48c93530ec2 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -205,8 +205,8 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set return val; } -Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { - ICHECK_EQ(push_const_.id, 0); +Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, + spv::StorageClass storage_class, ValueKind kind) { SType struct_type; struct_type.id = id_counter_++; struct_type.type = DataType::Handle(); @@ -226,22 +226,26 @@ Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { ICHECK_EQ(nbits % 8, 0); uint32_t bytes = (nbits / 8); if (t.bits() == 32) { - // In our Vulkan runtime, each push constant always occupies 64 bit. + // In our Vulkan runtime, each scalar argument always occupies 64 bit. offset += bytes * 2; } else { ICHECK_EQ(t.bits(), 64); offset += bytes; } } - // Decorate push constants as UBO this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); - SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant); - Value val = NewValue(ptr_type, kPushConstantPtr); - ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_); + SType ptr_type = GetPointerType(struct_type, storage_class); + Value val = NewValue(ptr_type, kind); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); return val; } +Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { + ICHECK_EQ(push_const_.id, 0); + return DeclareStorageVariable(value_types, spv::StorageClassPushConstant, kPushConstantPtr); +} + Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) { SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant); Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, @@ -249,6 +253,19 @@ Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint return this->MakeValue(spv::OpLoad, v_type, ptr); } +Value IRBuilder::DeclareUniformBuffer(const std::vector& value_types, uint32_t binding) { + Value val = DeclareStorageVariable(value_types, spv::StorageClassUniform, kUniformPtr); + this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); + return val; +} + +Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t index) { + SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassUniform); + Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, + IntImm(t_int32_, static_cast(index))); + return this->MakeValue(spv::OpLoad, v_type, ptr); +} + Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); } void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) { diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 8a08048e1955d..05a2bc631743f 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -60,7 +60,8 @@ enum ValueKind { kStructArrayPtr, kPushConstantPtr, kFunction, - kExtInst + kExtInst, + kUniformPtr }; /*! \brief Represent the SPIRV Value */ @@ -473,6 +474,7 @@ class IRBuilder { * \param The argument type. */ Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding); + /*! * \brief Declare POD arguments through push constants. * @@ -488,6 +490,23 @@ class IRBuilder { * \return the value of push constant */ Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index); + + /*! + * \brief Declare POD arguments through uniform buffer. + * + * \note Only call this function once! + * \param value_types The values in the uniform buffer + * \param binding The binding locaiton in descriptor set + * \return reference to self. + */ + Value DeclareUniformBuffer(const std::vector& value_types, uint32_t binding); + /*! + * \brief Get i-th uniform constant + * \param v_type The value type + * \param index The uniform index + * \return the value of uniform constant + */ + Value GetUniform(Value ptr_ubo, const SType& v_type, uint32_t index); /*! * \brief Declare a new function * \return The created function ID. @@ -555,6 +574,17 @@ class IRBuilder { val.flag = flag; return val; } + + /*! + * \brief The common function to declare push constants or uniform buffer + * \param value_types The values in the push constants or uniform buffer + * \param storage_class An enum defined by SPIR-V indicating push constant or uniform + * \param kind An enum indicating push constant or uniform + * \return The created new label + */ + Value DeclareStorageVariable(const std::vector& value_types, + spv::StorageClass storage_class, ValueKind kind); + // get constant given value encoded in uint64_t Value GetConst_(const SType& dtype, const uint64_t* pvalue); // declare type