Preserve relative order of the results and the tests. (#5225)

This commit is contained in:
Dmitri Smirnov 2020-09-19 00:45:44 -07:00 committed by GitHub
parent b49f6a5e2c
commit 8ee4e8226e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 19 deletions

View file

@ -17,25 +17,23 @@ TestCaseDriver::TestCaseDriver(const TestEnv& env, size_t concurrent_runs)
tests_started_(0),
tests_inprogress_(0),
finished_(false) {
results_.reserve(env.GetTests().size());
CallableFactory<TestCaseDriver, void, std::shared_ptr<TestCaseResult>> f(this);
results_.resize(env.GetTests().size());
CallableFactory<TestCaseDriver, void, size_t, std::shared_ptr<TestCaseResult>> f(this);
on_test_case_complete_ = f.GetCallable<&TestCaseDriver::OnTestCaseComplete>();
}
std::vector<std::shared_ptr<TestCaseResult>> TestCaseDriver::Run(const TestEnv& env, size_t concurrent_runs, size_t repeat_count) {
std::vector<std::shared_ptr<TestCaseResult>> results;
for (const auto& c : env.GetTests()) {
auto result = TestCaseRequestContext::Run(env.GetThreadPool(),
*c, env.Env(), env.GetSessionOptions(), concurrent_runs, repeat_count);
*c, env.Env(), env.GetSessionOptions(), concurrent_runs, repeat_count);
results.push_back(std::move(result));
}
return results;
}
std::vector<std::shared_ptr<TestCaseResult>> TestCaseDriver::RunParallel(const TestEnv& test_env, size_t parallel_models,
size_t concurrent_runs) {
std::vector<std::shared_ptr<TestCaseResult>> TestCaseDriver::RunParallel(const TestEnv& test_env, size_t parallel_models,
size_t concurrent_runs) {
assert(parallel_models > 1);
parallel_models = std::min(parallel_models, test_env.GetTests().size());
LOGF_DEFAULT(ERROR, "Running tests in parallel: at most %u models at any time", static_cast<unsigned int>(parallel_models));
@ -55,7 +53,7 @@ void TestCaseDriver::RunModelsAsync(size_t parallel_models) {
}
tests_inprogress_.fetch_add(1, std::memory_order_relaxed);
TestCaseRequestContext::Request(on_test_case_complete_, env_.GetThreadPool(), *tests[next_to_run],
env_.Env(), env_.GetSessionOptions(), concurrent_runs_);
env_.Env(), env_.GetSessionOptions(), next_to_run, concurrent_runs_);
}
// This thread is not on a threadpool so we are not using it
// to run anything. Just wait.
@ -63,10 +61,11 @@ void TestCaseDriver::RunModelsAsync(size_t parallel_models) {
LOGF_DEFAULT(ERROR, "Running tests finished. Generating report");
}
void TestCaseDriver::OnTestCaseComplete(std::shared_ptr<TestCaseResult> result) {
void TestCaseDriver::OnTestCaseComplete(size_t test_case_id, std::shared_ptr<TestCaseResult> result) {
assert(test_case_id < results_.size());
{
std::lock_guard<std::mutex> g(mut_);
results_.push_back(std::move(result));
results_ [test_case_id] = std::move(result);
}
const auto& tests = env_.GetTests();
@ -75,7 +74,7 @@ void TestCaseDriver::OnTestCaseComplete(std::shared_ptr<TestCaseResult> result)
if (next_to_run < total_models) {
tests_inprogress_.fetch_add(1, std::memory_order_relaxed);
TestCaseRequestContext::Request(on_test_case_complete_, env_.GetThreadPool(), *tests[next_to_run],
env_.Env(), env_.GetSessionOptions(), concurrent_runs_);
env_.Env(), env_.GetSessionOptions(), next_to_run, concurrent_runs_);
}
auto before_we_done = tests_inprogress_.fetch_sub(1, std::memory_order_acq_rel);

View file

@ -62,7 +62,7 @@ class TestCaseDriver {
return std::move(results_);
}
void OnTestCaseComplete(std::shared_ptr<TestCaseResult>);
void OnTestCaseComplete(size_t, std::shared_ptr<TestCaseResult>);
const TestEnv& env_;
size_t concurrent_runs_;

View file

@ -14,13 +14,14 @@ namespace onnxruntime {
namespace test {
TestCaseRequestContext::TestCaseRequestContext(const Callback& cb, PThreadPool tp, const ITestCase& test_case, Ort::Env& env,
const Ort::SessionOptions& session_opts)
const Ort::SessionOptions& session_opts, size_t test_case_id)
: cb_(cb),
tp_(tp),
test_case_(test_case),
env_(env),
session_opts_(session_opts.Clone()),
session_(nullptr),
test_case_id_(test_case_id),
allocator_(),
result_(),
data_tasks_started_(0),
@ -58,8 +59,9 @@ std::shared_ptr<TestCaseResult> TestCaseRequestContext::Run(PThreadPool tpool,
concurrent_runs = 1;
}
// No callback, test_case_id is zero.
Callback empty_cb;
TestCaseRequestContext ctx(empty_cb, tpool, c, env, session_opts);
TestCaseRequestContext ctx(empty_cb, tpool, c, env, session_opts, 0U);
const size_t data_count = c.GetDataCount();
if (concurrent_runs > 1 && data_count > 1) {
@ -76,13 +78,14 @@ void TestCaseRequestContext::Request(const Callback& cb, PThreadPool tpool,
const ITestCase& c,
Ort::Env& env,
const Ort::SessionOptions& session_opts,
size_t test_case_id,
size_t concurrent_runs) {
//temp hack. Because we have no resource control. We may not have enough memory to run this test in parallel
if (c.GetTestCaseName() == "coreml_FNS-Candy_ImageNet") {
concurrent_runs = 1;
}
std::unique_ptr<TestCaseRequestContext> self(new TestCaseRequestContext(cb, tpool, c, env, session_opts));
std::unique_ptr<TestCaseRequestContext> self(new TestCaseRequestContext(cb, tpool, c, env, session_opts, test_case_id));
CallableFactory<TestCaseRequestContext, void, size_t> f(self.get());
auto runnable = f.GetCallable<&TestCaseRequestContext::RunAsync>();
tpool->Schedule([runnable, concurrent_runs]() { runnable.Invoke(concurrent_runs); });
@ -139,7 +142,7 @@ void TestCaseRequestContext::OnDataTaskComplete(size_t task_id, EXECUTE_RESULT r
void TestCaseRequestContext::OnTestCaseComplete() {
if (cb_) {
std::unique_ptr<TestCaseRequestContext> self(this);
cb_.Invoke(std::move(result_));
cb_.Invoke(test_case_id_, std::move(result_));
// No member access beyond this point
} else {
std::lock_guard<std::mutex> g(mut_);

View file

@ -32,7 +32,7 @@ namespace test {
/// </summary>
class TestCaseRequestContext {
public:
using Callback = Callable<void, std::shared_ptr<TestCaseResult>>;
using Callback = Callable<void, size_t, std::shared_ptr<TestCaseResult>>;
/// <summary>
/// Runs data tests on the model sequentially (concurrent_runs < 2)
@ -64,7 +64,7 @@ class TestCaseRequestContext {
/// <param name="concurrent_runs"></param>
static void Request(const Callback& cb, PThreadPool tpool, const ITestCase& c,
Ort::Env& env, const Ort::SessionOptions& sf,
size_t concurrent_runs);
size_t test_case_id, size_t concurrent_runs);
const TIME_SPEC& GetTimeSpend() const {
return test_case_time_;
@ -81,7 +81,7 @@ class TestCaseRequestContext {
private:
TestCaseRequestContext(const Callback& cb, PThreadPool tp, const ITestCase& test_case, Ort::Env& env,
const Ort::SessionOptions& session_opts);
const Ort::SessionOptions& session_opts, size_t test_case_id);
bool SetupSession();
@ -106,6 +106,7 @@ class TestCaseRequestContext {
Ort::Env& env_;
Ort::SessionOptions session_opts_;
Ort::Session session_;
size_t test_case_id_;
MockedOrtAllocator allocator_;
std::shared_ptr<TestCaseResult> result_;
TIME_SPEC test_case_time_;