diff --git a/src/runtime/vulkan/vulkan_context.cc b/src/runtime/vulkan/vulkan_context.cc index 659e6bd225f6..7e59c9da47b5 100644 --- a/src/runtime/vulkan/vulkan_context.cc +++ b/src/runtime/vulkan/vulkan_context.cc @@ -19,6 +19,7 @@ #include "vulkan_context.h" +#include #include #include "vulkan_common.h" @@ -29,6 +30,176 @@ namespace tvm { namespace runtime { namespace vulkan { +VulkanDeviceProperties::VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_dev, + const std::vector instance_extensions, + const std::vector device_extensions) { + auto has_instance_extension = [&](const char* query) { + return std::any_of(instance_extensions.begin(), instance_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); + }; + + auto has_device_extension = [&](const char* query) { + return std::any_of(device_extensions.begin(), device_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); + }; + + /////////////////////////////////////////////////////////////// + // Query properties from Vulkan API // + /////////////////////////////////////////////////////////////// + + // Declare output locations for properties + VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2}; + VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES}; + VkPhysicalDeviceSubgroupProperties subgroup = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; + + // Need to do initial query in order to check the apiVersion. + vkGetPhysicalDeviceProperties(phy_dev, &properties.properties); + + // Set up linked list for property query + { + void** pp_next = &properties.pNext; + if (has_device_extension("VK_KHR_driver_properties")) { + *pp_next = &driver; + pp_next = &driver.pNext; + } + if (properties.properties.apiVersion >= VK_API_VERSION_1_1) { + *pp_next = &subgroup; + pp_next = &subgroup.pNext; + } + } + + // Declare output locations for features + VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + // Set up linked list for feature query + { + void** pp_next = &features.pNext; + if (has_device_extension("VK_KHR_8bit_storage")) { + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (has_device_extension("VK_KHR_16bit_storage")) { + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + if (has_device_extension("VK_KHR_shader_float16_int8")) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + } + + if (has_instance_extension("VK_KHR_get_physical_device_properties2")) { + // Preferred method, call to get all properties that can be queried. + auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); + vkGetPhysicalDeviceProperties2KHR(phy_dev, &properties); + + auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); + vkGetPhysicalDeviceFeatures2KHR(phy_dev, &features); + } else { + // Fallback, get as many features as we can from the Vulkan1.0 + // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. + vkGetPhysicalDeviceFeatures(phy_dev, &features.features); + } + + /////////////////////////////////////////////////////////////// + // Fill member variables from Vulkan structures // + /////////////////////////////////////////////////////////////// + + supports_float16 = float16_int8.shaderFloat16; + supports_float32 = true; + supports_float64 = features.features.shaderFloat64; + supports_int8 = float16_int8.shaderInt8; + supports_int16 = features.features.shaderInt16; + supports_int32 = true; + supports_int64 = features.features.shaderInt64; + supports_8bit_buffer = storage_8bit.storageBuffer8BitAccess; + supports_16bit_buffer = storage_16bit.storageBuffer16BitAccess; + supports_storage_buffer_storage_class = + has_device_extension("VK_KHR_storage_buffer_storage_class"); + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + supports_push_descriptor = has_device_extension("VK_KHR_push_descriptor") && + has_device_extension("VK_KHR_descriptor_update_template"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); + if (disable && *disable) { + supports_push_descriptor = false; + } + } + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + supports_dedicated_allocation = has_device_extension("VK_KHR_get_memory_requirements2") && + has_device_extension("VK_KHR_dedicated_allocation"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); + if (disable && *disable) { + supports_dedicated_allocation = false; + } + } + + // The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically + // needed, since it will be set so long at least one queue has + // VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future + // confusion.. + supported_subgroup_operations = + (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; + + max_num_threads = properties.properties.limits.maxComputeWorkGroupInvocations; + + // Even if we can't query it, warp size must be at least 1. + thread_warp_size = std::max(subgroup.subgroupSize, 1U); + + max_block_size_x = properties.properties.limits.maxComputeWorkGroupSize[0]; + max_block_size_y = properties.properties.limits.maxComputeWorkGroupSize[1]; + max_block_size_z = properties.properties.limits.maxComputeWorkGroupSize[2]; + max_push_constants_size = properties.properties.limits.maxPushConstantsSize; + max_uniform_buffer_range = properties.properties.limits.maxUniformBufferRange; + max_storage_buffer_range = properties.properties.limits.maxStorageBufferRange; + max_per_stage_descriptor_storage_buffer = + properties.properties.limits.maxPerStageDescriptorStorageBuffers; + max_shared_memory_per_block = properties.properties.limits.maxComputeSharedMemorySize; + device_name = properties.properties.deviceName; + driver_version = properties.properties.driverVersion; + + // By default, use the maximum API version that the driver allows, + // so that any supported features can be used by TVM shaders. + // However, if we can query the conformance version, then limit to + // only using the api version that passes the vulkan conformance + // tests. + vulkan_api_version = properties.properties.apiVersion; + if (has_device_extension("VK_KHR_driver_properties")) { + auto api_major = VK_VERSION_MAJOR(vulkan_api_version); + auto api_minor = VK_VERSION_MINOR(vulkan_api_version); + if ((api_major > driver.conformanceVersion.major) || + ((api_major == driver.conformanceVersion.major) && + (api_minor > driver.conformanceVersion.minor))) { + vulkan_api_version = + VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0); + } + } + + // From "Versions and Formats" section of Vulkan spec. + max_spirv_version = 0x10000; + if (vulkan_api_version >= VK_API_VERSION_1_2) { + max_spirv_version = 0x10500; + } else if (has_device_extension("VK_KHR_spirv_1_4")) { + max_spirv_version = 0x10400; + } else if (vulkan_api_version >= VK_API_VERSION_1_1) { + max_spirv_version = 0x10300; + } +} + VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) { vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); diff --git a/src/runtime/vulkan/vulkan_context.h b/src/runtime/vulkan/vulkan_context.h index 08f0f97def14..158a53043c7b 100644 --- a/src/runtime/vulkan/vulkan_context.h +++ b/src/runtime/vulkan/vulkan_context.h @@ -24,6 +24,8 @@ #include #include +#include +#include #include "vulkan/vulkan_core.h" #include "vulkan_buffer.h" @@ -47,14 +49,58 @@ struct VulkanGetBufferMemoryRequirements2Functions { PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr}; }; +/*! + * \brief Stores the capabilities/limits queried from the physical device. + * + * The member variables here have a 1-1 mapping to Target parameters, + * if target->kind->device_type==kDLVulkan. A separate struct is used + * to maintain the boundary between the Vulkan runtime in + * libtvm_runtime.so, and the Target object in libtvm.so. + */ +struct VulkanDeviceProperties { + VulkanDeviceProperties() {} + VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_device, + const std::vector instance_extensions, + const std::vector device_extensions); + + bool supports_float16{false}; + bool supports_float32{true}; + bool supports_float64{false}; + bool supports_int8{false}; + bool supports_int16{false}; + bool supports_int32{true}; + bool supports_int64{false}; + bool supports_8bit_buffer{false}; + bool supports_16bit_buffer{false}; + bool supports_storage_buffer_storage_class{false}; + bool supports_push_descriptor{false}; + bool supports_dedicated_allocation{false}; + uint32_t supported_subgroup_operations{0}; + uint32_t max_num_threads{1}; + uint32_t thread_warp_size{1}; + uint32_t max_block_size_x{1}; + uint32_t max_block_size_y{1}; + uint32_t max_block_size_z{1}; + uint32_t max_push_constants_size{128}; + uint32_t max_uniform_buffer_range{16384}; + uint32_t max_storage_buffer_range{1 << 27}; + uint32_t max_per_stage_descriptor_storage_buffer{4}; + uint32_t max_shared_memory_per_block{16384}; + std::string device_name{"unknown device name"}; + uint32_t driver_version{0}; + uint32_t vulkan_api_version{VK_API_VERSION_1_0}; + uint32_t max_spirv_version{0x10000}; +}; + struct VulkanContext { // physical device VkPhysicalDevice phy_device{nullptr}; + // Cached device properties, queried through Vulkan API. + VulkanDeviceProperties device_properties; + // Phyiscal device property VkPhysicalDeviceProperties phy_device_prop; - // Target that best represents this physical device - Target target; // Memory type index for staging. uint32_t staging_mtype_index{0}; // whether staging is coherent diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index d318204ce2c1..7cea2489cb1b 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -173,7 +173,8 @@ VulkanDeviceAPI::VulkanDeviceAPI() { return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); }(); - ctx.target = GetDeviceDescription(instance_, phy_dev, instance_extensions, device_extensions); + ctx.device_properties = + VulkanDeviceProperties(instance_, phy_dev, instance_extensions, device_extensions); { // Enable all features we may use that a device supports. @@ -188,30 +189,29 @@ VulkanDeviceAPI::VulkanDeviceAPI() { void** pp_next = &enabled_features.pNext; bool needs_float16_int8 = false; - auto has_support = [&](const char* name) { return ctx.target->GetAttr(name).value(); }; - if (has_support("supports_float16")) { + if (ctx.device_properties.supports_float16) { float16_int8.shaderFloat16 = true; needs_float16_int8 = true; } - if (has_support("supports_float64")) { + if (ctx.device_properties.supports_float64) { enabled_features.features.shaderFloat64 = true; } - if (has_support("supports_int8")) { + if (ctx.device_properties.supports_int8) { float16_int8.shaderInt8 = true; needs_float16_int8 = true; } - if (has_support("supports_int16")) { + if (ctx.device_properties.supports_int16) { enabled_features.features.shaderInt16 = true; } - if (has_support("supports_int64")) { + if (ctx.device_properties.supports_int64) { enabled_features.features.shaderInt64 = true; } - if (has_support("supports_8bit_buffer")) { + if (ctx.device_properties.supports_8bit_buffer) { storage_8bit.storageBuffer8BitAccess = true; *pp_next = &storage_8bit; pp_next = &storage_8bit.pNext; } - if (has_support("supports_16bit_buffer")) { + if (ctx.device_properties.supports_16bit_buffer) { storage_16bit.storageBuffer16BitAccess = true; *pp_next = &storage_16bit; pp_next = &storage_16bit.pNext; @@ -314,12 +314,12 @@ VulkanDeviceAPI::VulkanDeviceAPI() { } ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; - if (ctx.target->GetAttr("supports_push_descriptor").value()) { + if (ctx.device_properties.supports_push_descriptor) { ctx.descriptor_template_khr_functions = std::make_unique(ctx.device); } - if (ctx.target->GetAttr("supports_dedicated_allocation").value()) { + if (ctx.device_properties.supports_dedicated_allocation) { ctx.get_buffer_memory_requirements_2_functions = std::make_unique(ctx.device); } @@ -352,25 +352,24 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) *rv = static_cast(index < context_.size()); return; } - ICHECK_LT(index, context_.size()) << "Invalid device id " << index; - const auto& target = context(index).target; + const auto& prop = context(index).device_properties; switch (kind) { case kMaxThreadsPerBlock: { - *rv = target->GetAttr("max_num_threads").value(); + *rv = int64_t(prop.max_num_threads); break; } case kMaxSharedMemoryPerBlock: { - *rv = target->GetAttr("max_shared_memory_per_block"); + *rv = int64_t(prop.max_shared_memory_per_block); break; } case kWarpSize: { - *rv = target->GetAttr("thread_warp_size").value(); + *rv = int64_t(prop.thread_warp_size); break; } case kComputeVersion: { - int64_t value = target->GetAttr("vulkan_api_version").value(); + int64_t value = prop.vulkan_api_version; std::ostringstream os; os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." << VK_VERSION_PATCH(value); @@ -378,7 +377,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) break; } case kDeviceName: - *rv = target->GetAttr("device_name").value(); + *rv = prop.device_name; break; case kMaxClockRate: @@ -392,9 +391,8 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) case kMaxThreadDimensions: { std::stringstream ss; // use json string to return multiple int values; - ss << "[" << target->GetAttr("max_block_size_x").value() << ", " - << target->GetAttr("max_block_size_y").value() << ", " - << target->GetAttr("max_block_size_z").value() << "]"; + ss << "[" << prop.max_block_size_x << ", " << prop.max_block_size_y << ", " + << prop.max_block_size_z << "]"; *rv = ss.str(); break; } @@ -410,7 +408,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) break; case kDriverVersion: { - int64_t value = target->GetAttr("driver_version").value(); + int64_t value = prop.driver_version; std::ostringstream os; os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." << VK_VERSION_PATCH(value); @@ -420,6 +418,93 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) } } +void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) { + size_t index = static_cast(dev.device_id); + const auto& prop = context(index).device_properties; + + if (property == "supports_float16") { + *rv = prop.supports_float16; + } + if (property == "supports_float32") { + *rv = prop.supports_float32; + } + if (property == "supports_float64") { + *rv = prop.supports_float64; + } + if (property == "supports_int8") { + *rv = prop.supports_int8; + } + if (property == "supports_int16") { + *rv = prop.supports_int16; + } + if (property == "supports_int32") { + *rv = prop.supports_int32; + } + if (property == "supports_int64") { + *rv = prop.supports_int64; + } + if (property == "supports_8bit_buffer") { + *rv = prop.supports_8bit_buffer; + } + if (property == "supports_16bit_buffer") { + *rv = prop.supports_16bit_buffer; + } + if (property == "supports_storage_buffer_storage_class") { + *rv = prop.supports_storage_buffer_storage_class; + } + if (property == "supports_push_descriptor") { + *rv = prop.supports_push_descriptor; + } + if (property == "supports_dedicated_allocation") { + *rv = prop.supports_dedicated_allocation; + } + if (property == "supported_subgroup_operations") { + *rv = int64_t(prop.supported_subgroup_operations); + } + if (property == "max_num_threads") { + *rv = int64_t(prop.max_num_threads); + } + if (property == "thread_warp_size") { + *rv = int64_t(prop.thread_warp_size); + } + if (property == "max_block_size_x") { + *rv = int64_t(prop.max_block_size_x); + } + if (property == "max_block_size_y") { + *rv = int64_t(prop.max_block_size_y); + } + if (property == "max_block_size_z") { + *rv = int64_t(prop.max_block_size_z); + } + if (property == "max_push_constants_size") { + *rv = int64_t(prop.max_push_constants_size); + } + if (property == "max_uniform_buffer_range") { + *rv = int64_t(prop.max_uniform_buffer_range); + } + if (property == "max_storage_buffer_range") { + *rv = int64_t(prop.max_storage_buffer_range); + } + if (property == "max_per_stage_descriptor_storage_buffer") { + *rv = int64_t(prop.max_per_stage_descriptor_storage_buffer); + } + if (property == "max_shared_memory_per_block") { + *rv = int64_t(prop.max_shared_memory_per_block); + } + if (property == ":string device_name") { + *rv = prop.device_name; + } + if (property == "driver_version") { + *rv = int64_t(prop.driver_version); + } + if (property == "vulkan_api_version") { + *rv = int64_t(prop.vulkan_api_version); + } + if (property == "max_spirv_version") { + *rv = int64_t(prop.max_spirv_version); + } +} + void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) { if (nbytes == 0) { @@ -610,12 +695,11 @@ std::vector VulkanDeviceAPI::FindEnabledExtensions( } const VulkanContext& VulkanDeviceAPI::context(size_t device_id) const { - ICHECK_LT(device_id, context_.size()); + ICHECK_LT(device_id, context_.size()) << "Requested Vulkan device_id=" << device_id + << ", but only " << context_.size() << " devices present"; return context_[device_id]; } -Target VulkanDeviceAPI::GenerateTarget(size_t device_id) const { return context(device_id).target; } - std::vector VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice phy_dev) { uint32_t queue_prop_count = 0; vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr); @@ -641,189 +725,17 @@ std::vector VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice return result; } -Target VulkanDeviceAPI::GetDeviceDescription(VkInstance instance, VkPhysicalDevice dev, - const std::vector& instance_extensions, - const std::vector& device_extensions) { - auto has_extension = [&](const char* query) { - return std::any_of(device_extensions.begin(), device_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }) || - std::any_of(instance_extensions.begin(), instance_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - - // Declare output locations for properties - VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2}; - VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES}; - VkPhysicalDeviceSubgroupProperties subgroup = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; - - // Need to do initial query in order to check the apiVersion. - vkGetPhysicalDeviceProperties(dev, &properties.properties); - - // Set up linked list for property query - { - void** pp_next = &properties.pNext; - if (has_extension("VK_KHR_driver_properties")) { - *pp_next = &driver; - pp_next = &driver.pNext; - } - if (properties.properties.apiVersion >= VK_API_VERSION_1_1) { - *pp_next = &subgroup; - pp_next = &subgroup.pNext; - } - } - - // Declare output locations for features - VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; - VkPhysicalDevice8BitStorageFeatures storage_8bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; - VkPhysicalDevice16BitStorageFeatures storage_16bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; - VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; - - // Set up linked list for feature query - { - void** pp_next = &features.pNext; - if (has_extension("VK_KHR_8bit_storage")) { - *pp_next = &storage_8bit; - pp_next = &storage_8bit.pNext; - } - if (has_extension("VK_KHR_16bit_storage")) { - *pp_next = &storage_16bit; - pp_next = &storage_16bit.pNext; - } - if (has_extension("VK_KHR_shader_float16_int8")) { - *pp_next = &float16_int8; - pp_next = &float16_int8.pNext; - } - } - - if (has_extension("VK_KHR_get_physical_device_properties2")) { - // Preferred method, call to get all properties that can be queried. - auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( - vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); - vkGetPhysicalDeviceProperties2KHR(dev, &properties); - - auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( - vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); - vkGetPhysicalDeviceFeatures2KHR(dev, &features); - } else { - // Fallback, get as many features as we can from the Vulkan1.0 - // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. - vkGetPhysicalDeviceFeatures(dev, &features.features); - } - - //// Now, extracting all the information from the vulkan query. - - // Not technically needed, because VK_SHADER_STAGE_COMPUTE_BIT will - // be set so long at least one queue has VK_QUEUE_COMPUTE_BIT, but - // preferring the explicit check. - uint32_t supported_subgroup_operations = - (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; - - // Even if we can't query it, warp size must be at least 1. Must - // also be defined, as `transpose` operation requires it. - uint32_t thread_warp_size = std::max(subgroup.subgroupSize, 1U); - - // By default, use the maximum API version that the driver allows, - // so that any supported features can be used by TVM shaders. - // However, if we can query the conformance version, then limit to - // only using the api version that passes the vulkan conformance - // tests. - uint32_t vulkan_api_version = properties.properties.apiVersion; - if (has_extension("VK_KHR_driver_properties")) { - auto api_major = VK_VERSION_MAJOR(vulkan_api_version); - auto api_minor = VK_VERSION_MINOR(vulkan_api_version); - if ((api_major > driver.conformanceVersion.major) || - ((api_major == driver.conformanceVersion.major) && - (api_minor > driver.conformanceVersion.minor))) { - vulkan_api_version = - VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0); - } - } - - // From "Versions and Formats" section of Vulkan spec. - uint32_t max_spirv_version = 0x10000; - if (vulkan_api_version >= VK_API_VERSION_1_2) { - max_spirv_version = 0x10500; - } else if (has_extension("VK_KHR_spirv_1_4")) { - max_spirv_version = 0x10400; - } else if (vulkan_api_version >= VK_API_VERSION_1_1) { - max_spirv_version = 0x10300; - } - - // Support is available based on these extensions, but allow it to - // be disabled based on an environment variable. - bool supports_push_descriptor = - has_extension("VK_KHR_push_descriptor") && has_extension("VK_KHR_descriptor_update_template"); - { - const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); - if (disable && *disable) { - supports_push_descriptor = false; - } - } - - // Support is available based on these extensions, but allow it to - // be disabled based on an environment variable. - bool supports_dedicated_allocation = has_extension("VK_KHR_get_memory_requirements2") && - has_extension("VK_KHR_dedicated_allocation"); - { - const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); - if (disable && *disable) { - supports_dedicated_allocation = false; - } - } - - Map config = { - {"kind", String("vulkan")}, - // Feature support - {"supports_float16", Bool(float16_int8.shaderFloat16)}, - {"supports_float32", Bool(true)}, - {"supports_float64", Bool(features.features.shaderFloat64)}, - {"supports_int8", Bool(float16_int8.shaderInt8)}, - {"supports_int16", Bool(features.features.shaderInt16)}, - {"supports_int32", Bool(true)}, - {"supports_int64", Bool(features.features.shaderInt64)}, - {"supports_8bit_buffer", Bool(storage_8bit.storageBuffer8BitAccess)}, - {"supports_16bit_buffer", Bool(storage_16bit.storageBuffer16BitAccess)}, - {"supports_storage_buffer_storage_class", - Bool(has_extension("VK_KHR_storage_buffer_storage_class"))}, - {"supports_push_descriptor", Bool(supports_push_descriptor)}, - {"supports_dedicated_allocation", Bool(supports_dedicated_allocation)}, - {"supported_subgroup_operations", Integer(supported_subgroup_operations)}, - // Physical device limits - {"max_num_threads", Integer(properties.properties.limits.maxComputeWorkGroupInvocations)}, - {"thread_warp_size", Integer(thread_warp_size)}, - {"max_block_size_x", Integer(properties.properties.limits.maxComputeWorkGroupSize[0])}, - {"max_block_size_y", Integer(properties.properties.limits.maxComputeWorkGroupSize[1])}, - {"max_block_size_z", Integer(properties.properties.limits.maxComputeWorkGroupSize[2])}, - {"max_push_constants_size", Integer(properties.properties.limits.maxPushConstantsSize)}, - {"max_uniform_buffer_range", Integer(properties.properties.limits.maxUniformBufferRange)}, - {"max_storage_buffer_range", - Integer(IntImm(DataType::UInt(32), properties.properties.limits.maxStorageBufferRange))}, - {"max_per_stage_descriptor_storage_buffer", - Integer(properties.properties.limits.maxPerStageDescriptorStorageBuffers)}, - {"max_shared_memory_per_block", - Integer(properties.properties.limits.maxComputeSharedMemorySize)}, - // Other device properties - {"device_name", String(properties.properties.deviceName)}, - {"driver_version", Integer(properties.properties.driverVersion)}, - {"vulkan_api_version", Integer(vulkan_api_version)}, - {"max_spirv_version", Integer(max_spirv_version)}, - }; - - return Target(config); -} - TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = VulkanDeviceAPI::Global(); *rv = static_cast(ptr); }); -TVM_REGISTER_GLOBAL("device_api.vulkan.generate_target").set_body_typed([](int device_id) { - return VulkanDeviceAPI::Global()->GenerateTarget(device_id); -}); +TVM_REGISTER_GLOBAL("device_api.vulkan.get_target_property") + .set_body_typed([](Device dev, const std::string& property) { + TVMRetValue rv; + VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); + return rv; + }); } // namespace vulkan } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index d31af8945efd..71c73afb0d61 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -22,6 +22,7 @@ #include +#include #include #include "vulkan/vulkan_core.h" @@ -74,20 +75,16 @@ class VulkanDeviceAPI final : public DeviceAPI { */ const VulkanContext& context(size_t device_id) const; - /*! \brief Get a Target that best describes a particular device. + /*! \brief Returns a property to be stored in a target. * * Returns the results of feature/property queries done during the * device initialization. */ - Target GenerateTarget(size_t device_id) const; + void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv); private: std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); - Target GetDeviceDescription(VkInstance instance, VkPhysicalDevice dev, - const std::vector& instance_extensions, - const std::vector& device_extensions); - std::vector FindEnabledExtensions( const std::vector& ext_prop, const std::vector& required_extensions, diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 08e998e0f035..cd62432d5ada 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -217,13 +217,53 @@ Map UpdateROCmAttrs(Map attrs) { Map UpdateVulkanAttrs(Map attrs) { if (attrs.count("from_device")) { int device_id = Downcast(attrs.at("from_device")); - const PackedFunc* generate_target = runtime::Registry::Get("device_api.vulkan.generate_target"); - ICHECK(generate_target) + Device device{kDLVulkan, device_id}; + const PackedFunc* get_target_property = + runtime::Registry::Get("device_api.vulkan.get_target_property"); + ICHECK(get_target_property) << "Requested to read Vulkan parameters from device, but no Vulkan runtime available"; - Target target = (*generate_target)(device_id).AsObjectRef(); - for (auto& kv : target->Export()) { - if (!attrs.count(kv.first)) { - attrs.Set(kv.first, kv.second); + + // Current vulkan implementation is partially a proof-of-concept, + // with long-term goal to move the -from_device functionality to + // TargetInternal::FromConfig, and to be usable by all targets. + // The duplicate list of parameters is needed until then, since + // TargetKind::Get("vulkan")->key2vtype_ is private. + std::vector bool_opts = { + "supports_float16", "supports_float32", + "supports_float64", "supports_int8", + "supports_int16", "supports_int32", + "supports_int64", "supports_8bit_buffer", + "supports_16bit_buffer", "supports_storage_buffer_storage_class", + "supports_push_descriptor", "supports_dedicated_allocation"}; + std::vector int_opts = {"supported_subgroup_operations", + "max_num_threads", + "thread_warp_size", + "max_block_size_x", + "max_block_size_y", + "max_block_size_z", + "max_push_constants_size", + "max_uniform_buffer_range", + "max_storage_buffer_range", + "max_per_stage_descriptor_storage_buffer", + "max_shared_memory_per_block", + "driver_version", + "vulkan_api_version", + "max_spirv_version"}; + std::vector str_opts = {"device_name"}; + + for (auto& key : bool_opts) { + if (!attrs.count(key)) { + attrs.Set(key, Bool(static_cast((*get_target_property)(device, key)))); + } + } + for (auto& key : int_opts) { + if (!attrs.count(key)) { + attrs.Set(key, Integer(static_cast((*get_target_property)(device, key)))); + } + } + for (auto& key : str_opts) { + if (!attrs.count(key)) { + attrs.Set(key, (*get_target_property)(device, key)); } } @@ -234,8 +274,8 @@ Map UpdateVulkanAttrs(Map attrs) { // The priority should be user-specified > device-query > default, // but defaults defined in .add_attr_option() are already applied by // this point. Longer-term, would be good to add a - // `DeviceAPI::GenerateTarget` function and extend "from_device" to - // work for all runtimes. + // `DeviceAPI::GetTargetProperty` function and extend "from_device" + // to work for all runtimes. std::unordered_map defaults = {{"supports_float32", Bool(true)}, {"supports_int32", Bool(true)}, {"max_num_threads", Integer(256)},