Improve MatMulNBits test (#19378)

### Description
The test creates millions of threads. This change is to avoid that by
using an existing thread pool.


### Motivation and Context
This commit is contained in:
Changming Sun 2024-02-01 16:18:14 -08:00 committed by GitHub
parent 8a2646ce60
commit 13ad922e7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -14,6 +14,8 @@
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/default_providers.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/ort_env.h"
#include "core/util/qmath.h"
#include <chrono>
@ -21,12 +23,13 @@
#include "gtest/gtest.h"
#include "gmock/gmock.h"
extern std::unique_ptr<Ort::Env> ort_env;
namespace onnxruntime {
namespace test {
static constexpr int QBits = 4;
void QuantizeDequantize(std::vector<float>& raw_vals,
std::vector<uint8_t>& quant_vals,
std::vector<float>& scales,
@ -34,9 +37,8 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
int32_t N,
int32_t K,
int32_t block_size) {
OrtThreadPoolParams to;
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
concurrency::ThreadPoolType::INTRA_OP);
auto& ortenv = **ort_env.get();
onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool();
MlasQuantizeBlockwise<float, 4>(
quant_vals.data(),
@ -48,7 +50,7 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
K,
N,
N,
tp.get());
tp);
// Note that input1_f_vals is NxK after dequant
MlasDequantizeBlockwise<float, 4>(
@ -60,7 +62,7 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
true, // columnwise quantization
K, // number of rows
N, // number of columns
tp.get());
tp);
}
void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level,