mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Preserve relative order of the results and the tests. (#5225)
This commit is contained in:
parent
b49f6a5e2c
commit
8ee4e8226e
4 changed files with 22 additions and 19 deletions
|
|
@ -17,25 +17,23 @@ TestCaseDriver::TestCaseDriver(const TestEnv& env, size_t concurrent_runs)
|
|||
tests_started_(0),
|
||||
tests_inprogress_(0),
|
||||
finished_(false) {
|
||||
results_.reserve(env.GetTests().size());
|
||||
CallableFactory<TestCaseDriver, void, std::shared_ptr<TestCaseResult>> f(this);
|
||||
results_.resize(env.GetTests().size());
|
||||
CallableFactory<TestCaseDriver, void, size_t, std::shared_ptr<TestCaseResult>> f(this);
|
||||
on_test_case_complete_ = f.GetCallable<&TestCaseDriver::OnTestCaseComplete>();
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::shared_ptr<TestCaseResult>> TestCaseDriver::Run(const TestEnv& env, size_t concurrent_runs, size_t repeat_count) {
|
||||
std::vector<std::shared_ptr<TestCaseResult>> results;
|
||||
for (const auto& c : env.GetTests()) {
|
||||
auto result = TestCaseRequestContext::Run(env.GetThreadPool(),
|
||||
*c, env.Env(), env.GetSessionOptions(), concurrent_runs, repeat_count);
|
||||
*c, env.Env(), env.GetSessionOptions(), concurrent_runs, repeat_count);
|
||||
results.push_back(std::move(result));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<TestCaseResult>> TestCaseDriver::RunParallel(const TestEnv& test_env, size_t parallel_models,
|
||||
size_t concurrent_runs) {
|
||||
|
||||
std::vector<std::shared_ptr<TestCaseResult>> TestCaseDriver::RunParallel(const TestEnv& test_env, size_t parallel_models,
|
||||
size_t concurrent_runs) {
|
||||
assert(parallel_models > 1);
|
||||
parallel_models = std::min(parallel_models, test_env.GetTests().size());
|
||||
LOGF_DEFAULT(ERROR, "Running tests in parallel: at most %u models at any time", static_cast<unsigned int>(parallel_models));
|
||||
|
|
@ -55,7 +53,7 @@ void TestCaseDriver::RunModelsAsync(size_t parallel_models) {
|
|||
}
|
||||
tests_inprogress_.fetch_add(1, std::memory_order_relaxed);
|
||||
TestCaseRequestContext::Request(on_test_case_complete_, env_.GetThreadPool(), *tests[next_to_run],
|
||||
env_.Env(), env_.GetSessionOptions(), concurrent_runs_);
|
||||
env_.Env(), env_.GetSessionOptions(), next_to_run, concurrent_runs_);
|
||||
}
|
||||
// This thread is not on a threadpool so we are not using it
|
||||
// to run anything. Just wait.
|
||||
|
|
@ -63,10 +61,11 @@ void TestCaseDriver::RunModelsAsync(size_t parallel_models) {
|
|||
LOGF_DEFAULT(ERROR, "Running tests finished. Generating report");
|
||||
}
|
||||
|
||||
void TestCaseDriver::OnTestCaseComplete(std::shared_ptr<TestCaseResult> result) {
|
||||
void TestCaseDriver::OnTestCaseComplete(size_t test_case_id, std::shared_ptr<TestCaseResult> result) {
|
||||
assert(test_case_id < results_.size());
|
||||
{
|
||||
std::lock_guard<std::mutex> g(mut_);
|
||||
results_.push_back(std::move(result));
|
||||
results_ [test_case_id] = std::move(result);
|
||||
}
|
||||
|
||||
const auto& tests = env_.GetTests();
|
||||
|
|
@ -75,7 +74,7 @@ void TestCaseDriver::OnTestCaseComplete(std::shared_ptr<TestCaseResult> result)
|
|||
if (next_to_run < total_models) {
|
||||
tests_inprogress_.fetch_add(1, std::memory_order_relaxed);
|
||||
TestCaseRequestContext::Request(on_test_case_complete_, env_.GetThreadPool(), *tests[next_to_run],
|
||||
env_.Env(), env_.GetSessionOptions(), concurrent_runs_);
|
||||
env_.Env(), env_.GetSessionOptions(), next_to_run, concurrent_runs_);
|
||||
}
|
||||
|
||||
auto before_we_done = tests_inprogress_.fetch_sub(1, std::memory_order_acq_rel);
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class TestCaseDriver {
|
|||
return std::move(results_);
|
||||
}
|
||||
|
||||
void OnTestCaseComplete(std::shared_ptr<TestCaseResult>);
|
||||
void OnTestCaseComplete(size_t, std::shared_ptr<TestCaseResult>);
|
||||
|
||||
const TestEnv& env_;
|
||||
size_t concurrent_runs_;
|
||||
|
|
|
|||
|
|
@ -14,13 +14,14 @@ namespace onnxruntime {
|
|||
namespace test {
|
||||
|
||||
TestCaseRequestContext::TestCaseRequestContext(const Callback& cb, PThreadPool tp, const ITestCase& test_case, Ort::Env& env,
|
||||
const Ort::SessionOptions& session_opts)
|
||||
const Ort::SessionOptions& session_opts, size_t test_case_id)
|
||||
: cb_(cb),
|
||||
tp_(tp),
|
||||
test_case_(test_case),
|
||||
env_(env),
|
||||
session_opts_(session_opts.Clone()),
|
||||
session_(nullptr),
|
||||
test_case_id_(test_case_id),
|
||||
allocator_(),
|
||||
result_(),
|
||||
data_tasks_started_(0),
|
||||
|
|
@ -58,8 +59,9 @@ std::shared_ptr<TestCaseResult> TestCaseRequestContext::Run(PThreadPool tpool,
|
|||
concurrent_runs = 1;
|
||||
}
|
||||
|
||||
// No callback, test_case_id is zero.
|
||||
Callback empty_cb;
|
||||
TestCaseRequestContext ctx(empty_cb, tpool, c, env, session_opts);
|
||||
TestCaseRequestContext ctx(empty_cb, tpool, c, env, session_opts, 0U);
|
||||
|
||||
const size_t data_count = c.GetDataCount();
|
||||
if (concurrent_runs > 1 && data_count > 1) {
|
||||
|
|
@ -76,13 +78,14 @@ void TestCaseRequestContext::Request(const Callback& cb, PThreadPool tpool,
|
|||
const ITestCase& c,
|
||||
Ort::Env& env,
|
||||
const Ort::SessionOptions& session_opts,
|
||||
size_t test_case_id,
|
||||
size_t concurrent_runs) {
|
||||
//temp hack. Because we have no resource control. We may not have enough memory to run this test in parallel
|
||||
if (c.GetTestCaseName() == "coreml_FNS-Candy_ImageNet") {
|
||||
concurrent_runs = 1;
|
||||
}
|
||||
|
||||
std::unique_ptr<TestCaseRequestContext> self(new TestCaseRequestContext(cb, tpool, c, env, session_opts));
|
||||
std::unique_ptr<TestCaseRequestContext> self(new TestCaseRequestContext(cb, tpool, c, env, session_opts, test_case_id));
|
||||
CallableFactory<TestCaseRequestContext, void, size_t> f(self.get());
|
||||
auto runnable = f.GetCallable<&TestCaseRequestContext::RunAsync>();
|
||||
tpool->Schedule([runnable, concurrent_runs]() { runnable.Invoke(concurrent_runs); });
|
||||
|
|
@ -139,7 +142,7 @@ void TestCaseRequestContext::OnDataTaskComplete(size_t task_id, EXECUTE_RESULT r
|
|||
void TestCaseRequestContext::OnTestCaseComplete() {
|
||||
if (cb_) {
|
||||
std::unique_ptr<TestCaseRequestContext> self(this);
|
||||
cb_.Invoke(std::move(result_));
|
||||
cb_.Invoke(test_case_id_, std::move(result_));
|
||||
// No member access beyond this point
|
||||
} else {
|
||||
std::lock_guard<std::mutex> g(mut_);
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ namespace test {
|
|||
/// </summary>
|
||||
class TestCaseRequestContext {
|
||||
public:
|
||||
using Callback = Callable<void, std::shared_ptr<TestCaseResult>>;
|
||||
using Callback = Callable<void, size_t, std::shared_ptr<TestCaseResult>>;
|
||||
|
||||
/// <summary>
|
||||
/// Runs data tests on the model sequentially (concurrent_runs < 2)
|
||||
|
|
@ -64,7 +64,7 @@ class TestCaseRequestContext {
|
|||
/// <param name="concurrent_runs"></param>
|
||||
static void Request(const Callback& cb, PThreadPool tpool, const ITestCase& c,
|
||||
Ort::Env& env, const Ort::SessionOptions& sf,
|
||||
size_t concurrent_runs);
|
||||
size_t test_case_id, size_t concurrent_runs);
|
||||
|
||||
const TIME_SPEC& GetTimeSpend() const {
|
||||
return test_case_time_;
|
||||
|
|
@ -81,7 +81,7 @@ class TestCaseRequestContext {
|
|||
private:
|
||||
|
||||
TestCaseRequestContext(const Callback& cb, PThreadPool tp, const ITestCase& test_case, Ort::Env& env,
|
||||
const Ort::SessionOptions& session_opts);
|
||||
const Ort::SessionOptions& session_opts, size_t test_case_id);
|
||||
|
||||
bool SetupSession();
|
||||
|
||||
|
|
@ -106,6 +106,7 @@ class TestCaseRequestContext {
|
|||
Ort::Env& env_;
|
||||
Ort::SessionOptions session_opts_;
|
||||
Ort::Session session_;
|
||||
size_t test_case_id_;
|
||||
MockedOrtAllocator allocator_;
|
||||
std::shared_ptr<TestCaseResult> result_;
|
||||
TIME_SPEC test_case_time_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue