Add ability to specify just the device when using IOBinding for an output (#4386)

* Add ability to specify just the device when using IOBinding for an output. This enables keeping an output on a different device GPU when it has a dynamic size that is not known ahead of graph execution.
This commit is contained in:
Scott McKay 2020-07-03 09:26:47 +10:00 committed by GitHub
parent 28e4c0edf5
commit d22f6fddf7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 77 additions and 31 deletions

View file

@ -78,33 +78,36 @@ common::Status IOBinding::SynchronizeOutputs() {
return Status::OK();
}
common::Status IOBinding::BindOutput(const std::string& name, const OrtValue& ml_value) {
common::Status IOBinding::BindOutput(const std::string& name, const OrtValue& ml_value, OrtDevice device) {
auto rc = Contains(output_names_, name);
if (rc.first) {
outputs_[rc.second] = ml_value;
return Status::OK();
outputs_device_info_[rc.second] = device;
} else {
output_names_.push_back(name);
outputs_.push_back(ml_value);
outputs_device_info_.push_back(device);
}
output_names_.push_back(name);
outputs_.push_back(ml_value);
return Status::OK();
}
void IOBinding::ClearOutputs() {
output_names_.clear();
outputs_.clear();
outputs_device_info_.clear();
}
const std::vector<std::string>& IOBinding::GetOutputNames() const {
return output_names_;
}
const std::vector<std::string>& IOBinding::GetOutputNames() const { return output_names_; }
std::vector<OrtValue>& IOBinding::GetOutputs() { return outputs_; }
const std::vector<std::string>& IOBinding::GetInputNames() const {
return feed_names_;
const std::vector<OrtDevice>& IOBinding::GetOutputsDeviceInfo() const {
return outputs_device_info_;
}
const std::vector<std::string>& IOBinding::GetInputNames() const { return feed_names_; }
const std::vector<OrtValue>& IOBinding::GetInputs() const { return feeds_; }
AllocatorPtr IOBinding::GetCPUAllocator(int id, onnxruntime::ProviderType provider_type) const {

View file

@ -58,10 +58,15 @@ class IOBinding {
*/
common::Status SynchronizeInputs();
common::Status SynchronizeOutputs();
/**
* This simply provides the names and optionally allocated output containers.
* Bind an output name to a provided OrtValue.
* If the output is pre-allocated, the value in 'device' is not used.
* If the output is not pre-allocated, information on the device it should be allocated on can be provided.
*
* @param device Device to allocate on if not pre-allocated. Default is CPU.
*/
common::Status BindOutput(const std::string& name, const OrtValue& ml_value);
common::Status BindOutput(const std::string& name, const OrtValue& ml_value, OrtDevice device = {});
/**
* This simply collects the outputs obtained after calling Run() inside the @param outputs.
@ -69,6 +74,9 @@ class IOBinding {
const std::vector<std::string>& GetOutputNames() const;
std::vector<OrtValue>& GetOutputs();
// device info for outputs that are not pre-allocated
const std::vector<OrtDevice>& GetOutputsDeviceInfo() const;
const std::vector<std::string>& GetInputNames() const;
const std::vector<OrtValue>& GetInputs() const;
@ -94,6 +102,7 @@ class IOBinding {
std::vector<OrtValue> feeds_;
std::vector<std::string> output_names_;
std::vector<OrtValue> outputs_;
std::vector<OrtDevice> outputs_device_info_;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(IOBinding);
};

View file

@ -1131,9 +1131,10 @@ common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>&
return common::Status::OK();
}
Status InferenceSession::Run(const RunOptions& run_options, const std::vector<std::string>& feed_names,
const std::vector<OrtValue>& feeds, const std::vector<std::string>& output_names,
std::vector<OrtValue>* p_fetches) {
Status InferenceSession::Run(const RunOptions& run_options,
const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds,
const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info) {
TimePoint tp;
if (session_profiler_.IsEnabled()) {
tp = session_profiler_.StartTime();
@ -1171,6 +1172,16 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
FeedsFetchesInfo info(feed_names, output_names, session_state_->GetOrtValueNameIdxMap());
FeedsFetchesManager feeds_fetches_manager{std::move(info)};
if (p_fetches_device_info) {
// populate the target device info. ignored if pre-allocated fetches are provided
const auto& fetch_device_info = *p_fetches_device_info;
auto& fetch_info = feeds_fetches_manager.GetMutableFetchesDeviceCopyInfo();
for (size_t i = 0, end = output_names.size(); i < end; ++i) {
fetch_info[i].target_device = fetch_device_info[i];
}
}
if (!run_options.run_tag.empty()) {
LOGS(*session_logger_, INFO) << "Running with tag: " << run_options.run_tag;
}
@ -1206,7 +1217,6 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
session_options_.execution_mode, run_options.terminate, run_logger,
run_options.only_execute_path_to_fetches));
} catch (const std::exception& e) {
retval = Status(common::ONNXRUNTIME, common::FAIL, e.what());
} catch (...) {
@ -1270,7 +1280,7 @@ common::Status InferenceSession::Run(const RunOptions& run_options, const NameML
feeds.push_back(pair.second);
}
return Run(run_options, feed_names, feeds, output_names, p_fetches);
return Run(run_options, feed_names, feeds, output_names, p_fetches, nullptr);
}
std::pair<common::Status, const ModelMetadata*> InferenceSession::GetModelMetadata() const {
@ -1341,7 +1351,7 @@ common::Status InferenceSession::Run(const RunOptions& run_options, IOBinding& i
// TODO should Run() call io_binding.SynchronizeInputs() or should it let the callers do it?
// io_binding.SynchronizeInputs();
return Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
&io_binding.GetOutputs());
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
}
common::Status InferenceSession::Run(IOBinding& io_binding) {

View file

@ -238,7 +238,8 @@ class InferenceSession {
common::Status Run(const RunOptions& run_options, const std::vector<std::string>& feed_names,
const std::vector<OrtValue>& feeds, const std::vector<std::string>& output_names,
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info = nullptr) ORT_MUST_USE_RESULT;
/**
* Run a pre-loaded and pre-intialized model.

View file

@ -503,9 +503,9 @@ ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRu
Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = session->Run(op, feed_names, feeds, output_names, &fetches);
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
} else {
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches);
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
}
if (!status.IsOK())

View file

@ -337,7 +337,7 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
results[0].Init(p_tensor.release(), DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
RunOptions ro;
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results));
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
EXPECT_EQ(results[0].Get<Tensor>().DataRaw(), orig_buffer);
EXPECT_THAT(results[0].Get<Tensor>().DataAsSpan<float>(), ::testing::ContainerEq(gsl::make_span(expected)));
@ -361,7 +361,7 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
std::vector<OrtValue> results;
RunOptions ro;
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results));
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
// output buffer should not be the same as the initializer in SessionState
const auto& initializers = session.GetSessionState().GetInitializedTensors();

View file

@ -24,12 +24,14 @@
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/graph/op.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/platform/env.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/providers/cpu/math/element_wise_ops.h"
#ifdef USE_CUDA
#include "core/providers/cuda/gpu_data_transfer.h"
#endif
#include "core/session/environment.h"
#include "core/session/IOBinding.h"
#include "dummy_provider.h"
#include "test_utils.h"
@ -37,9 +39,9 @@
#include "test/test_environment.h"
#include "test/providers/provider_test_utils.h"
#include "test/optimizer/dummy_graph_transformer.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "test/util/include/default_providers.h"
#include "gtest/gtest.h"
#include "core/session/environment.h"
using namespace std;
using namespace ONNX_NAMESPACE;
@ -246,7 +248,8 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
const RunOptions& run_options,
ProviderType bind_provider_type,
bool is_preallocate_output_vec,
ProviderType allocation_provider) {
ProviderType allocation_provider,
OrtDevice* output_device) {
unique_ptr<IOBinding> io_binding;
Status st = session_object.NewIOBinding(&io_binding);
ASSERT_TRUE(st.IsOK());
@ -303,7 +306,13 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
}
}
io_binding->BindOutput("Y", output_ml_value);
if (output_device) {
// output should be allocated on specified device (if not preallocated here)
io_binding->BindOutput("Y", output_ml_value, *output_device);
} else {
io_binding->BindOutput("Y", output_ml_value);
}
ASSERT_TRUE(io_binding->SynchronizeInputs().IsOK());
// prepare expected inputs and outputs
@ -315,8 +324,8 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
std::cout << "Run returned status: " << st.ErrorMessage() << std::endl;
ASSERT_TRUE(st.IsOK());
if (is_preallocate_output_vec &&
allocation_provider == kCudaExecutionProvider) {
if ((is_preallocate_output_vec && allocation_provider == kCudaExecutionProvider) ||
(output_device && output_device->Type() == OrtDevice::GPU)) {
#ifdef USE_CUDA
// in this case we need to copy the tensor from cuda to cpu
vector<OrtValue>& outputs = io_binding->GetOutputs();
@ -787,7 +796,8 @@ static void TestBindHelper(const std::string& log_str,
ProviderType bind_provider_type,
ProviderType run_provider_type,
bool preallocate_output,
ProviderType allocation_provider = kCpuExecutionProvider) {
ProviderType allocation_provider = kCpuExecutionProvider,
OrtDevice* output_device = nullptr) {
SessionOptions so;
so.session_logid = "InferenceSessionTests." + log_str;
@ -815,11 +825,13 @@ static void TestBindHelper(const std::string& log_str,
RunOptions run_options;
run_options.run_log_verbosity_level = so.session_log_verbosity_level;
run_options.run_tag = so.session_logid;
RunModelWithBindingMatMul(session_object,
run_options,
bind_provider_type,
preallocate_output,
allocation_provider);
allocation_provider,
output_device);
}
TEST(InferenceSessionTests, TestBindCpu) {
@ -938,6 +950,17 @@ TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu2) {
kCpuExecutionProvider);
}
TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) {
OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0);
TestBindHelper("TestBindCudaPreallocateOutputOnCuda",
kCudaExecutionProvider,
kCudaExecutionProvider,
false /* preallocate output on GPU */,
kCudaExecutionProvider,
&device /* specify output device */);
}
#endif
TEST(InferenceSessionTests, ModelWithoutOpset) {
@ -1814,7 +1837,7 @@ TEST(InferenceSessionTests, TestCopyToFromDevices) {
RunOptions run_options;
run_options.run_tag = "run:" + std::to_string(run_num);
common::Status st = session_object.Run(run_options, feed_names, feeds, output_names, &fetches);
common::Status st = session_object.Run(run_options, feed_names, feeds, output_names, &fetches, nullptr);
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
VerifyOutputs(fetches, expected_dims_mul_y, expected_values_mul_y);