mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-19 02:03:52 +00:00
Add ability to specify just the device when using IOBinding for an output (#4386)
* Add ability to specify just the device when using IOBinding for an output. This enables keeping an output on a different device GPU when it has a dynamic size that is not known ahead of graph execution.
This commit is contained in:
parent
28e4c0edf5
commit
d22f6fddf7
7 changed files with 77 additions and 31 deletions
|
|
@ -78,33 +78,36 @@ common::Status IOBinding::SynchronizeOutputs() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status IOBinding::BindOutput(const std::string& name, const OrtValue& ml_value) {
|
||||
common::Status IOBinding::BindOutput(const std::string& name, const OrtValue& ml_value, OrtDevice device) {
|
||||
auto rc = Contains(output_names_, name);
|
||||
if (rc.first) {
|
||||
outputs_[rc.second] = ml_value;
|
||||
return Status::OK();
|
||||
outputs_device_info_[rc.second] = device;
|
||||
} else {
|
||||
output_names_.push_back(name);
|
||||
outputs_.push_back(ml_value);
|
||||
outputs_device_info_.push_back(device);
|
||||
}
|
||||
|
||||
output_names_.push_back(name);
|
||||
outputs_.push_back(ml_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void IOBinding::ClearOutputs() {
|
||||
output_names_.clear();
|
||||
outputs_.clear();
|
||||
outputs_device_info_.clear();
|
||||
}
|
||||
|
||||
const std::vector<std::string>& IOBinding::GetOutputNames() const {
|
||||
return output_names_;
|
||||
}
|
||||
const std::vector<std::string>& IOBinding::GetOutputNames() const { return output_names_; }
|
||||
|
||||
std::vector<OrtValue>& IOBinding::GetOutputs() { return outputs_; }
|
||||
|
||||
const std::vector<std::string>& IOBinding::GetInputNames() const {
|
||||
return feed_names_;
|
||||
const std::vector<OrtDevice>& IOBinding::GetOutputsDeviceInfo() const {
|
||||
return outputs_device_info_;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& IOBinding::GetInputNames() const { return feed_names_; }
|
||||
|
||||
const std::vector<OrtValue>& IOBinding::GetInputs() const { return feeds_; }
|
||||
|
||||
AllocatorPtr IOBinding::GetCPUAllocator(int id, onnxruntime::ProviderType provider_type) const {
|
||||
|
|
|
|||
|
|
@ -58,10 +58,15 @@ class IOBinding {
|
|||
*/
|
||||
common::Status SynchronizeInputs();
|
||||
common::Status SynchronizeOutputs();
|
||||
|
||||
/**
|
||||
* This simply provides the names and optionally allocated output containers.
|
||||
* Bind an output name to a provided OrtValue.
|
||||
* If the output is pre-allocated, the value in 'device' is not used.
|
||||
* If the output is not pre-allocated, information on the device it should be allocated on can be provided.
|
||||
*
|
||||
* @param device Device to allocate on if not pre-allocated. Default is CPU.
|
||||
*/
|
||||
common::Status BindOutput(const std::string& name, const OrtValue& ml_value);
|
||||
common::Status BindOutput(const std::string& name, const OrtValue& ml_value, OrtDevice device = {});
|
||||
|
||||
/**
|
||||
* This simply collects the outputs obtained after calling Run() inside the @param outputs.
|
||||
|
|
@ -69,6 +74,9 @@ class IOBinding {
|
|||
const std::vector<std::string>& GetOutputNames() const;
|
||||
std::vector<OrtValue>& GetOutputs();
|
||||
|
||||
// device info for outputs that are not pre-allocated
|
||||
const std::vector<OrtDevice>& GetOutputsDeviceInfo() const;
|
||||
|
||||
const std::vector<std::string>& GetInputNames() const;
|
||||
const std::vector<OrtValue>& GetInputs() const;
|
||||
|
||||
|
|
@ -94,6 +102,7 @@ class IOBinding {
|
|||
std::vector<OrtValue> feeds_;
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<OrtValue> outputs_;
|
||||
std::vector<OrtDevice> outputs_device_info_;
|
||||
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(IOBinding);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1131,9 +1131,10 @@ common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>&
|
|||
return common::Status::OK();
|
||||
}
|
||||
|
||||
Status InferenceSession::Run(const RunOptions& run_options, const std::vector<std::string>& feed_names,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<std::string>& output_names,
|
||||
std::vector<OrtValue>* p_fetches) {
|
||||
Status InferenceSession::Run(const RunOptions& run_options,
|
||||
const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds,
|
||||
const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches,
|
||||
const std::vector<OrtDevice>* p_fetches_device_info) {
|
||||
TimePoint tp;
|
||||
if (session_profiler_.IsEnabled()) {
|
||||
tp = session_profiler_.StartTime();
|
||||
|
|
@ -1171,6 +1172,16 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
|
|||
FeedsFetchesInfo info(feed_names, output_names, session_state_->GetOrtValueNameIdxMap());
|
||||
FeedsFetchesManager feeds_fetches_manager{std::move(info)};
|
||||
|
||||
if (p_fetches_device_info) {
|
||||
// populate the target device info. ignored if pre-allocated fetches are provided
|
||||
const auto& fetch_device_info = *p_fetches_device_info;
|
||||
auto& fetch_info = feeds_fetches_manager.GetMutableFetchesDeviceCopyInfo();
|
||||
|
||||
for (size_t i = 0, end = output_names.size(); i < end; ++i) {
|
||||
fetch_info[i].target_device = fetch_device_info[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (!run_options.run_tag.empty()) {
|
||||
LOGS(*session_logger_, INFO) << "Running with tag: " << run_options.run_tag;
|
||||
}
|
||||
|
|
@ -1206,7 +1217,6 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
|
|||
ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
|
||||
session_options_.execution_mode, run_options.terminate, run_logger,
|
||||
run_options.only_execute_path_to_fetches));
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
retval = Status(common::ONNXRUNTIME, common::FAIL, e.what());
|
||||
} catch (...) {
|
||||
|
|
@ -1270,7 +1280,7 @@ common::Status InferenceSession::Run(const RunOptions& run_options, const NameML
|
|||
feeds.push_back(pair.second);
|
||||
}
|
||||
|
||||
return Run(run_options, feed_names, feeds, output_names, p_fetches);
|
||||
return Run(run_options, feed_names, feeds, output_names, p_fetches, nullptr);
|
||||
}
|
||||
|
||||
std::pair<common::Status, const ModelMetadata*> InferenceSession::GetModelMetadata() const {
|
||||
|
|
@ -1341,7 +1351,7 @@ common::Status InferenceSession::Run(const RunOptions& run_options, IOBinding& i
|
|||
// TODO should Run() call io_binding.SynchronizeInputs() or should it let the callers do it?
|
||||
// io_binding.SynchronizeInputs();
|
||||
return Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
|
||||
&io_binding.GetOutputs());
|
||||
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
|
||||
}
|
||||
|
||||
common::Status InferenceSession::Run(IOBinding& io_binding) {
|
||||
|
|
|
|||
|
|
@ -238,7 +238,8 @@ class InferenceSession {
|
|||
|
||||
common::Status Run(const RunOptions& run_options, const std::vector<std::string>& feed_names,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<std::string>& output_names,
|
||||
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
|
||||
std::vector<OrtValue>* p_fetches,
|
||||
const std::vector<OrtDevice>* p_fetches_device_info = nullptr) ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
* Run a pre-loaded and pre-intialized model.
|
||||
|
|
|
|||
|
|
@ -503,9 +503,9 @@ ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRu
|
|||
Status status;
|
||||
if (run_options == nullptr) {
|
||||
OrtRunOptions op;
|
||||
status = session->Run(op, feed_names, feeds, output_names, &fetches);
|
||||
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
|
||||
} else {
|
||||
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches);
|
||||
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
|
||||
}
|
||||
|
||||
if (!status.IsOK())
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
|
|||
results[0].Init(p_tensor.release(), DataTypeImpl::GetType<Tensor>(),
|
||||
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
|
||||
RunOptions ro;
|
||||
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results));
|
||||
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
|
||||
|
||||
EXPECT_EQ(results[0].Get<Tensor>().DataRaw(), orig_buffer);
|
||||
EXPECT_THAT(results[0].Get<Tensor>().DataAsSpan<float>(), ::testing::ContainerEq(gsl::make_span(expected)));
|
||||
|
|
@ -361,7 +361,7 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
|
|||
|
||||
std::vector<OrtValue> results;
|
||||
RunOptions ro;
|
||||
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results));
|
||||
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
|
||||
|
||||
// output buffer should not be the same as the initializer in SessionState
|
||||
const auto& initializers = session.GetSessionState().GetInitializedTensors();
|
||||
|
|
|
|||
|
|
@ -24,12 +24,14 @@
|
|||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/gpu_data_transfer.h"
|
||||
#endif
|
||||
#include "core/session/environment.h"
|
||||
#include "core/session/IOBinding.h"
|
||||
#include "dummy_provider.h"
|
||||
#include "test_utils.h"
|
||||
|
|
@ -37,9 +39,9 @@
|
|||
#include "test/test_environment.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/optimizer/dummy_graph_transformer.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "core/session/environment.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
|
@ -246,7 +248,8 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
|
|||
const RunOptions& run_options,
|
||||
ProviderType bind_provider_type,
|
||||
bool is_preallocate_output_vec,
|
||||
ProviderType allocation_provider) {
|
||||
ProviderType allocation_provider,
|
||||
OrtDevice* output_device) {
|
||||
unique_ptr<IOBinding> io_binding;
|
||||
Status st = session_object.NewIOBinding(&io_binding);
|
||||
ASSERT_TRUE(st.IsOK());
|
||||
|
|
@ -303,7 +306,13 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
|
|||
}
|
||||
}
|
||||
|
||||
io_binding->BindOutput("Y", output_ml_value);
|
||||
if (output_device) {
|
||||
// output should be allocated on specified device (if not preallocated here)
|
||||
io_binding->BindOutput("Y", output_ml_value, *output_device);
|
||||
} else {
|
||||
io_binding->BindOutput("Y", output_ml_value);
|
||||
}
|
||||
|
||||
ASSERT_TRUE(io_binding->SynchronizeInputs().IsOK());
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
|
|
@ -315,8 +324,8 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
|
|||
std::cout << "Run returned status: " << st.ErrorMessage() << std::endl;
|
||||
ASSERT_TRUE(st.IsOK());
|
||||
|
||||
if (is_preallocate_output_vec &&
|
||||
allocation_provider == kCudaExecutionProvider) {
|
||||
if ((is_preallocate_output_vec && allocation_provider == kCudaExecutionProvider) ||
|
||||
(output_device && output_device->Type() == OrtDevice::GPU)) {
|
||||
#ifdef USE_CUDA
|
||||
// in this case we need to copy the tensor from cuda to cpu
|
||||
vector<OrtValue>& outputs = io_binding->GetOutputs();
|
||||
|
|
@ -787,7 +796,8 @@ static void TestBindHelper(const std::string& log_str,
|
|||
ProviderType bind_provider_type,
|
||||
ProviderType run_provider_type,
|
||||
bool preallocate_output,
|
||||
ProviderType allocation_provider = kCpuExecutionProvider) {
|
||||
ProviderType allocation_provider = kCpuExecutionProvider,
|
||||
OrtDevice* output_device = nullptr) {
|
||||
SessionOptions so;
|
||||
|
||||
so.session_logid = "InferenceSessionTests." + log_str;
|
||||
|
|
@ -815,11 +825,13 @@ static void TestBindHelper(const std::string& log_str,
|
|||
RunOptions run_options;
|
||||
run_options.run_log_verbosity_level = so.session_log_verbosity_level;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
||||
RunModelWithBindingMatMul(session_object,
|
||||
run_options,
|
||||
bind_provider_type,
|
||||
preallocate_output,
|
||||
allocation_provider);
|
||||
allocation_provider,
|
||||
output_device);
|
||||
}
|
||||
|
||||
TEST(InferenceSessionTests, TestBindCpu) {
|
||||
|
|
@ -938,6 +950,17 @@ TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu2) {
|
|||
kCpuExecutionProvider);
|
||||
}
|
||||
|
||||
TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) {
|
||||
OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0);
|
||||
|
||||
TestBindHelper("TestBindCudaPreallocateOutputOnCuda",
|
||||
kCudaExecutionProvider,
|
||||
kCudaExecutionProvider,
|
||||
false /* preallocate output on GPU */,
|
||||
kCudaExecutionProvider,
|
||||
&device /* specify output device */);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TEST(InferenceSessionTests, ModelWithoutOpset) {
|
||||
|
|
@ -1814,7 +1837,7 @@ TEST(InferenceSessionTests, TestCopyToFromDevices) {
|
|||
RunOptions run_options;
|
||||
run_options.run_tag = "run:" + std::to_string(run_num);
|
||||
|
||||
common::Status st = session_object.Run(run_options, feed_names, feeds, output_names, &fetches);
|
||||
common::Status st = session_object.Run(run_options, feed_names, feeds, output_names, &fetches, nullptr);
|
||||
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
|
||||
|
||||
VerifyOutputs(fetches, expected_dims_mul_y, expected_values_mul_y);
|
||||
|
|
|
|||
Loading…
Reference in a new issue