Skip to content

Commit

Permalink
[ET-VK][ez] Introduce check_close function in compute_api_test to…
Browse files Browse the repository at this point in the history
… account for small numerical differences

Differential Revision: D61666459

Pull Request resolved: pytorch#4841
  • Loading branch information
SS-JIA authored Aug 22, 2024
1 parent 0a21102 commit 65473de
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
6 changes: 6 additions & 0 deletions backends/vulkan/test/utils/test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,9 @@ void execute_graph_and_check_output(
}
}
}

bool check_close(float a, float b, float atol, float rtol) {
float max = std::max(std::abs(a), std::abs(b));
float diff = std::abs(a - b);
return diff <= (atol + rtol * max);
}
6 changes: 6 additions & 0 deletions backends/vulkan/test/utils/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,9 @@ void print_vector(
}
std::cout << std::endl;
}

//
// Misc. Utilities
//

bool check_close(float a, float b, float atol = 1e-4, float rtol = 1e-5);
2 changes: 1 addition & 1 deletion backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ TEST_F(VulkanComputeAPITest, tensor_no_copy_transpose_test) {
EXPECT_TRUE(data_out.size() == ref_out.size());

for (size_t i = 0; i < data_out.size(); ++i) {
EXPECT_TRUE(data_out[i] == ref_out[i]);
EXPECT_TRUE(check_close(data_out[i], ref_out[i]));
}
}

Expand Down

0 comments on commit 65473de

Please sign in to comment.