From d22f6fddf72cd984553b0331a23fbf89b215c7cd Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 3 Jul 2020 09:26:47 +1000 Subject: [PATCH] Add ability to specify just the device when using IOBinding for an output (#4386) * Add ability to specify just the device when using IOBinding for an output. This enables keeping an output on a different device GPU when it has a dynamic size that is not known ahead of graph execution. --- onnxruntime/core/session/IOBinding.cc | 21 ++++++---- onnxruntime/core/session/IOBinding.h | 13 +++++- onnxruntime/core/session/inference_session.cc | 22 +++++++--- onnxruntime/core/session/inference_session.h | 3 +- onnxruntime/core/session/onnxruntime_c_api.cc | 4 +- .../test/framework/execution_frame_test.cc | 4 +- .../test/framework/inference_session_test.cc | 41 +++++++++++++++---- 7 files changed, 77 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/session/IOBinding.cc b/onnxruntime/core/session/IOBinding.cc index 156ef6db45..053c221ff1 100644 --- a/onnxruntime/core/session/IOBinding.cc +++ b/onnxruntime/core/session/IOBinding.cc @@ -78,33 +78,36 @@ common::Status IOBinding::SynchronizeOutputs() { return Status::OK(); } -common::Status IOBinding::BindOutput(const std::string& name, const OrtValue& ml_value) { +common::Status IOBinding::BindOutput(const std::string& name, const OrtValue& ml_value, OrtDevice device) { auto rc = Contains(output_names_, name); if (rc.first) { outputs_[rc.second] = ml_value; - return Status::OK(); + outputs_device_info_[rc.second] = device; + } else { + output_names_.push_back(name); + outputs_.push_back(ml_value); + outputs_device_info_.push_back(device); } - output_names_.push_back(name); - outputs_.push_back(ml_value); return Status::OK(); } void IOBinding::ClearOutputs() { output_names_.clear(); outputs_.clear(); + outputs_device_info_.clear(); } -const std::vector& IOBinding::GetOutputNames() const { - return output_names_; -} +const std::vector& IOBinding::GetOutputNames() const { return output_names_; } std::vector& IOBinding::GetOutputs() { return outputs_; } -const std::vector& IOBinding::GetInputNames() const { - return feed_names_; +const std::vector& IOBinding::GetOutputsDeviceInfo() const { + return outputs_device_info_; } +const std::vector& IOBinding::GetInputNames() const { return feed_names_; } + const std::vector& IOBinding::GetInputs() const { return feeds_; } AllocatorPtr IOBinding::GetCPUAllocator(int id, onnxruntime::ProviderType provider_type) const { diff --git a/onnxruntime/core/session/IOBinding.h b/onnxruntime/core/session/IOBinding.h index 93eeb08680..150db507cd 100644 --- a/onnxruntime/core/session/IOBinding.h +++ b/onnxruntime/core/session/IOBinding.h @@ -58,10 +58,15 @@ class IOBinding { */ common::Status SynchronizeInputs(); common::Status SynchronizeOutputs(); + /** - * This simply provides the names and optionally allocated output containers. + * Bind an output name to a provided OrtValue. + * If the output is pre-allocated, the value in 'device' is not used. + * If the output is not pre-allocated, information on the device it should be allocated on can be provided. + * + * @param device Device to allocate on if not pre-allocated. Default is CPU. */ - common::Status BindOutput(const std::string& name, const OrtValue& ml_value); + common::Status BindOutput(const std::string& name, const OrtValue& ml_value, OrtDevice device = {}); /** * This simply collects the outputs obtained after calling Run() inside the @param outputs. @@ -69,6 +74,9 @@ class IOBinding { const std::vector& GetOutputNames() const; std::vector& GetOutputs(); + // device info for outputs that are not pre-allocated + const std::vector& GetOutputsDeviceInfo() const; + const std::vector& GetInputNames() const; const std::vector& GetInputs() const; @@ -94,6 +102,7 @@ class IOBinding { std::vector feeds_; std::vector output_names_; std::vector outputs_; + std::vector outputs_device_info_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(IOBinding); }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index d0a2b6a97c..2a64562282 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1131,9 +1131,10 @@ common::Status InferenceSession::ValidateOutputs(const std::vector& return common::Status::OK(); } -Status InferenceSession::Run(const RunOptions& run_options, const std::vector& feed_names, - const std::vector& feeds, const std::vector& output_names, - std::vector* p_fetches) { +Status InferenceSession::Run(const RunOptions& run_options, + const std::vector& feed_names, const std::vector& feeds, + const std::vector& output_names, std::vector* p_fetches, + const std::vector* p_fetches_device_info) { TimePoint tp; if (session_profiler_.IsEnabled()) { tp = session_profiler_.StartTime(); @@ -1171,6 +1172,16 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vectorGetOrtValueNameIdxMap()); FeedsFetchesManager feeds_fetches_manager{std::move(info)}; + if (p_fetches_device_info) { + // populate the target device info. ignored if pre-allocated fetches are provided + const auto& fetch_device_info = *p_fetches_device_info; + auto& fetch_info = feeds_fetches_manager.GetMutableFetchesDeviceCopyInfo(); + + for (size_t i = 0, end = output_names.size(); i < end; ++i) { + fetch_info[i].target_device = fetch_device_info[i]; + } + } + if (!run_options.run_tag.empty()) { LOGS(*session_logger_, INFO) << "Running with tag: " << run_options.run_tag; } @@ -1206,7 +1217,6 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector InferenceSession::GetModelMetadata() const { @@ -1341,7 +1351,7 @@ common::Status InferenceSession::Run(const RunOptions& run_options, IOBinding& i // TODO should Run() call io_binding.SynchronizeInputs() or should it let the callers do it? // io_binding.SynchronizeInputs(); return Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(), - &io_binding.GetOutputs()); + &io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo()); } common::Status InferenceSession::Run(IOBinding& io_binding) { diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 9841821eb9..c34d54c5b7 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -238,7 +238,8 @@ class InferenceSession { common::Status Run(const RunOptions& run_options, const std::vector& feed_names, const std::vector& feeds, const std::vector& output_names, - std::vector* p_fetches) ORT_MUST_USE_RESULT; + std::vector* p_fetches, + const std::vector* p_fetches_device_info = nullptr) ORT_MUST_USE_RESULT; /** * Run a pre-loaded and pre-intialized model. diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 1f5bf29ae2..3c10000ea4 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -503,9 +503,9 @@ ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRu Status status; if (run_options == nullptr) { OrtRunOptions op; - status = session->Run(op, feed_names, feeds, output_names, &fetches); + status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr); } else { - status = session->Run(*run_options, feed_names, feeds, output_names, &fetches); + status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr); } if (!status.IsOK()) diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index ae59605b51..13269c82f9 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -337,7 +337,7 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) { results[0].Init(p_tensor.release(), DataTypeImpl::GetType(), DataTypeImpl::GetType()->GetDeleteFunc()); RunOptions ro; - ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results)); + ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr)); EXPECT_EQ(results[0].Get().DataRaw(), orig_buffer); EXPECT_THAT(results[0].Get().DataAsSpan(), ::testing::ContainerEq(gsl::make_span(expected))); @@ -361,7 +361,7 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) { std::vector results; RunOptions ro; - ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results)); + ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr)); // output buffer should not be the same as the initializer in SessionState const auto& initializers = session.GetSessionState().GetInitializedTensors(); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 5b476ba2e6..dbd008ea10 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -24,12 +24,14 @@ #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/graph/op.h" +#include "core/optimizer/rule_based_graph_transformer.h" #include "core/platform/env.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/providers/cpu/math/element_wise_ops.h" #ifdef USE_CUDA #include "core/providers/cuda/gpu_data_transfer.h" #endif +#include "core/session/environment.h" #include "core/session/IOBinding.h" #include "dummy_provider.h" #include "test_utils.h" @@ -37,9 +39,9 @@ #include "test/test_environment.h" #include "test/providers/provider_test_utils.h" #include "test/optimizer/dummy_graph_transformer.h" -#include "core/optimizer/rule_based_graph_transformer.h" +#include "test/util/include/default_providers.h" + #include "gtest/gtest.h" -#include "core/session/environment.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -246,7 +248,8 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, const RunOptions& run_options, ProviderType bind_provider_type, bool is_preallocate_output_vec, - ProviderType allocation_provider) { + ProviderType allocation_provider, + OrtDevice* output_device) { unique_ptr io_binding; Status st = session_object.NewIOBinding(&io_binding); ASSERT_TRUE(st.IsOK()); @@ -303,7 +306,13 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, } } - io_binding->BindOutput("Y", output_ml_value); + if (output_device) { + // output should be allocated on specified device (if not preallocated here) + io_binding->BindOutput("Y", output_ml_value, *output_device); + } else { + io_binding->BindOutput("Y", output_ml_value); + } + ASSERT_TRUE(io_binding->SynchronizeInputs().IsOK()); // prepare expected inputs and outputs @@ -315,8 +324,8 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; ASSERT_TRUE(st.IsOK()); - if (is_preallocate_output_vec && - allocation_provider == kCudaExecutionProvider) { + if ((is_preallocate_output_vec && allocation_provider == kCudaExecutionProvider) || + (output_device && output_device->Type() == OrtDevice::GPU)) { #ifdef USE_CUDA // in this case we need to copy the tensor from cuda to cpu vector& outputs = io_binding->GetOutputs(); @@ -787,7 +796,8 @@ static void TestBindHelper(const std::string& log_str, ProviderType bind_provider_type, ProviderType run_provider_type, bool preallocate_output, - ProviderType allocation_provider = kCpuExecutionProvider) { + ProviderType allocation_provider = kCpuExecutionProvider, + OrtDevice* output_device = nullptr) { SessionOptions so; so.session_logid = "InferenceSessionTests." + log_str; @@ -815,11 +825,13 @@ static void TestBindHelper(const std::string& log_str, RunOptions run_options; run_options.run_log_verbosity_level = so.session_log_verbosity_level; run_options.run_tag = so.session_logid; + RunModelWithBindingMatMul(session_object, run_options, bind_provider_type, preallocate_output, - allocation_provider); + allocation_provider, + output_device); } TEST(InferenceSessionTests, TestBindCpu) { @@ -938,6 +950,17 @@ TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu2) { kCpuExecutionProvider); } +TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) { + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0); + + TestBindHelper("TestBindCudaPreallocateOutputOnCuda", + kCudaExecutionProvider, + kCudaExecutionProvider, + false /* preallocate output on GPU */, + kCudaExecutionProvider, + &device /* specify output device */); +} + #endif TEST(InferenceSessionTests, ModelWithoutOpset) { @@ -1814,7 +1837,7 @@ TEST(InferenceSessionTests, TestCopyToFromDevices) { RunOptions run_options; run_options.run_tag = "run:" + std::to_string(run_num); - common::Status st = session_object.Run(run_options, feed_names, feeds, output_names, &fetches); + common::Status st = session_object.Run(run_options, feed_names, feeds, output_names, &fetches, nullptr); ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); VerifyOutputs(fetches, expected_dims_mul_y, expected_values_mul_y);