Skip to content

Commit

Permalink
[ET-VK][ez] Empty initialize ShaderInfo and add bool() operator
Browse files Browse the repository at this point in the history
Differential Revision: D61666460

Pull Request resolved: pytorch#4842
  • Loading branch information
SS-JIA authored Aug 22, 2024
1 parent 65473de commit 87b38cf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
8 changes: 6 additions & 2 deletions backends/vulkan/runtime/vk_api/Shader.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class ShaderLayout final {

struct ShaderInfo final {
struct {
const uint32_t* bin;
uint32_t size;
const uint32_t* bin = nullptr;
uint32_t size = 0u;
} src_code;

std::string kernel_name{""};
Expand All @@ -71,6 +71,10 @@ struct ShaderInfo final {
const uint32_t,
std::vector<VkDescriptorType>,
const utils::uvec3 tile_size);

operator bool() const {
return src_code.bin != nullptr;
};
};

bool operator==(const ShaderInfo& _1, const ShaderInfo& _2);
Expand Down
7 changes: 7 additions & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ std::vector<int64_t> get_reference_strides(
return {};
}

TEST_F(VulkanComputeAPITest, empty_init_shader_info_test) {
vkapi::ShaderInfo empty_shader_info;
EXPECT_FALSE(empty_shader_info);
EXPECT_TRUE(empty_shader_info.src_code.bin == nullptr);
EXPECT_TRUE(empty_shader_info.src_code.size == 0u);
}

TEST_F(VulkanComputeAPITest, calculate_tensor_strides_test) {
for (const auto& sizes : standard_sizes_to_test) {
if (sizes.size() < 3) {
Expand Down

0 comments on commit 87b38cf

Please sign in to comment.