From f4c1c6625d78f08630ff2fc77a0a770e90700681 Mon Sep 17 00:00:00 2001 From: Evan Harvey Date: Fri, 7 Jan 2022 12:51:09 -0700 Subject: [PATCH] unit_test/batched: Update BatchedGemm epsilon --- unit_test/batched/dense/Test_Batched_BatchedGemm.hpp | 7 ++++++- unit_test/batched/dense/Test_Batched_SerialGemm.hpp | 6 ++++-- unit_test/batched/dense/Test_Batched_TeamGemm.hpp | 6 ++++-- unit_test/batched/dense/Test_Batched_TeamVectorGemm.hpp | 6 ++++-- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/unit_test/batched/dense/Test_Batched_BatchedGemm.hpp b/unit_test/batched/dense/Test_Batched_BatchedGemm.hpp index d230b39d14..52dc66091c 100644 --- a/unit_test/batched/dense/Test_Batched_BatchedGemm.hpp +++ b/unit_test/batched/dense/Test_Batched_BatchedGemm.hpp @@ -135,7 +135,12 @@ void impl_test_batched_gemm_with_handle(BatchedGemmHandle* batchedGemmHandle, using mag_type = float; mag_type sum(1), diff(0); - mag_type eps = (mag_type)(1 << 1) * KOKKOSKERNELS_IMPL_FP16_EPSILON; + auto eps = static_cast(ats::epsilon()); + + eps *= std::is_same::value || + std::is_same::value + ? 4 + : 1e3; for (int k = 0; k < N; ++k) { for (int i = 0; i < matCdim1; ++i) { diff --git a/unit_test/batched/dense/Test_Batched_SerialGemm.hpp b/unit_test/batched/dense/Test_Batched_SerialGemm.hpp index 6a8be3fc54..a7ec3db6a9 100644 --- a/unit_test/batched/dense/Test_Batched_SerialGemm.hpp +++ b/unit_test/batched/dense/Test_Batched_SerialGemm.hpp @@ -125,8 +125,10 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, mag_type eps = ats::epsilon(); - eps *= - std::is_same::value ? 4 : 1e3; + eps *= std::is_same::value || + std::is_same::value + ? 4 + : 1e3; for (int k = 0; k < N; ++k) for (int i = 0; i < matCdim1; ++i) diff --git a/unit_test/batched/dense/Test_Batched_TeamGemm.hpp b/unit_test/batched/dense/Test_Batched_TeamGemm.hpp index 00fb2f4d49..d5aa853482 100644 --- a/unit_test/batched/dense/Test_Batched_TeamGemm.hpp +++ b/unit_test/batched/dense/Test_Batched_TeamGemm.hpp @@ -132,8 +132,10 @@ void impl_test_batched_teamgemm(const int N, const int matAdim1, mag_type sum(1), diff(0); mag_type eps = ats::epsilon(); - eps *= - std::is_same::value ? 4 : 1e3; + eps *= std::is_same::value || + std::is_same::value + ? 4 + : 1e3; for (int k = 0; k < N; ++k) for (int i = 0; i < matCdim1; ++i) diff --git a/unit_test/batched/dense/Test_Batched_TeamVectorGemm.hpp b/unit_test/batched/dense/Test_Batched_TeamVectorGemm.hpp index d104df2b06..8d10440bc2 100644 --- a/unit_test/batched/dense/Test_Batched_TeamVectorGemm.hpp +++ b/unit_test/batched/dense/Test_Batched_TeamVectorGemm.hpp @@ -131,8 +131,10 @@ void impl_test_batched_teamvectorgemm(const int N, const int matAdim1, mag_type eps = ats::epsilon(); - eps *= - std::is_same::value ? 4 : 1e3; + eps *= std::is_same::value || + std::is_same::value + ? 4 + : 1e3; for (int k = 0; k < N; ++k) for (int i = 0; i < matCdim1; ++i)