From e1788b8d5134dcd6b7bde17b95caca0ed0b24edd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 20 Mar 2021 04:34:44 +0900 Subject: [PATCH] allocate and bind ubo --- src/runtime/vulkan/vulkan.cc | 169 ++++++++++++++++++++--------------- 1 file changed, 98 insertions(+), 71 deletions(-) diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 6e050e5b8603..4552c27106e2 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -91,6 +91,11 @@ struct VulkanBuffer { VkDeviceMemory memory{VK_NULL_HANDLE}; }; +struct UniformBuffer { + VulkanBuffer* vk_buf; + ArgUnion64* host_buf; +}; + struct VulkanPipeline { VulkanContext* vctx_{nullptr}; VkShaderModule shader{VK_NULL_HANDLE}; @@ -100,11 +105,80 @@ struct VulkanPipeline { VkPipelineLayout pipeline_layout{VK_NULL_HANDLE}; VkPipeline pipeline{VK_NULL_HANDLE}; VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE}; - VulkanBuffer ubo; + UniformBuffer ubo; }; typedef dmlc::ThreadLocalStore VulkanThreadStore; +VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage) { + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = nbytes; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &(vctx.queue_family_index); + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.usage = usage; + // create buffer + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); + + // bind to memory + bool dedicated_allocation = false; + VkMemoryRequirements2KHR req2; + + if (vctx.get_buffer_memory_requirements_2_functions) { + VkBufferMemoryRequirementsInfo2KHR req_info2; + req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; + req_info2.pNext = 0; + req_info2.buffer = buffer; + + req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; + req2.pNext = 0; + + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = 0; + req2.pNext = &dedicated_req; + + vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + vctx.device, &req_info2, &req2); + dedicated_allocation = + dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; + } + + VkDeviceMemory memory; + // TODO: revisit memoryTypeIndex + if (!dedicated_allocation) { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = nbytes; + minfo.memoryTypeIndex = vctx.compute_mtype_index; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } else { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = req2.memoryRequirements.size; + minfo.memoryTypeIndex = vctx.compute_mtype_index; + + VkMemoryDedicatedAllocateInfoKHR mdinfo; + mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; + mdinfo.pNext = 0; + mdinfo.image = 0; + mdinfo.buffer = buffer; + minfo.pNext = &mdinfo; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } + VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); + VulkanBuffer* pbuf = new VulkanBuffer(); + pbuf->memory = memory; + pbuf->buffer = buffer; + return pbuf; +} + class VulkanDeviceAPI final : public DeviceAPI { public: VulkanDeviceAPI(); @@ -125,70 +199,9 @@ class VulkanDeviceAPI final : public DeviceAPI { nbytes = 1; } const auto& vctx = context(dev.device_id); - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = nbytes; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(vctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - // create buffer - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - // bind to memory - VkBufferMemoryRequirementsInfo2KHR req_info2; - req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; - req_info2.pNext = 0; - req_info2.buffer = buffer; - - VkMemoryRequirements2KHR req2; - req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; - req2.pNext = 0; - - VkMemoryDedicatedRequirementsKHR dedicated_req; - dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; - dedicated_req.pNext = 0; - req2.pNext = &dedicated_req; - - bool dedicated_allocation = false; - if (vctx.get_buffer_memory_requirements_2_functions) { - vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - vctx.device, &req_info2, &req2); - dedicated_allocation = - dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; - } - - VkDeviceMemory memory; - if (!dedicated_allocation) { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = nbytes; - minfo.memoryTypeIndex = vctx.compute_mtype_index; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } else { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = req2.memoryRequirements.size; - minfo.memoryTypeIndex = vctx.compute_mtype_index; - - VkMemoryDedicatedAllocateInfoKHR mdinfo; - mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; - mdinfo.pNext = 0; - mdinfo.image = 0; - mdinfo.buffer = buffer; - minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } - VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); - VulkanBuffer* pbuf = new VulkanBuffer(); - pbuf->memory = memory; - pbuf->buffer = buffer; - return pbuf; + return CreateBuffer(vctx, nbytes, usage); } void FreeDataSpace(Device dev, void* ptr) final { @@ -784,6 +797,11 @@ class VulkanModuleNode final : public runtime::ModuleNode { vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr); vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr); vkDestroyShaderModule(vctx.device, pe->shader, nullptr); + // UBO + vkDestroyBuffer(vctx.device, pe->ubo.vk_buf->buffer, nullptr); + vkFreeMemory(vctx.device, pe->ubo.vk_buf->memory, nullptr); + delete pe->ubo.vk_buf; + delete[] pe->ubo.host_buf; } } } @@ -846,14 +864,14 @@ class VulkanModuleNode final : public runtime::ModuleNode { } } - if (num_pod != 0 && num_pod * 8 > 120) { + size_t nbytes_scalars = num_pod * sizeof(ArgUnion64); + if (nbytes_scalars > 120) { ICHECK(num_pod == num_pack_args); // UBO - // TODO: allocate ubo { VkDescriptorSetLayoutBinding bd; bd.binding = num_buffer; - bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + bd.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; bd.descriptorCount = 1; bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; bd.pImmutableSamplers = nullptr; @@ -864,7 +882,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { tpl.dstBinding = num_buffer; tpl.dstArrayElement = 0; tpl.descriptorCount = 1; - tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + tpl.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo); tpl.stride = sizeof(VkDescriptorBufferInfo); arg_template.push_back(tpl); @@ -951,6 +969,15 @@ class VulkanModuleNode final : public runtime::ModuleNode { VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe->pipeline))); + if (nbytes_scalars > 120) { + // Allocate, bind and map UBO + UniformBuffer ubo = pe->ubo; + ubo.host_buf = new ArgUnion64[nbytes_scalars]; + ubo.vk_buf = CreateBuffer(vctx, nbytes_scalars, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); + void* host_ptr = ubo.host_buf; + vkMapMemory(vctx.device, ubo.vk_buf->memory, 0, nbytes_scalars, 0, &host_ptr); + } + if (vctx.UseImmediate()) { VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo; descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR; @@ -1104,11 +1131,11 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, binfo.range = VK_WHOLE_SIZE; descriptor_buffers[i] = binfo; } - if (num_pack_args_ != 0 && num_pack_args_ * 8 > 120) { + if (num_pack_args_ != 0 && num_pack_args_ * sizeof(ArgUnion64) > 120) { // UBO - // TODO: copy pack_args + memcpy(pipeline->ubo.host_buf, pack_args, num_pack_args_ * sizeof(ArgUnion64)); VkDescriptorBufferInfo binfo; - binfo.buffer = pipeline->ubo.buffer; + binfo.buffer = pipeline->ubo.vk_buf->buffer; binfo.offset = 0; binfo.range = VK_WHOLE_SIZE; descriptor_buffers.push_back(binfo);