Skip to content

Commit

Permalink
[Vulkan] Remove dependency on Target from -from_device functionality. (
Browse files Browse the repository at this point in the history
…#8171)

The `tvm.target.Target("vulkan -from_device=0")` functionality was
initially implemented by generating/returning a Target.  This broke
usage of libtvm_runtime.so, since Target is only defined in libtvm.so.
This commit reimplements the functionality without the dependency on
Target, Integer, Bool, or IntImm.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
Lunderberg and Lunderberg authored Jun 2, 2021
1 parent bb3e772 commit 0c83fe8
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 219 deletions.
171 changes: 171 additions & 0 deletions src/runtime/vulkan/vulkan_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "vulkan_context.h"

#include <algorithm>
#include <unordered_map>

#include "vulkan_common.h"
Expand All @@ -29,6 +30,176 @@ namespace tvm {
namespace runtime {
namespace vulkan {

VulkanDeviceProperties::VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_dev,
const std::vector<const char*> instance_extensions,
const std::vector<const char*> device_extensions) {
auto has_instance_extension = [&](const char* query) {
return std::any_of(instance_extensions.begin(), instance_extensions.end(),
[&](const char* extension) { return std::strcmp(query, extension) == 0; });
};

auto has_device_extension = [&](const char* query) {
return std::any_of(device_extensions.begin(), device_extensions.end(),
[&](const char* extension) { return std::strcmp(query, extension) == 0; });
};

///////////////////////////////////////////////////////////////
// Query properties from Vulkan API //
///////////////////////////////////////////////////////////////

// Declare output locations for properties
VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2};
VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES};
VkPhysicalDeviceSubgroupProperties subgroup = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES};

// Need to do initial query in order to check the apiVersion.
vkGetPhysicalDeviceProperties(phy_dev, &properties.properties);

// Set up linked list for property query
{
void** pp_next = &properties.pNext;
if (has_device_extension("VK_KHR_driver_properties")) {
*pp_next = &driver;
pp_next = &driver.pNext;
}
if (properties.properties.apiVersion >= VK_API_VERSION_1_1) {
*pp_next = &subgroup;
pp_next = &subgroup.pNext;
}
}

// Declare output locations for features
VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2};
VkPhysicalDevice8BitStorageFeatures storage_8bit = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES};
VkPhysicalDevice16BitStorageFeatures storage_16bit = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES};
VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES};

// Set up linked list for feature query
{
void** pp_next = &features.pNext;
if (has_device_extension("VK_KHR_8bit_storage")) {
*pp_next = &storage_8bit;
pp_next = &storage_8bit.pNext;
}
if (has_device_extension("VK_KHR_16bit_storage")) {
*pp_next = &storage_16bit;
pp_next = &storage_16bit.pNext;
}
if (has_device_extension("VK_KHR_shader_float16_int8")) {
*pp_next = &float16_int8;
pp_next = &float16_int8.pNext;
}
}

