mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Improve MatMulNBits test (#19378)
### Description The test creates millions of threads. This change is to avoid that by using an existing thread pool. ### Motivation and Context
This commit is contained in:
parent
8a2646ce60
commit
13ad922e7f
1 changed files with 8 additions and 6 deletions
|
|
@ -14,6 +14,8 @@
|
|||
#include "test/optimizer/graph_transform_test_builder.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/session/ort_env.h"
|
||||
#include "core/util/qmath.h"
|
||||
|
||||
#include <chrono>
|
||||
|
|
@ -21,12 +23,13 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
extern std::unique_ptr<Ort::Env> ort_env;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace test {
|
||||
|
||||
static constexpr int QBits = 4;
|
||||
|
||||
void QuantizeDequantize(std::vector<float>& raw_vals,
|
||||
std::vector<uint8_t>& quant_vals,
|
||||
std::vector<float>& scales,
|
||||
|
|
@ -34,9 +37,8 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
|
|||
int32_t N,
|
||||
int32_t K,
|
||||
int32_t block_size) {
|
||||
OrtThreadPoolParams to;
|
||||
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
|
||||
concurrency::ThreadPoolType::INTRA_OP);
|
||||
auto& ortenv = **ort_env.get();
|
||||
onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool();
|
||||
|
||||
MlasQuantizeBlockwise<float, 4>(
|
||||
quant_vals.data(),
|
||||
|
|
@ -48,7 +50,7 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
|
|||
K,
|
||||
N,
|
||||
N,
|
||||
tp.get());
|
||||
tp);
|
||||
|
||||
// Note that input1_f_vals is NxK after dequant
|
||||
MlasDequantizeBlockwise<float, 4>(
|
||||
|
|
@ -60,7 +62,7 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
|
|||
true, // columnwise quantization
|
||||
K, // number of rows
|
||||
N, // number of columns
|
||||
tp.get());
|
||||
tp);
|
||||
}
|
||||
|
||||
void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level,
|
||||
|
|
|
|||
Loading…
Reference in a new issue