Skip to content

Commit

Permalink
[bug] Fix validation erros due to inactive VK_KHR_16bit_storage (taic…
Browse files Browse the repository at this point in the history
…hi-dev#7360)

Issue: #

### Brief Summary
  • Loading branch information
jim19930609 authored and quadpixels committed May 13, 2023
1 parent 7efa2d2 commit 0d3c3ff
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
10 changes: 7 additions & 3 deletions c_api/src/taichi_vulkan_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifdef TI_WITH_VULKAN
#include "taichi_vulkan_impl.h"
#include "taichi/rhi/vulkan/vulkan_loader.h"
#include "taichi/common/utils.h"

#ifdef ANDROID
#define VK_KHR_android_surface 1
Expand Down Expand Up @@ -126,10 +127,10 @@ void VulkanRuntime::free_image(TiImage image) {
// -----------------------------------------------------------------------------

TiRuntime ti_create_vulkan_runtime_ext(uint32_t api_version,
const char **instance_extensions,
uint32_t instance_extension_count,
const char **device_extensions,
uint32_t device_extension_count) {
const char **instance_extensions,
uint32_t device_extension_count,
const char **device_extensions) {
TiRuntime out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
if (api_version < VK_API_VERSION_1_0) {
Expand Down Expand Up @@ -157,6 +158,9 @@ TiRuntime ti_create_vulkan_runtime_ext(uint32_t api_version,
params.additional_device_extensions.push_back(device_extensions[i]);
}
params.surface_creator = nullptr;
if (is_ci()) {
params.enable_validation_layer = true;
}
out = (TiRuntime) static_cast<Runtime *>(new VulkanRuntimeOwned(params));
TI_CAPI_TRY_CATCH_END();
return out;
Expand Down
18 changes: 16 additions & 2 deletions c_api/tests/c_api_numerical_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "taichi/cpp/taichi.hpp"
#include "c_api/tests/gtest_fixture.h"

#ifdef TI_WITH_VULKAN

std::vector<float> read_fp16_ndarray(ti::Runtime &runtime,
const ti::NdArray<uint16_t> &ndarray,
const std::vector<uint32_t> &shape,
Expand Down Expand Up @@ -31,7 +33,12 @@ TEST_F(CapiTest, Float16Fill) {
std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir;

ti::Runtime runtime(arch);
std::vector<const char *> device_extensions = {
VK_KHR_16BIT_STORAGE_EXTENSION_NAME};
TiRuntime ti_runtime = ti_create_vulkan_runtime_ext(
VK_API_VERSION_1_0, 0, {} /*instance extensions*/, 1,
device_extensions.data() /*device extensions*/);
ti::Runtime runtime = ti::Runtime(arch, ti_runtime, true);

ti::AotModule aot_mod = runtime.load_aot_module(aot_mod_ss.str().c_str());
ti::Kernel k_fill_scalar_array_with_fp32 =
Expand Down Expand Up @@ -112,7 +119,12 @@ TEST_F(CapiTest, Float16Compute) {
std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir;

ti::Runtime runtime(arch);
std::vector<const char *> device_extensions = {
VK_KHR_16BIT_STORAGE_EXTENSION_NAME};
TiRuntime ti_runtime = ti_create_vulkan_runtime_ext(
VK_API_VERSION_1_0, 0, {} /*instance extensions*/, 1,
device_extensions.data() /*device extensions*/);
ti::Runtime runtime = ti::Runtime(arch, ti_runtime, true);

ti::AotModule aot_mod = runtime.load_aot_module(aot_mod_ss.str().c_str());
ti::Kernel k_compute = aot_mod.get_kernel("compute_kernel");
Expand Down Expand Up @@ -181,3 +193,5 @@ TEST_F(CapiTest, Float16Compute) {
EXPECT_EQ(data[11], 272);
}
}

#endif
20 changes: 16 additions & 4 deletions taichi/rhi/vulkan/vulkan_device_creator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ bool check_validation_layer_support() {
static const std::unordered_set<std::string> ignored_messages = {
"UNASSIGNED-DEBUG-PRINTF",
"VUID_Undefined",
// FIXME(zhanlue): Fix validation errors with float16 and remove these
// ignores.
"VUID-RuntimeSpirv-uniformAndStorageBuffer16BitAccess-06332",
"VUID-RuntimeSpirv-storageBuffer16BitAccess-06331",
};

[[maybe_unused]] bool vk_ignore_validation_warning(
Expand Down Expand Up @@ -599,6 +595,8 @@ void VulkanDeviceCreator::create_logical_device(bool manual_create) {
// Tracking issue: https://github.com/KhronosGroup/MoltenVK/issues/1214
caps.set(DeviceCapability::spirv_has_non_semantic_info, true);
enabled_extensions.push_back(ext.extensionName);
} else if (name == VK_KHR_16BIT_STORAGE_EXTENSION_NAME) {
enabled_extensions.push_back(ext.extensionName);
} else if (std::find(params_.additional_device_extensions.begin(),
params_.additional_device_extensions.end(),
name) != params_.additional_device_extensions.end()) {
Expand Down Expand Up @@ -684,6 +682,11 @@ void VulkanDeviceCreator::create_logical_device(bool manual_create) {
VkPhysicalDeviceFloat16Int8FeaturesKHR shader_f16_i8_feature{};
shader_f16_i8_feature.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FLOAT16_INT8_FEATURES_KHR;

VkPhysicalDevice16BitStorageFeatures shader_16bit_storage_feature{};
shader_16bit_storage_feature.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES;

VkPhysicalDeviceBufferDeviceAddressFeaturesKHR
buffer_device_address_feature{};
buffer_device_address_feature.sType =
Expand Down Expand Up @@ -778,6 +781,15 @@ void VulkanDeviceCreator::create_logical_device(bool manual_create) {
pNextEnd = &shader_f16_i8_feature.pNext;
}

if (CHECK_VERSION(1, 1) ||
CHECK_EXTENSION(VK_KHR_16BIT_STORAGE_EXTENSION_NAME)) {
features2.pNext = &shader_16bit_storage_feature;
vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2);

*pNextEnd = &shader_16bit_storage_feature;
pNextEnd = &shader_16bit_storage_feature.pNext;
}

// Buffer Device Address
if (CHECK_VERSION(1, 2) ||
CHECK_EXTENSION(VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME)) {
Expand Down

0 comments on commit 0d3c3ff

Please sign in to comment.