if (has_instance_extension("VK_KHR_get_physical_device_properties2")) {
// Preferred method, call to get all properties that can be queried.
auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL(
vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR"));
vkGetPhysicalDeviceProperties2KHR(phy_dev, &properties);

auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL(
vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR"));
vkGetPhysicalDeviceFeatures2KHR(phy_dev, &features);
} else {
// Fallback, get as many features as we can from the Vulkan1.0
// API. Corresponding vkGetPhysicalDeviceProperties was already done earlier.
vkGetPhysicalDeviceFeatures(phy_dev, &features.features);
}

///////////////////////////////////////////////////////////////
// Fill member variables from Vulkan structures //
///////////////////////////////////////////////////////////////

supports_float16 = float16_int8.shaderFloat16;
supports_float32 = true;
supports_float64 = features.features.shaderFloat64;
supports_int8 = float16_int8.shaderInt8;
supports_int16 = features.features.shaderInt16;
supports_int32 = true;
supports_int64 = features.features.shaderInt64;
supports_8bit_buffer = storage_8bit.storageBuffer8BitAccess;
supports_16bit_buffer = storage_16bit.storageBuffer16BitAccess;
supports_storage_buffer_storage_class =
has_device_extension("VK_KHR_storage_buffer_storage_class");

// Support is available based on these extensions, but allow it to
// be disabled based on an environment variable.
supports_push_descriptor = has_device_extension("VK_KHR_push_descriptor") &&
has_device_extension("VK_KHR_descriptor_update_template");
{
const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR");
if (disable && *disable) {
supports_push_descriptor = false;
}
}

// Support is available based on these extensions, but allow it to
// be disabled based on an environment variable.
supports_dedicated_allocation = has_device_extension("VK_KHR_get_memory_requirements2") &&
has_device_extension("VK_KHR_dedicated_allocation");
{
const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION");
if (disable && *disable) {
supports_dedicated_allocation = false;
}
}

// The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically
// needed, since it will be set so long at least one queue has
// VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future
// confusion..
supported_subgroup_operations =
(subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0;

max_num_threads = properties.properties.limits.maxComputeWorkGroupInvocations;

// Even if we can't query it, warp size must be at least 1.
thread_warp_size = std::max(subgroup.subgroupSize, 1U);

max_block_size_x = properties.properties.limits.maxComputeWorkGroupSize[0];
max_block_size_y = properties.properties.limits.maxComputeWorkGroupSize[1];
max_block_size_z = properties.properties.limits.maxComputeWorkGroupSize[2];
max_push_constants_size = properties.properties.limits.maxPushConstantsSize;
max_uniform_buffer_range = properties.properties.limits.maxUniformBufferRange;
max_storage_buffer_range = properties.properties.limits.maxStorageBufferRange;
max_per_stage_descriptor_storage_buffer =
properties.properties.limits.maxPerStageDescriptorStorageBuffers;
max_shared_memory_per_block = properties.properties.limits.maxComputeSharedMemorySize;
device_name = properties.properties.deviceName;
driver_version = properties.properties.driverVersion;

// By default, use the maximum API version that the driver allows,
// so that any supported features can be used by TVM shaders.
// However, if we can query the conformance version, then limit to
// only using the api version that passes the vulkan conformance
// tests.
vulkan_api_version = properties.properties.apiVersion;
if (has_device_extension("VK_KHR_driver_properties")) {
auto api_major = VK_VERSION_MAJOR(vulkan_api_version);
auto api_minor = VK_VERSION_MINOR(vulkan_api_version);
if ((api_major > driver.conformanceVersion.major) ||
((api_major == driver.conformanceVersion.major) &&
(api_minor > driver.conformanceVersion.minor))) {
vulkan_api_version =
VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0);
}
}

// From "Versions and Formats" section of Vulkan spec.
max_spirv_version = 0x10000;
if (vulkan_api_version >= VK_API_VERSION_1_2) {
max_spirv_version = 0x10500;
} else if (has_device_extension("VK_KHR_spirv_1_4")) {
max_spirv_version = 0x10400;
} else if (vulkan_api_version >= VK_API_VERSION_1_1) {
max_spirv_version = 0x10300;
}
}

VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) {
vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL(
vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR"));
Expand Down
50 changes: 48 additions & 2 deletions src/runtime/vulkan/vulkan_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <tvm/target/target.h>

#include <memory>
#include <string>
#include <vector>

#include "vulkan/vulkan_core.h"
#include "vulkan_buffer.h"
Expand All @@ -47,14 +49,58 @@ struct VulkanGetBufferMemoryRequirements2Functions {
PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr};
};

/*!
* \brief Stores the capabilities/limits queried from the physical device.
*
* The member variables here have a 1-1 mapping to Target parameters,
* if target->kind->device_type==kDLVulkan. A separate struct is used
* to maintain the boundary between the Vulkan runtime in
* libtvm_runtime.so, and the Target object in libtvm.so.
*/
struct VulkanDeviceProperties {
VulkanDeviceProperties() {}
VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_device,
const std::vector<const char*> instance_extensions,
const std::vector<const char*> device_extensions);

bool supports_float16{false};
bool supports_float32{true};
bool supports_float64{false};
bool supports_int8{false};
bool supports_int16{false};
bool supports_int32{true};
bool supports_int64{false};
bool supports_8bit_buffer{false};
bool supports_16bit_buffer{false};
bool supports_storage_buffer_storage_class{false};
bool supports_push_descriptor{false};
bool supports_dedicated_allocation{false};
uint32_t supported_subgroup_operations{0};
uint32_t max_num_threads{1};
uint32_t thread_warp_size{1};
uint32_t max_block_size_x{1};
uint32_t max_block_size_y{1};
uint32_t max_block_size_z{1};
uint32_t max_push_constants_size{128};
uint32_t max_uniform_buffer_range{16384};
uint32_t max_storage_buffer_range{1 << 27};
uint32_t max_per_stage_descriptor_storage_buffer{4};
uint32_t max_shared_memory_per_block{16384};
std::string device_name{"unknown device name"};
uint32_t driver_version{0};
uint32_t vulkan_api_version{VK_API_VERSION_1_0};
uint32_t max_spirv_version{0x10000};
};

struct VulkanContext {
// physical device
VkPhysicalDevice phy_device{nullptr};

// Cached device properties, queried through Vulkan API.
VulkanDeviceProperties device_properties;

// Phyiscal device property
VkPhysicalDeviceProperties phy_device_prop;
// Target that best represents this physical device
Target target;
// Memory type index for staging.
uint32_t staging_mtype_index{0};
// whether staging is coherent
Expand Down
Loading

0 comments on commit 0c83fe8

Please sign in to comment.