Skip to content

Commit

Permalink
Improve MatMulNBits test (#19378)
Browse files Browse the repository at this point in the history
### Description
The test creates millions of threads. This change is to avoid that by
using an existing thread pool.


### Motivation and Context
  • Loading branch information
snnn authored Feb 2, 2024
1 parent 8a2646c commit 13ad922
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,31 @@
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/default_providers.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/ort_env.h"
#include "core/util/qmath.h"

#include <chrono>
#include <random>

#include "gtest/gtest.h"
#include "gmock/gmock.h"
extern std::unique_ptr<Ort::Env> ort_env;

namespace onnxruntime {

namespace test {

static constexpr int QBits = 4;

void QuantizeDequantize(std::vector<float>& raw_vals,
std::vector<uint8_t>& quant_vals,
std::vector<float>& scales,
std::vector<uint8_t>* zp,
int32_t N,
int32_t K,
int32_t block_size) {
OrtThreadPoolParams to;
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
concurrency::ThreadPoolType::INTRA_OP);
auto& ortenv = **ort_env.get();
onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool();

MlasQuantizeBlockwise<float, 4>(
quant_vals.data(),
Expand All @@ -48,7 +50,7 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
K,
N,
N,
tp.get());
tp);

// Note that input1_f_vals is NxK after dequant
MlasDequantizeBlockwise<float, 4>(
Expand All @@ -60,7 +62,7 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
true, // columnwise quantization
K, // number of rows
N, // number of columns
tp.get());
tp);
}

void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level,
Expand Down

0 comments on commit 13ad922

Please sign in to comment.