From 13ad922e7ffd4eeeff5ca7199c4d5e7bf9703849 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 1 Feb 2024 16:18:14 -0800 Subject: [PATCH] Improve MatMulNBits test (#19378) ### Description The test creates millions of threads. This change is to avoid that by using an existing thread pool. ### Motivation and Context --- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index e0ed32630277e..2ad20eafc2ef1 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -14,6 +14,8 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/ort_env.h" #include "core/util/qmath.h" #include @@ -21,12 +23,13 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" +extern std::unique_ptr ort_env; namespace onnxruntime { + namespace test { static constexpr int QBits = 4; - void QuantizeDequantize(std::vector& raw_vals, std::vector& quant_vals, std::vector& scales, @@ -34,9 +37,8 @@ void QuantizeDequantize(std::vector& raw_vals, int32_t N, int32_t K, int32_t block_size) { - OrtThreadPoolParams to; - auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, - concurrency::ThreadPoolType::INTRA_OP); + auto& ortenv = **ort_env.get(); + onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); MlasQuantizeBlockwise( quant_vals.data(), @@ -48,7 +50,7 @@ void QuantizeDequantize(std::vector& raw_vals, K, N, N, - tp.get()); + tp); // Note that input1_f_vals is NxK after dequant MlasDequantizeBlockwise( @@ -60,7 +62,7 @@ void QuantizeDequantize(std::vector& raw_vals, true, // columnwise quantization K, // number of rows N, // number of columns - tp.get()); + tp); } void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level,