From 8ee4e8226eafd76ee3bd29b302ca4f6b03395576 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Sat, 19 Sep 2020 00:45:44 -0700 Subject: [PATCH] Preserve relative order of the results and the tests. (#5225) --- onnxruntime/test/onnx/testcase_driver.cc | 21 ++++++++++----------- onnxruntime/test/onnx/testcase_driver.h | 2 +- onnxruntime/test/onnx/testcase_request.cc | 11 +++++++---- onnxruntime/test/onnx/testcase_request.h | 7 ++++--- 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/onnxruntime/test/onnx/testcase_driver.cc b/onnxruntime/test/onnx/testcase_driver.cc index 9348422aa8..494ca23354 100644 --- a/onnxruntime/test/onnx/testcase_driver.cc +++ b/onnxruntime/test/onnx/testcase_driver.cc @@ -17,25 +17,23 @@ TestCaseDriver::TestCaseDriver(const TestEnv& env, size_t concurrent_runs) tests_started_(0), tests_inprogress_(0), finished_(false) { - results_.reserve(env.GetTests().size()); - CallableFactory> f(this); + results_.resize(env.GetTests().size()); + CallableFactory> f(this); on_test_case_complete_ = f.GetCallable<&TestCaseDriver::OnTestCaseComplete>(); } - std::vector> TestCaseDriver::Run(const TestEnv& env, size_t concurrent_runs, size_t repeat_count) { std::vector> results; for (const auto& c : env.GetTests()) { auto result = TestCaseRequestContext::Run(env.GetThreadPool(), - *c, env.Env(), env.GetSessionOptions(), concurrent_runs, repeat_count); + *c, env.Env(), env.GetSessionOptions(), concurrent_runs, repeat_count); results.push_back(std::move(result)); } return results; } -std::vector> TestCaseDriver::RunParallel(const TestEnv& test_env, size_t parallel_models, - size_t concurrent_runs) { - +std::vector> TestCaseDriver::RunParallel(const TestEnv& test_env, size_t parallel_models, + size_t concurrent_runs) { assert(parallel_models > 1); parallel_models = std::min(parallel_models, test_env.GetTests().size()); LOGF_DEFAULT(ERROR, "Running tests in parallel: at most %u models at any time", static_cast(parallel_models)); @@ -55,7 +53,7 @@ void TestCaseDriver::RunModelsAsync(size_t parallel_models) { } tests_inprogress_.fetch_add(1, std::memory_order_relaxed); TestCaseRequestContext::Request(on_test_case_complete_, env_.GetThreadPool(), *tests[next_to_run], - env_.Env(), env_.GetSessionOptions(), concurrent_runs_); + env_.Env(), env_.GetSessionOptions(), next_to_run, concurrent_runs_); } // This thread is not on a threadpool so we are not using it // to run anything. Just wait. @@ -63,10 +61,11 @@ void TestCaseDriver::RunModelsAsync(size_t parallel_models) { LOGF_DEFAULT(ERROR, "Running tests finished. Generating report"); } -void TestCaseDriver::OnTestCaseComplete(std::shared_ptr result) { +void TestCaseDriver::OnTestCaseComplete(size_t test_case_id, std::shared_ptr result) { + assert(test_case_id < results_.size()); { std::lock_guard g(mut_); - results_.push_back(std::move(result)); + results_ [test_case_id] = std::move(result); } const auto& tests = env_.GetTests(); @@ -75,7 +74,7 @@ void TestCaseDriver::OnTestCaseComplete(std::shared_ptr result) if (next_to_run < total_models) { tests_inprogress_.fetch_add(1, std::memory_order_relaxed); TestCaseRequestContext::Request(on_test_case_complete_, env_.GetThreadPool(), *tests[next_to_run], - env_.Env(), env_.GetSessionOptions(), concurrent_runs_); + env_.Env(), env_.GetSessionOptions(), next_to_run, concurrent_runs_); } auto before_we_done = tests_inprogress_.fetch_sub(1, std::memory_order_acq_rel); diff --git a/onnxruntime/test/onnx/testcase_driver.h b/onnxruntime/test/onnx/testcase_driver.h index 417755f20e..50a96f4b28 100644 --- a/onnxruntime/test/onnx/testcase_driver.h +++ b/onnxruntime/test/onnx/testcase_driver.h @@ -62,7 +62,7 @@ class TestCaseDriver { return std::move(results_); } - void OnTestCaseComplete(std::shared_ptr); + void OnTestCaseComplete(size_t, std::shared_ptr); const TestEnv& env_; size_t concurrent_runs_; diff --git a/onnxruntime/test/onnx/testcase_request.cc b/onnxruntime/test/onnx/testcase_request.cc index dc208c7a17..10de15aa12 100644 --- a/onnxruntime/test/onnx/testcase_request.cc +++ b/onnxruntime/test/onnx/testcase_request.cc @@ -14,13 +14,14 @@ namespace onnxruntime { namespace test { TestCaseRequestContext::TestCaseRequestContext(const Callback& cb, PThreadPool tp, const ITestCase& test_case, Ort::Env& env, - const Ort::SessionOptions& session_opts) + const Ort::SessionOptions& session_opts, size_t test_case_id) : cb_(cb), tp_(tp), test_case_(test_case), env_(env), session_opts_(session_opts.Clone()), session_(nullptr), + test_case_id_(test_case_id), allocator_(), result_(), data_tasks_started_(0), @@ -58,8 +59,9 @@ std::shared_ptr TestCaseRequestContext::Run(PThreadPool tpool, concurrent_runs = 1; } + // No callback, test_case_id is zero. Callback empty_cb; - TestCaseRequestContext ctx(empty_cb, tpool, c, env, session_opts); + TestCaseRequestContext ctx(empty_cb, tpool, c, env, session_opts, 0U); const size_t data_count = c.GetDataCount(); if (concurrent_runs > 1 && data_count > 1) { @@ -76,13 +78,14 @@ void TestCaseRequestContext::Request(const Callback& cb, PThreadPool tpool, const ITestCase& c, Ort::Env& env, const Ort::SessionOptions& session_opts, + size_t test_case_id, size_t concurrent_runs) { //temp hack. Because we have no resource control. We may not have enough memory to run this test in parallel if (c.GetTestCaseName() == "coreml_FNS-Candy_ImageNet") { concurrent_runs = 1; } - std::unique_ptr self(new TestCaseRequestContext(cb, tpool, c, env, session_opts)); + std::unique_ptr self(new TestCaseRequestContext(cb, tpool, c, env, session_opts, test_case_id)); CallableFactory f(self.get()); auto runnable = f.GetCallable<&TestCaseRequestContext::RunAsync>(); tpool->Schedule([runnable, concurrent_runs]() { runnable.Invoke(concurrent_runs); }); @@ -139,7 +142,7 @@ void TestCaseRequestContext::OnDataTaskComplete(size_t task_id, EXECUTE_RESULT r void TestCaseRequestContext::OnTestCaseComplete() { if (cb_) { std::unique_ptr self(this); - cb_.Invoke(std::move(result_)); + cb_.Invoke(test_case_id_, std::move(result_)); // No member access beyond this point } else { std::lock_guard g(mut_); diff --git a/onnxruntime/test/onnx/testcase_request.h b/onnxruntime/test/onnx/testcase_request.h index 8d95858156..75d02db3d6 100644 --- a/onnxruntime/test/onnx/testcase_request.h +++ b/onnxruntime/test/onnx/testcase_request.h @@ -32,7 +32,7 @@ namespace test { /// class TestCaseRequestContext { public: - using Callback = Callable>; + using Callback = Callable>; /// /// Runs data tests on the model sequentially (concurrent_runs < 2) @@ -64,7 +64,7 @@ class TestCaseRequestContext { /// static void Request(const Callback& cb, PThreadPool tpool, const ITestCase& c, Ort::Env& env, const Ort::SessionOptions& sf, - size_t concurrent_runs); + size_t test_case_id, size_t concurrent_runs); const TIME_SPEC& GetTimeSpend() const { return test_case_time_; @@ -81,7 +81,7 @@ class TestCaseRequestContext { private: TestCaseRequestContext(const Callback& cb, PThreadPool tp, const ITestCase& test_case, Ort::Env& env, - const Ort::SessionOptions& session_opts); + const Ort::SessionOptions& session_opts, size_t test_case_id); bool SetupSession(); @@ -106,6 +106,7 @@ class TestCaseRequestContext { Ort::Env& env_; Ort::SessionOptions session_opts_; Ort::Session session_; + size_t test_case_id_; MockedOrtAllocator allocator_; std::shared_ptr result_; TIME_SPEC test_case_time_;