mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
336 lines
12 KiB
C++
336 lines
12 KiB
C++
#include "testPch.h"
|
|
|
|
#include "concurrencytests.h"
|
|
#include "model.h"
|
|
#include "SqueezeNetValidator.h"
|
|
#include "threadPool.h"
|
|
|
|
#include <chrono>
|
|
#include <cstdlib>
|
|
#include <ctime>
|
|
#include <fstream>
|
|
|
|
using namespace winml;
|
|
using namespace winrt;
|
|
|
|
namespace {
|
|
void LoadBindEvalSqueezenetRealDataWithValidationConcurrently() {
|
|
WINML_SKIP_TEST("Skipping due to bug 21617097");
|
|
|
|
constexpr auto load_test_model = [](const std::string& instance, LearningModelDeviceKind device) {
|
|
WinML::Engine::Test::ModelValidator::SqueezeNet(instance, device, 0.00001f, false);
|
|
};
|
|
|
|
std::vector<std::thread> threads;
|
|
for (const auto& instance : {"1", "2", "3", "4"}) {
|
|
threads.emplace_back(load_test_model, instance, LearningModelDeviceKind::Cpu);
|
|
}
|
|
if (SkipGpuTests()) {
|
|
} else {
|
|
for (const auto& instance : {"GPU_1", "GPU_2", "GPU_3", "GPU_4"}) {
|
|
threads.emplace_back(load_test_model, instance, LearningModelDeviceKind::DirectX);
|
|
}
|
|
}
|
|
|
|
for (auto& thread : threads) {
|
|
thread.join();
|
|
}
|
|
}
|
|
|
|
void ConcurrencyTestsClassSetup() {
|
|
init_apartment();
|
|
std::srand(static_cast<unsigned>(std::time(nullptr)));
|
|
#ifdef BUILD_INBOX
|
|
winrt_activation_handler = WINRT_RoGetActivationFactory;
|
|
#endif
|
|
}
|
|
|
|
struct EvaluationUnit {
|
|
LearningModel model;
|
|
LearningModelSession session;
|
|
LearningModelBinding binding;
|
|
wf::IAsyncOperation<LearningModelEvaluationResult> operation;
|
|
LearningModelEvaluationResult result;
|
|
|
|
EvaluationUnit() : model(nullptr), session(nullptr), binding(nullptr), result(nullptr) {}
|
|
};
|
|
|
|
// Run EvalAsync for each unit concurrently and get results
|
|
void RunAsync(std::vector<EvaluationUnit>& evaluation_units) {
|
|
std::for_each(evaluation_units.begin(), evaluation_units.end(), [](EvaluationUnit& unit) {
|
|
unit.operation = unit.session.EvaluateAsync(unit.binding, L"");
|
|
});
|
|
// get results
|
|
std::for_each(evaluation_units.begin(), evaluation_units.end(), [](EvaluationUnit& unit) {
|
|
unit.result = unit.operation.get();
|
|
});
|
|
}
|
|
|
|
void VerifyEvaluation(const std::vector<EvaluationUnit>& evaluation_units, std::vector<uint32_t> expected_indices) {
|
|
assert(evaluation_units.size() == expected_indices.size());
|
|
for (size_t i = 0; i < evaluation_units.size(); ++i) {
|
|
auto unit = evaluation_units[i];
|
|
auto expectedIndex = expected_indices[i];
|
|
auto result = unit.result.Outputs().Lookup(L"softmaxout_1").as<TensorFloat>().GetAsVectorView();
|
|
int64_t maxIndex = 0;
|
|
float maxValue = 0;
|
|
for (uint32_t j = 0; j < result.Size(); ++j) {
|
|
float val = result.GetAt(j);
|
|
if (val > maxValue) {
|
|
maxValue = val;
|
|
maxIndex = j;
|
|
}
|
|
}
|
|
WINML_EXPECT_TRUE(maxIndex == expectedIndex);
|
|
}
|
|
}
|
|
|
|
void CopyLocalFile(LPCWSTR from, LPCWSTR to) {
|
|
using namespace std;
|
|
ifstream source(FileHelpers::GetModulePath() + from, ios::binary);
|
|
ofstream dest(FileHelpers::GetModulePath() + to, ios::binary);
|
|
|
|
dest << source.rdbuf();
|
|
}
|
|
|
|
// Run evaluations with different models
|
|
void EvalAsyncDifferentModels() {
|
|
CopyLocalFile(L"model.onnx", L"model2.onnx");
|
|
|
|
std::vector<std::wstring> model_paths = {L"model.onnx", L"model2.onnx"};
|
|
const unsigned int num_units = static_cast<unsigned int>(model_paths.size());
|
|
std::vector<EvaluationUnit> evaluation_units(num_units, EvaluationUnit());
|
|
|
|
auto ifv = FileHelpers::LoadImageFeatureValue(L"kitten_224.png");
|
|
|
|
for (unsigned int i = 0; i < num_units; ++i) {
|
|
evaluation_units[i].model = LearningModel::LoadFromFilePath(FileHelpers::GetModulePath() + model_paths[i]);
|
|
evaluation_units[i].session = LearningModelSession(evaluation_units[i].model);
|
|
evaluation_units[i].binding = LearningModelBinding(evaluation_units[i].session);
|
|
evaluation_units[i].binding.Bind(L"data_0", ifv);
|
|
}
|
|
|
|
RunAsync(evaluation_units);
|
|
std::vector<uint32_t> indices(num_units, TABBY_CAT_INDEX);
|
|
VerifyEvaluation(evaluation_units, indices);
|
|
}
|
|
|
|
// Run evaluations with same model, different sessions
|
|
void EvalAsyncDifferentSessions() {
|
|
unsigned int num_units = 3;
|
|
std::vector<EvaluationUnit> evaluation_units(num_units, EvaluationUnit());
|
|
auto ifv = FileHelpers::LoadImageFeatureValue(L"kitten_224.png");
|
|
|
|
// same model, different session
|
|
auto model = LearningModel::LoadFromFilePath(FileHelpers::GetModulePath() + L"model.onnx");
|
|
for (unsigned int i = 0; i < num_units; ++i) {
|
|
evaluation_units[i].model = model;
|
|
evaluation_units[i].session = LearningModelSession(evaluation_units[i].model);
|
|
evaluation_units[i].binding = LearningModelBinding(evaluation_units[i].session);
|
|
evaluation_units[i].binding.Bind(L"data_0", ifv);
|
|
}
|
|
|
|
RunAsync(evaluation_units);
|
|
std::vector<uint32_t> indices(num_units, TABBY_CAT_INDEX);
|
|
VerifyEvaluation(evaluation_units, indices);
|
|
}
|
|
|
|
// Run evaluations with same session (and model), with different bindings
|
|
void EvalAsyncDifferentBindings() {
|
|
unsigned int num_units = 2;
|
|
std::vector<EvaluationUnit> evaluation_units(num_units, EvaluationUnit());
|
|
|
|
std::vector<ImageFeatureValue> ifvs = {
|
|
FileHelpers::LoadImageFeatureValue(L"kitten_224.png"), FileHelpers::LoadImageFeatureValue(L"fish.png")
|
|
};
|
|
|
|
// same session, different binding
|
|
auto model = LearningModel::LoadFromFilePath(FileHelpers::GetModulePath() + L"model.onnx");
|
|
auto session = LearningModelSession(model);
|
|
for (unsigned int i = 0; i < num_units; ++i) {
|
|
evaluation_units[i].model = model;
|
|
evaluation_units[i].session = session;
|
|
evaluation_units[i].binding = LearningModelBinding(evaluation_units[i].session);
|
|
evaluation_units[i].binding.Bind(L"data_0", ifvs[i]);
|
|
}
|
|
|
|
RunAsync(evaluation_units);
|
|
VerifyEvaluation(evaluation_units, {TABBY_CAT_INDEX, TENCH_INDEX});
|
|
}
|
|
|
|
winml::ILearningModelFeatureDescriptor UnusedCreateFeatureDescriptor(
|
|
std::shared_ptr<onnxruntime::Model> model,
|
|
const std::wstring& name,
|
|
const std::wstring& description,
|
|
bool is_required,
|
|
const ::onnx::TypeProto* type_proto
|
|
);
|
|
|
|
// Get random number in interval [1,max_number]
|
|
unsigned int GetRandomNumber(unsigned int max_number) {
|
|
return std::rand() % max_number + 1;
|
|
}
|
|
|
|
void MultiThreadLoadModel() {
|
|
// load same model
|
|
auto path = FileHelpers::GetModulePath() + L"model.onnx";
|
|
ThreadPool pool(NUM_THREADS);
|
|
try {
|
|
for (unsigned int i = 0; i < NUM_THREADS; ++i) {
|
|
pool.SubmitWork([&path]() {
|
|
auto model = LearningModel::LoadFromFilePath(path);
|
|
std::wstring name(model.Name());
|
|
WINML_EXPECT_EQUAL(name, L"squeezenet_old");
|
|
});
|
|
}
|
|
} catch (...) {
|
|
WINML_LOG_ERROR("Failed to load model concurrently.");
|
|
}
|
|
}
|
|
|
|
void MultiThreadMultiSessionOnDevice(const LearningModelDevice& device) {
|
|
auto path = FileHelpers::GetModulePath() + L"model.onnx";
|
|
auto model = LearningModel::LoadFromFilePath(path);
|
|
std::vector<ImageFeatureValue> ivfs = {
|
|
FileHelpers::LoadImageFeatureValue(L"kitten_224.png"), FileHelpers::LoadImageFeatureValue(L"fish.png")
|
|
};
|
|
std::vector<int> max_indices = {
|
|
281, // tabby, tabby cat
|
|
0 // tench, Tinca tinca
|
|
};
|
|
std::vector<float> max_values = {0.9314f, 0.7385f};
|
|
float tolerance = 0.001f;
|
|
std::vector<LearningModelSession> modelSessions(NUM_THREADS, nullptr);
|
|
ThreadPool pool(NUM_THREADS);
|
|
try {
|
|
device.as<IMetacommandsController>()->SetMetacommandsEnabled(false);
|
|
// create all the sessions
|
|
for (unsigned i = 0; i < NUM_THREADS; ++i) {
|
|
modelSessions[i] = LearningModelSession(model, device);
|
|
}
|
|
// start all the threads
|
|
for (unsigned i_thread = 0; i_thread < NUM_THREADS; ++i_thread) {
|
|
LearningModelSession& model_session = modelSessions[i_thread];
|
|
pool.SubmitWork([&model_session, &ivfs, &max_indices, &max_values, tolerance, i_thread]() {
|
|
ULONGLONG start_time = GetTickCount64();
|
|
while (((GetTickCount64() - start_time) / 1000) < NUM_SECONDS) {
|
|
auto j = i_thread % ivfs.size();
|
|
auto input = ivfs[j];
|
|
auto expected_index = max_indices[j];
|
|
auto expected_value = max_values[j];
|
|
LearningModelBinding bind(model_session);
|
|
bind.Bind(L"data_0", input);
|
|
auto result = model_session.Evaluate(bind, L"").Outputs();
|
|
auto softmax = result.Lookup(L"softmaxout_1");
|
|
if (auto tensor = softmax.try_as<ITensorFloat>()) {
|
|
auto view = tensor.GetAsVectorView();
|
|
float max_val = .0f;
|
|
int max_index = -1;
|
|
for (uint32_t i = 0; i < view.Size(); ++i) {
|
|
auto val = view.GetAt(i);
|
|
if (val > max_val) {
|
|
max_index = i;
|
|
max_val = val;
|
|
}
|
|
}
|
|
WINML_EXPECT_EQUAL(expected_index, max_index);
|
|
WINML_EXPECT_TRUE(std::abs(expected_value - max_val) < tolerance);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
} catch (...) {
|
|
WINML_LOG_ERROR("Failed to create session concurrently.");
|
|
}
|
|
}
|
|
|
|
void MultiThreadMultiSession() {
|
|
MultiThreadMultiSessionOnDevice(LearningModelDevice(LearningModelDeviceKind::Cpu));
|
|
}
|
|
|
|
void MultiThreadMultiSessionGpu() {
|
|
MultiThreadMultiSessionOnDevice(LearningModelDevice(LearningModelDeviceKind::DirectX));
|
|
}
|
|
|
|
// Create different sessions for each thread, and evaluate
|
|
void MultiThreadSingleSessionOnDevice(const LearningModelDevice& device) {
|
|
auto path = FileHelpers::GetModulePath() + L"model.onnx";
|
|
auto model = LearningModel::LoadFromFilePath(path);
|
|
LearningModelSession model_session = nullptr;
|
|
WINML_EXPECT_NO_THROW(model_session = LearningModelSession(model, device));
|
|
std::vector<ImageFeatureValue> ivfs = {
|
|
FileHelpers::LoadImageFeatureValue(L"kitten_224.png"), FileHelpers::LoadImageFeatureValue(L"fish.png")
|
|
};
|
|
std::vector<int> max_indices = {
|
|
281, // tabby, tabby cat
|
|
0 // tench, Tinca tinca
|
|
};
|
|
std::vector<float> max_values = {0.9314f, 0.7385f};
|
|
float tolerance = 0.001f;
|
|
|
|
ThreadPool pool(NUM_THREADS);
|
|
try {
|
|
for (unsigned i = 0; i < NUM_THREADS; ++i) {
|
|
pool.SubmitWork([&model_session, &ivfs, &max_indices, &max_values, tolerance, i]() {
|
|
ULONGLONG start_time = GetTickCount64();
|
|
while (((GetTickCount64() - start_time) / 1000) < NUM_SECONDS) {
|
|
auto j = i % ivfs.size();
|
|
auto input = ivfs[j];
|
|
auto expected_index = max_indices[j];
|
|
auto expected_value = max_values[j];
|
|
std::wstring name(model_session.Model().Name());
|
|
LearningModelBinding bind(model_session);
|
|
bind.Bind(L"data_0", input);
|
|
auto result = model_session.Evaluate(bind, L"").Outputs();
|
|
auto softmax = result.Lookup(L"softmaxout_1");
|
|
if (auto tensor = softmax.try_as<ITensorFloat>()) {
|
|
auto view = tensor.GetAsVectorView();
|
|
float max_val = .0f;
|
|
int max_index = -1;
|
|
for (uint32_t k = 0; k < view.Size(); ++k) {
|
|
auto val = view.GetAt(k);
|
|
if (val > max_val) {
|
|
max_index = k;
|
|
max_val = val;
|
|
}
|
|
}
|
|
WINML_EXPECT_EQUAL(expected_index, max_index);
|
|
WINML_EXPECT_TRUE(std::abs(expected_value - max_val) < tolerance);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
} catch (...) {
|
|
WINML_LOG_ERROR("Failed to create session concurrently.");
|
|
}
|
|
}
|
|
|
|
void MultiThreadSingleSession() {
|
|
MultiThreadSingleSessionOnDevice(LearningModelDevice(LearningModelDeviceKind::Cpu));
|
|
}
|
|
|
|
void MultiThreadSingleSessionGpu() {
|
|
MultiThreadSingleSessionOnDevice(LearningModelDevice(LearningModelDeviceKind::DirectX));
|
|
}
|
|
} // namespace
|
|
|
|
const ConcurrencyTestsApi& getapi() {
|
|
static ConcurrencyTestsApi api = {
|
|
ConcurrencyTestsClassSetup,
|
|
LoadBindEvalSqueezenetRealDataWithValidationConcurrently,
|
|
MultiThreadLoadModel,
|
|
MultiThreadMultiSession,
|
|
MultiThreadMultiSessionGpu,
|
|
MultiThreadSingleSession,
|
|
MultiThreadSingleSessionGpu,
|
|
EvalAsyncDifferentModels,
|
|
EvalAsyncDifferentSessions,
|
|
EvalAsyncDifferentBindings
|
|
};
|
|
|
|
if (SkipGpuTests()) {
|
|
api.MultiThreadMultiSessionGpu = SkipTest;
|
|
api.MultiThreadSingleSessionGpu = SkipTest;
|
|
}
|
|
return api;
|
|
}
|