mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Add mlas_bench tools. Starting with sconv bench and sgemm bench. (#7139)
* Add mlas_bench tools. Starting with sconv bench and sgemm bench. * Some update with build related.
This commit is contained in:
parent
56b22c1c6b
commit
dbcfc4bee6
6 changed files with 462 additions and 0 deletions
|
|
@ -768,6 +768,16 @@ if(onnxruntime_BUILD_BENCHMARKS)
|
|||
target_link_libraries(onnxruntime_benchmark PRIVATE onnx_test_runner_common benchmark::benchmark ${onnx_test_libs})
|
||||
add_dependencies(onnxruntime_benchmark ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
set_target_properties(onnxruntime_benchmark PROPERTIES FOLDER "ONNXRuntimeTest")
|
||||
|
||||
SET(MLAS_BENCH_DIR ${TEST_SRC_DIR}/mlas/bench)
|
||||
file(GLOB_RECURSE MLAS_BENCH_SOURCE_FILES "${MLAS_BENCH_DIR}/*.cpp" "${MLAS_BENCH_DIR}/*.h")
|
||||
onnxruntime_add_executable(onnxruntime_mlas_benchmark ${MLAS_BENCH_SOURCE_FILES})
|
||||
target_include_directories(onnxruntime_mlas_benchmark PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc)
|
||||
target_link_libraries(onnxruntime_mlas_benchmark PRIVATE benchmark::benchmark onnxruntime_mlas onnxruntime_common)
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(onnxruntime_mlas_benchmark PRIVATE nsync_cpp)
|
||||
endif()
|
||||
set_target_properties(onnxruntime_mlas_benchmark PROPERTIES FOLDER "ONNXRuntimeTest")
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
|
|
|
|||
6
onnxruntime/test/mlas/bench/bench_main.cpp
Normal file
6
onnxruntime/test/mlas/bench/bench_main.cpp
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <benchmark/benchmark.h>
|
||||
|
||||
BENCHMARK_MAIN();
|
||||
247
onnxruntime/test/mlas/bench/bench_sconv.cpp
Normal file
247
onnxruntime/test/mlas/bench/bench_sconv.cpp
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "mlas.h"
|
||||
#include "bench_util.h"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <numeric>
|
||||
|
||||
static std::vector<std::string> BuildArgNamesForConv(size_t rank) {
|
||||
std::vector<std::string> names = {"Rank", "N", "G", "Cpg", "Fpg"};
|
||||
|
||||
size_t arg_position = names.size();
|
||||
names.resize(arg_position + rank * 6, std::string(""));
|
||||
|
||||
names[arg_position] = "I";
|
||||
arg_position += rank;
|
||||
|
||||
names[arg_position] = "K";
|
||||
arg_position += rank;
|
||||
|
||||
names[arg_position] = "P";
|
||||
arg_position += rank * 2;
|
||||
|
||||
names[arg_position] = "S";
|
||||
arg_position += rank;
|
||||
|
||||
names[arg_position] = "D";
|
||||
|
||||
return names;
|
||||
}
|
||||
|
||||
static const std::vector<std::string>& ArgNamesForConv(size_t rank) {
|
||||
static std::map<size_t, std::vector<std::string>> rank_to_args_name;
|
||||
if (rank_to_args_name.find(rank) == rank_to_args_name.end()) {
|
||||
rank_to_args_name.emplace(std::make_pair(rank, BuildArgNamesForConv(rank)));
|
||||
}
|
||||
return rank_to_args_name[rank];
|
||||
}
|
||||
|
||||
// dummy for some strange build error when using Bench capture
|
||||
void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) {
|
||||
const int64_t rank = state.range(0); // Rank
|
||||
const int64_t batch_size = state.range(1); // N
|
||||
const int64_t groups = state.range(2); // G
|
||||
const int64_t input_channels_per_group = state.range(3); // Cpg
|
||||
const int64_t output_channels_per_group = state.range(4); // Fpg
|
||||
|
||||
if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!");
|
||||
if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!");
|
||||
if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!");
|
||||
if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!");
|
||||
if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!");
|
||||
|
||||
size_t arg_position = 5;
|
||||
const auto input_shape = BenchArgsVector(state, arg_position, rank);
|
||||
const auto kernel_shape = BenchArgsVector(state, arg_position, rank);
|
||||
const auto paddings = BenchArgsVector(state, arg_position, rank * 2);
|
||||
const auto strides = BenchArgsVector(state, arg_position, rank);
|
||||
const auto dilations = BenchArgsVector(state, arg_position, rank);
|
||||
|
||||
// do not check the size of each vector as they are forced from args.
|
||||
if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) {
|
||||
throw std::invalid_argument("all input image dim must > 0");
|
||||
}
|
||||
|
||||
if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) {
|
||||
throw std::invalid_argument("all kernel dim must > 0");
|
||||
}
|
||||
|
||||
if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) {
|
||||
throw std::invalid_argument("all strides dim must > 0");
|
||||
}
|
||||
|
||||
if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) {
|
||||
throw std::invalid_argument("all dilations dim must > 0");
|
||||
}
|
||||
|
||||
const int64_t GC = groups * input_channels_per_group;
|
||||
const int64_t GF = groups * output_channels_per_group;
|
||||
std::vector<int64_t> x_shape = {batch_size, GC};
|
||||
x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end());
|
||||
std::vector<int64_t> f_shape = {GF, input_channels_per_group};
|
||||
f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end());
|
||||
|
||||
std::vector<int64_t> output_shape((size_t)rank);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
auto km = 1 + dilations[i] * (kernel_shape[i] - 1);
|
||||
output_shape[i] = (paddings[i] + paddings[i + rank] + input_shape[i] - km) / strides[i] + 1;
|
||||
}
|
||||
std::vector<int64_t> y_shape = {batch_size, GF};
|
||||
y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end());
|
||||
|
||||
MLAS_ACTIVATION activation;
|
||||
activation.ActivationKind = MlasIdentityActivation;
|
||||
MLAS_CONV_PARAMETERS Parameters;
|
||||
size_t WorkingBufferSize = 0;
|
||||
MlasConvPrepare(&Parameters,
|
||||
static_cast<size_t>(rank),
|
||||
static_cast<size_t>(batch_size),
|
||||
static_cast<size_t>(groups),
|
||||
static_cast<size_t>(input_channels_per_group),
|
||||
input_shape.data(),
|
||||
kernel_shape.data(),
|
||||
dilations.data(),
|
||||
paddings.data(),
|
||||
strides.data(),
|
||||
output_shape.data(),
|
||||
static_cast<size_t>(output_channels_per_group),
|
||||
&activation,
|
||||
&WorkingBufferSize,
|
||||
nullptr);
|
||||
|
||||
auto X = RandomVectorUniform(x_shape, -2.0, 2.0);
|
||||
auto F = RandomVectorUniform(f_shape, -1.0, 1.0);
|
||||
int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies<int64_t>());
|
||||
std::vector<float> Y(static_cast<size_t>(y_size));
|
||||
std::vector<float> working_buffer(WorkingBufferSize);
|
||||
|
||||
// warm up first round.
|
||||
MlasConv(&Parameters,
|
||||
X.data(),
|
||||
F.data(),
|
||||
nullptr,
|
||||
working_buffer.data(),
|
||||
Y.data(),
|
||||
nullptr);
|
||||
|
||||
for (auto _ : state) {
|
||||
MlasConv(&Parameters,
|
||||
X.data(),
|
||||
F.data(),
|
||||
nullptr,
|
||||
working_buffer.data(),
|
||||
Y.data(),
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
static void ResNet50(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(ArgNamesForConv(2));
|
||||
|
||||
//************************* Conv 1 *************************
|
||||
// Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1, 3, 64,224,224, 7,7, 3,3,3,3, 2,2, 1,1});
|
||||
|
||||
//************************ Conv 2.1 ************************
|
||||
// Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1, 64, 64, 56, 56, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1, 64, 64, 56, 56, 3,3, 1,1,1,1, 1,1, 1,1});
|
||||
b->Args({2, 1, 1, 64,256, 56, 56, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
//b->Args({2, 1, 1, 64,256, 56, 56, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
|
||||
//************************ Conv 2.X ************************
|
||||
// Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1,256, 64, 56, 56, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
//b->Args({2, 1, 1, 64, 64, 56, 56, 3,3, 1,1,1,1, 1,1, 1,1});
|
||||
//b->Args({2, 1, 1, 64,256, 56, 56, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
|
||||
/************************ Conv 3.1 ************************/
|
||||
// Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1,256,128, 56, 56, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1,128,128, 56, 56, 3,3, 1,1,1,1, 2,2, 1,1});
|
||||
b->Args({2, 1, 1,128,512, 28, 28, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1,256,512, 56, 56, 1,1, 0,0,0,0, 2,2, 1,1});
|
||||
|
||||
/************************ Conv 3.X ************************/
|
||||
// Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1,512,128, 28, 28, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1,128,128, 28, 28, 3,3, 1,1,1,1, 1,1, 1,1});
|
||||
//b->Args({2, 1, 1,128,512, 28, 28, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
|
||||
/************************ Conv 4.1 ************************/
|
||||
// Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1,512,256, 28, 28, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1,256,256, 28, 28, 3,3, 1,1,1,1, 2,2, 1,1});
|
||||
b->Args({2, 1, 1,256,1024,14, 14, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1,512,1024,28, 28, 1,1, 0,0,0,0, 2,2, 1,1});
|
||||
|
||||
/************************ Conv 4.X ************************/
|
||||
// Rank, N, G, Cpg, Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1, 1024, 256, 14, 14, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1, 256, 256, 14, 14, 3,3, 1,1,1,1, 1,1, 1,1});
|
||||
//b->Args({2, 1, 1, 256, 1024, 14, 14, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
|
||||
/************************ Conv 5.1 ************************/
|
||||
// Rank, N, G, Cpg, Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1, 1024, 512, 14, 14, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1, 512, 512, 14, 14, 3,3, 1,1,1,1, 2,2, 1,1});
|
||||
b->Args({2, 1, 1, 512,2048, 7, 7, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1, 1024,2048, 14, 14, 1,1, 0,0,0,0, 2,2, 1,1});
|
||||
|
||||
/************************ Conv 5.X ************************/
|
||||
// Rank, N, G, Cpg, Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1, 2048, 512, 7, 7, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
b->Args({2, 1, 1, 512, 512, 7, 7, 3,3, 1,1,1,1, 1,1, 1,1});
|
||||
//b->Args({2, 1, 1, 512,2048, 7, 7, 1,1, 0,0,0,0, 1,1, 1,1});
|
||||
}
|
||||
|
||||
BENCHMARK_CAPTURE(SCONV_NCHW, ResNet50, "")->Apply(ResNet50)->UseRealTime();
|
||||
|
||||
static void TeamsModel(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(ArgNamesForConv(2));
|
||||
// Rank, N, G, Cpg, Fpg, I, , K, , P, , , , S, , D, ,
|
||||
b->Args({2, 1, 1, 40, 24, 24, 40, 3,3, 1,1,1,1, 1,1, 1,1}); // fused conv_349 => 24x40
|
||||
b->Args({2, 1, 1, 24, 24, 24, 40, 3,3, 1,1,1,1, 1,1, 1,1}); // fused Conv_367 => 24x40
|
||||
b->Args({2, 1, 1, 4, 24, 96,160, 3,3, 0,0,1,1, 2,2, 1,1}); // fused Conv_15 => 48x80
|
||||
b->Args({2, 1, 1, 12, 72, 48, 80, 1,1, 0,0,0,0, 1,1, 1,1}); // fused Conv_38 => 48x80
|
||||
b->Args({2, 1, 1, 12, 8, 48, 80, 3,3, 1,1,1,1, 1,1, 1,1}); // fused Conv_395 => 48x80
|
||||
b->Args({2, 1,24, 1, 1, 48, 80, 3,3, 1,1,1,1, 1,1, 1,1}); // fused Conv_33 => 48x80
|
||||
b->Args({2, 1, 1, 8, 8, 48, 80, 3,3, 1,1,1,1, 1,1, 1,1}); // fused Conv_413 => 48x80
|
||||
b->Args({2, 1,72, 1, 1, 48, 80, 3,3, 0,0,1,1, 2,2, 1,1}); // fused Conv_56 => 24x40
|
||||
b->Args({2, 1,72, 1, 1, 24, 40, 3,3, 1,1,1,1, 2,2, 1,1}); // fused Conv_79 => 24x40
|
||||
b->Args({2, 1, 1, 24, 12, 48, 80, 1,1, 0,0,0,0, 1,1, 1,1}); // Conv_36 => 48x80
|
||||
b->Args({2, 1, 1, 12, 72, 24, 40, 1,1, 0,0,0,0, 1,1, 1,1}); // fused Conv_61/85 => 24x40
|
||||
b->Args({2, 1, 1, 24, 144, 12, 20, 1,1, 0,0,0,0, 1,1, 1,1}); // fused Conv_108/132 => 12x20
|
||||
|
||||
b->Args({2, 1, 1, 12, 12, 48, 80, 1,1, 0,0,0,0, 1,1, 1,1}); // fused Conv_376 => 48x80
|
||||
b->Args({2, 1, 1, 12, 72, 48, 80, 1,1, 0,0,0,0, 1,1, 1,1}); // Conv_59 => 24x40
|
||||
}
|
||||
|
||||
BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();
|
||||
|
||||
static void General_Conv2d(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(ArgNamesForConv(2));
|
||||
ArgsProduct(
|
||||
b,
|
||||
{{2}, // Rank,
|
||||
{1}, // N
|
||||
{1, 2}, // Groups
|
||||
{3, 12}, // Cpg
|
||||
{6}, // Fpg
|
||||
{24, 72}, // Input Image Shape
|
||||
{36},
|
||||
{3}, // kernel shape
|
||||
{3},
|
||||
{0}, // paddings
|
||||
{0},
|
||||
{0},
|
||||
{0},
|
||||
{1}, // strides
|
||||
{1},
|
||||
{1}, // dilations
|
||||
{1}});
|
||||
}
|
||||
|
||||
BENCHMARK_CAPTURE(SCONV_NCHW, 2d, "")->Apply(General_Conv2d)->UseRealTime();
|
||||
120
onnxruntime/test/mlas/bench/bench_sgemm.cpp
Normal file
120
onnxruntime/test/mlas/bench/bench_sgemm.cpp
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "mlas.h"
|
||||
#include "bench_util.h"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <numeric>
|
||||
|
||||
static const std::vector<std::string> sgemm_bench_arg_names = {"M", "N", "K"};
|
||||
|
||||
void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, float alpha = 1.0f, float beta = 0.0f) {
|
||||
const int64_t M = state.range(0);
|
||||
const int64_t N = state.range(1);
|
||||
const int64_t K = state.range(2);
|
||||
|
||||
if (M <= 0) throw std::invalid_argument("M must greater than 0!");
|
||||
if (N <= 0) throw std::invalid_argument("N must greater than 0!");
|
||||
if (K <= 0) throw std::invalid_argument("K must greater than 0!");
|
||||
|
||||
auto A = RandomVectorUniform(static_cast<size_t>(M * K), -1.0, 1.0);
|
||||
auto B = RandomVectorUniform(static_cast<size_t>(N * K), -1.0, 1.0);
|
||||
std::vector<float> C(static_cast<size_t>(M * N));
|
||||
|
||||
if (pack_b) {
|
||||
size_t pack_b_size = MlasGemmPackBSize(N, K);
|
||||
std::vector<float> B_packed(pack_b_size);
|
||||
MlasGemmPackB(CblasNoTrans, N, K, B.data(), N, B_packed.data());
|
||||
|
||||
MlasGemm(
|
||||
trans_a ? CblasTrans : CblasNoTrans,
|
||||
static_cast<size_t>(M),
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(K),
|
||||
alpha,
|
||||
A.data(),
|
||||
trans_a ? M : K,
|
||||
B_packed.data(),
|
||||
beta,
|
||||
C.data(),
|
||||
N,
|
||||
nullptr);
|
||||
|
||||
for (auto _ : state) {
|
||||
MlasGemm(
|
||||
trans_a ? CblasTrans : CblasNoTrans,
|
||||
static_cast<size_t>(M),
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(K),
|
||||
alpha,
|
||||
A.data(),
|
||||
trans_a ? M : K,
|
||||
B_packed.data(),
|
||||
beta,
|
||||
C.data(),
|
||||
N,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
} else {
|
||||
MlasGemm(
|
||||
trans_a ? CblasTrans : CblasNoTrans,
|
||||
trans_b ? CblasTrans : CblasNoTrans,
|
||||
static_cast<size_t>(M),
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(K),
|
||||
alpha,
|
||||
A.data(),
|
||||
trans_a ? M : K,
|
||||
B.data(),
|
||||
trans_b ? K : N,
|
||||
beta,
|
||||
C.data(),
|
||||
N,
|
||||
nullptr);
|
||||
|
||||
for (auto _ : state) {
|
||||
MlasGemm(
|
||||
trans_a ? CblasTrans : CblasNoTrans,
|
||||
trans_b ? CblasTrans : CblasNoTrans,
|
||||
static_cast<size_t>(M),
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(K),
|
||||
alpha,
|
||||
A.data(),
|
||||
trans_a ? M : K,
|
||||
B.data(),
|
||||
trans_b ? K : N,
|
||||
beta,
|
||||
C.data(),
|
||||
N,
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void GemmSizeWithOne(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(sgemm_bench_arg_names);
|
||||
ArgsProduct(b, {{1}, {63, 255, 1023}, {63, 255, 1023}});
|
||||
ArgsProduct(b, {{63, 255, 1023}, {1}, {63, 255, 1023}});
|
||||
ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {1}});
|
||||
}
|
||||
|
||||
static void GemmSizeProducts(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(sgemm_bench_arg_names);
|
||||
ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}});
|
||||
}
|
||||
|
||||
BENCHMARK_CAPTURE(SGEMM, NORMAL_NoTrans, false, false, false)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(SGEMM, NORMAL_TransA, false, true, false)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(SGEMM, NORMAL_TransB, false, false, true)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(SGEMM, NORMAL_ABTrans, false, true, true)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
|
||||
BENCHMARK_CAPTURE(SGEMM, GEMV_NoTrans, false, false, false)->Apply(GemmSizeWithOne)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(SGEMM, GEMV_TransA, false, true, false)->Apply(GemmSizeWithOne)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(SGEMM, GEMV_TransB, false, false, true)->Apply(GemmSizeWithOne)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(SGEMM, GEMV_ABTrans, false, true, true)->Apply(GemmSizeWithOne)->UseRealTime();
|
||||
|
||||
BENCHMARK_CAPTURE(SGEMM, PACKB_NoTransA, true, false, false)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(SGEMM, PACKB_TransA, true, true, false)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
64
onnxruntime/test/mlas/bench/bench_util.cpp
Normal file
64
onnxruntime/test/mlas/bench/bench_util.cpp
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "bench_util.h"
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
|
||||
std::vector<int64_t> BenchArgsVector(benchmark::State& state, size_t& start, size_t count) {
|
||||
std::vector<int64_t> shape;
|
||||
shape.reserve(count);
|
||||
for (size_t axis = 0; axis < count; ++axis) {
|
||||
shape.emplace_back(state.range(start + axis));
|
||||
}
|
||||
start += count;
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<float> RandomVectorUniform(size_t N, float min_value, float max_value) {
|
||||
if (min_value >= max_value) {
|
||||
return std::vector<float>(N, min_value);
|
||||
}
|
||||
std::default_random_engine generator(static_cast<unsigned>(N));
|
||||
std::uniform_real_distribution<float> distribution(min_value, max_value);
|
||||
std::vector<float> r(N);
|
||||
for (size_t i = 0; i < N; i++) {
|
||||
r[i] = distribution(generator);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
std::vector<float> RandomVectorUniform(std::vector<int64_t> shape, float min_value, float max_value) {
|
||||
int64_t sz = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
|
||||
if (sz <= 0) {
|
||||
throw std::invalid_argument("shape gives size must greater than 0!");
|
||||
}
|
||||
return RandomVectorUniform(static_cast<size_t>(sz), min_value, max_value);
|
||||
}
|
||||
|
||||
// The Benchmark used here do not contains this as in newer version.
|
||||
// Use the code from newer version.
|
||||
void ArgsProduct(benchmark::internal::Benchmark* bench,
|
||||
const std::vector<std::vector<int64_t>>& arglists) {
|
||||
std::vector<std::size_t> indices(arglists.size(), 0);
|
||||
const std::size_t total = std::accumulate(
|
||||
std::begin(arglists), std::end(arglists), std::size_t{1},
|
||||
[](const std::size_t res, const std::vector<int64_t>& arglist) {
|
||||
return res * arglist.size();
|
||||
});
|
||||
std::vector<int64_t> args;
|
||||
args.reserve(arglists.size());
|
||||
for (std::size_t i = 0; i < total; i++) {
|
||||
for (std::size_t arg = 0; arg < arglists.size(); arg++) {
|
||||
args.push_back(arglists[arg][indices[arg]]);
|
||||
}
|
||||
bench->Args(args);
|
||||
args.clear();
|
||||
|
||||
std::size_t arg = 0;
|
||||
do {
|
||||
indices[arg] = (indices[arg] + 1) % arglists[arg].size();
|
||||
} while (indices[arg++] == 0 && arg < arglists.size());
|
||||
}
|
||||
}
|
||||
15
onnxruntime/test/mlas/bench/bench_util.h
Normal file
15
onnxruntime/test/mlas/bench/bench_util.h
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <benchmark/benchmark.h>
|
||||
|
||||
void ArgsProduct(benchmark::internal::Benchmark* bench,
|
||||
const std::vector<std::vector<int64_t>>& arglists);
|
||||
|
||||
std::vector<float> RandomVectorUniform(size_t N, float min_value, float max_value);
|
||||
|
||||
std::vector<float> RandomVectorUniform(std::vector<int64_t> shape, float min_value, float max_value);
|
||||
|
||||
std::vector<int64_t> BenchArgsVector(benchmark::State& state, size_t& start, size_t count);
|
||||
Loading…
Reference in a new issue