mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
649 lines
25 KiB
C++
649 lines
25 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "core/framework/utils.h"
|
|
|
|
#include <iomanip>
|
|
|
|
#include "core/graph/graph_viewer.h"
|
|
#include "core/framework/data_transfer_manager.h"
|
|
#include "core/framework/execution_frame.h"
|
|
#include "core/framework/execution_providers.h"
|
|
#include "core/framework/feeds_fetches_manager.h"
|
|
#include "core/framework/kernel_def_builder.h"
|
|
#include "core/framework/kernel_registry_manager.h"
|
|
#include "core/framework/op_kernel_context_internal.h"
|
|
#include "core/framework/parallel_executor.h"
|
|
#include "core/framework/session_state.h"
|
|
#include "core/framework/sequential_executor.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace utils {
|
|
AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorInfo& allocator_info) {
|
|
return session_state.GetExecutionProviders().GetAllocator(allocator_info);
|
|
}
|
|
|
|
bool ProviderIsCpuBased(const std::string& provider_type) {
|
|
return provider_type == onnxruntime::kCpuExecutionProvider ||
|
|
provider_type == onnxruntime::kMklDnnExecutionProvider ||
|
|
provider_type == onnxruntime::kNGraphExecutionProvider ||
|
|
provider_type == onnxruntime::kNupharExecutionProvider ||
|
|
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
|
|
provider_type == onnxruntime::kNnapiExecutionProvider;
|
|
}
|
|
|
|
common::Status AllocateHelper(const IExecutionProvider& execution_provider, const OrtDevice& device, const Tensor& fetched_tensor,
|
|
OrtValue& output_mlvalue) {
|
|
auto allocator = execution_provider.GetAllocator(device.Id(), OrtMemTypeDefault);
|
|
if (!allocator) {
|
|
return Status(common::ONNXRUNTIME, common::FAIL, "invalid allocator");
|
|
}
|
|
|
|
std::unique_ptr<Tensor> p_tensor = std::make_unique<Tensor>(fetched_tensor.DataType(),
|
|
fetched_tensor.Shape(),
|
|
allocator);
|
|
output_mlvalue.Init(p_tensor.release(),
|
|
DataTypeImpl::GetType<Tensor>(),
|
|
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info) {
|
|
// the input index will be std::numeric_limits<size_t>::max() if it's an implicit input to a control flow node.
|
|
// the input will be processed fully when executing the subgraph that consumes the implicit input.
|
|
bool implicit_input = info.index == std::numeric_limits<size_t>::max();
|
|
|
|
// node may declare input_mem_type to be on CPU explicitly
|
|
// skip implicit inputs as they don't have a valid 'index' value
|
|
bool node_input_on_cpu = !implicit_input && info.kci && info.kci->kernel_def->IsInputOnCpu(info.index);
|
|
|
|
// need a std::string that doesn't go away for kCpuExecutionProvider so we can return a reference.
|
|
static const std::string cpu_execution_provider{onnxruntime::kCpuExecutionProvider};
|
|
|
|
auto& required_provider_type = node_input_on_cpu ? cpu_execution_provider
|
|
: info.p_node->GetExecutionProviderType();
|
|
|
|
return required_provider_type;
|
|
}
|
|
|
|
static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
|
|
const FeedsFetchesManager::MLValueCopyInfo& copy_info,
|
|
const OrtValue& source_mlvalue,
|
|
OrtValue& target_mlvalue) {
|
|
if (copy_info.allocation_provider == nullptr) {
|
|
target_mlvalue = source_mlvalue;
|
|
return Status::OK();
|
|
}
|
|
|
|
auto& source_tensor = source_mlvalue.Get<Tensor>();
|
|
if (!target_mlvalue.IsAllocated()) {
|
|
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device,
|
|
source_tensor, target_mlvalue));
|
|
}
|
|
|
|
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();
|
|
|
|
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
// TODO should we handle the case of one input name feeding 2 nodes placed on different devices?
|
|
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
|
|
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue, bool& needed_copy,
|
|
FeedsFetchesManager::MLValueCopyInfo& copy_info) {
|
|
needed_copy = false;
|
|
|
|
std::vector<SessionState::NodeInfo> node_info_vec;
|
|
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec));
|
|
|
|
auto& exec_providers = session_state.GetExecutionProviders();
|
|
|
|
do {
|
|
// currently we only support one device per input. see SessionState::AddInputNameToNodeInfoMapping for more
|
|
// info on the logic to create the node_info_vec.
|
|
// for (auto& node_info : node_info_vec) {
|
|
auto& node_info = node_info_vec.front();
|
|
if (node_info.p_node == nullptr) {
|
|
// dummy entry for an input that we didn't find a use of in the graph.
|
|
// use the input as is given we don't believe it's actually needed.
|
|
new_mlvalue = orig_mlvalue;
|
|
break;
|
|
}
|
|
|
|
if (!orig_mlvalue.IsTensor()) {
|
|
// copying not supported for non-tensor types
|
|
new_mlvalue = orig_mlvalue;
|
|
break;
|
|
}
|
|
|
|
auto& required_device = *node_info.device;
|
|
auto& input_tensor_device = orig_mlvalue.Get<Tensor>().Location().device;
|
|
if (required_device == input_tensor_device) {
|
|
// No copy needed for same device.
|
|
new_mlvalue = orig_mlvalue;
|
|
break;
|
|
}
|
|
|
|
auto& required_provider_type = GetNodeInputProviderType(node_info);
|
|
auto* required_provider = exec_providers.Get(required_provider_type);
|
|
copy_info.target_device = required_device;
|
|
copy_info.allocation_provider = required_provider;
|
|
|
|
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue));
|
|
|
|
needed_copy = true;
|
|
|
|
} while (false);
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
|
|
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue) {
|
|
bool needed_copy;
|
|
FeedsFetchesManager::MLValueCopyInfo ignored;
|
|
return CopyOneInputAcrossDevices(session_state, input_name, orig_mlvalue, new_mlvalue, needed_copy, ignored);
|
|
}
|
|
|
|
// copies inputs across devices only if required and save copy_info
|
|
static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
|
|
const std::vector<std::string>& feed_names,
|
|
const std::vector<OrtValue>& orig_feeds, std::vector<OrtValue>& new_feeds,
|
|
bool& needed_copy,
|
|
std::vector<FeedsFetchesManager::MLValueCopyInfo>* copy_info) {
|
|
bool copied = false;
|
|
size_t num_feeds = orig_feeds.size();
|
|
ORT_ENFORCE(feed_names.size() == num_feeds);
|
|
|
|
new_feeds.resize(num_feeds);
|
|
if (copy_info) {
|
|
copy_info->resize(num_feeds);
|
|
}
|
|
|
|
for (size_t idx = 0; idx < num_feeds; ++idx) {
|
|
bool copied_this_input = false;
|
|
FeedsFetchesManager::MLValueCopyInfo current_copy_info = {}; // init for each call
|
|
ORT_RETURN_IF_ERROR(CopyOneInputAcrossDevices(session_state, feed_names[idx], orig_feeds[idx], new_feeds[idx],
|
|
copied_this_input, current_copy_info));
|
|
|
|
if (copied_this_input) {
|
|
copied = true;
|
|
|
|
if (copy_info) {
|
|
(*copy_info)[idx] = current_copy_info;
|
|
}
|
|
}
|
|
}
|
|
|
|
needed_copy = copied;
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
// copies inputs across devices only if required using cached copy_info
|
|
static common::Status CachedCopyInputsAcrossDevices(
|
|
const std::vector<OrtValue>& orig_feeds, std::vector<OrtValue>& new_feeds,
|
|
const std::vector<FeedsFetchesManager::MLValueCopyInfo>& copy_info,
|
|
const DataTransferManager& data_transfer_mgr) {
|
|
size_t num_feeds = orig_feeds.size();
|
|
ORT_ENFORCE(copy_info.size() == num_feeds);
|
|
|
|
new_feeds.resize(num_feeds);
|
|
|
|
for (size_t idx = 0; idx < num_feeds; ++idx) {
|
|
ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], orig_feeds[idx], new_feeds[idx]));
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
// Setup fetches for execution. Use any provided fetches directly if the provider matches.
|
|
// If the provider doesn't match, we don't know what device the execution output may be on, so can't assume the output
|
|
// can be returned to the user directly.
|
|
static common::Status SetupFetchesForExecute(const SessionState& session_state,
|
|
const std::vector<std::string>& output_names,
|
|
std::vector<OrtValue>& fetches, std::vector<OrtValue>& new_fetches,
|
|
std::vector<bool>* copy_to_new_fetches_cached_values) {
|
|
ORT_ENFORCE(new_fetches.empty());
|
|
auto num_outputs = output_names.size();
|
|
new_fetches.resize(num_outputs);
|
|
|
|
const auto& name_to_id = session_state.GetOrtValueNameIdxMap();
|
|
const auto* exec_plan = session_state.GetExecutionPlan();
|
|
// track which fetches can be copied to new_fetches and used directly in the execution.
|
|
std::vector<bool> local_can_copy_flags(num_outputs, false);
|
|
|
|
std::set<std::string> seen_outputs;
|
|
auto p_graph = session_state.GetGraphViewer();
|
|
ORT_ENFORCE(p_graph);
|
|
|
|
auto contains = [](const std::vector<std::string>& output_names,
|
|
const std::string& name) {
|
|
auto it = std::find(std::begin(output_names), std::end(output_names), name);
|
|
if (it == output_names.end()) {
|
|
return std::make_pair(false, size_t(0));
|
|
}
|
|
|
|
return std::pair<bool, size_t>(true, it - output_names.begin());
|
|
};
|
|
|
|
std::pair<bool, size_t> found;
|
|
for (auto& node : p_graph->Nodes()) {
|
|
if (seen_outputs.size() == num_outputs) {
|
|
break;
|
|
}
|
|
|
|
for (auto* arg : node.OutputDefs()) {
|
|
if (!arg->Exists() ||
|
|
!(found = contains(output_names, arg->Name())).first) {
|
|
continue;
|
|
}
|
|
|
|
seen_outputs.insert(arg->Name());
|
|
size_t idx = found.second;
|
|
const OrtValue& provided_mlvalue = fetches[idx];
|
|
|
|
if (provided_mlvalue.IsAllocated()) {
|
|
if (!provided_mlvalue.IsTensor()) {
|
|
new_fetches[idx] = fetches[idx];
|
|
local_can_copy_flags[idx] = true;
|
|
continue;
|
|
}
|
|
|
|
int arg_index;
|
|
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg->Name(), arg_index));
|
|
const auto& planned_device = exec_plan->GetLocation(arg_index).device;
|
|
const auto& provided_tensor_device = provided_mlvalue.Get<Tensor>().Location().device;
|
|
|
|
if (planned_device == provided_tensor_device) {
|
|
new_fetches[idx] = fetches[idx];
|
|
local_can_copy_flags[idx] = true;
|
|
continue;
|
|
}
|
|
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (copy_to_new_fetches_cached_values) {
|
|
*copy_to_new_fetches_cached_values = local_can_copy_flags;
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
static common::Status CachedSetupFetchesForExecute(std::vector<OrtValue>& fetches, std::vector<OrtValue>& new_fetches,
|
|
const std::vector<bool>& copy_to_new_fetches_cached_values) {
|
|
auto num_outputs = fetches.size();
|
|
ORT_ENFORCE(new_fetches.empty());
|
|
ORT_ENFORCE(copy_to_new_fetches_cached_values.size() == num_outputs);
|
|
|
|
new_fetches.resize(num_outputs);
|
|
|
|
// use the cached values
|
|
for (size_t i = 0; i < num_outputs; ++i) {
|
|
if (copy_to_new_fetches_cached_values[i]) {
|
|
new_fetches[i] = fetches[i];
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
// copies outputs across devices only if required
|
|
static common::Status CopyOutputsAcrossDevices(const SessionState& session_state, const std::vector<OrtValue>& fetches,
|
|
std::vector<OrtValue>& user_fetches, bool& needed_copy,
|
|
std::vector<FeedsFetchesManager::MLValueCopyInfo>* copiers) {
|
|
needed_copy = false;
|
|
auto num_outputs = fetches.size();
|
|
|
|
if (copiers) {
|
|
// resize so we have default values and only need to update an entry if there's a device copy required.
|
|
copiers->resize(num_outputs);
|
|
}
|
|
|
|
auto& execution_providers = session_state.GetExecutionProviders();
|
|
|
|
// CPU execution provider is always registered so this is not null
|
|
const auto* cpu_execution_provider = execution_providers.Get(onnxruntime::kCpuExecutionProvider);
|
|
|
|
for (size_t idx = 0; idx < num_outputs; ++idx) {
|
|
auto& fetched_mlvalue = fetches[idx];
|
|
if (!fetched_mlvalue.IsTensor()) {
|
|
user_fetches[idx] = fetched_mlvalue;
|
|
continue;
|
|
}
|
|
|
|
const IExecutionProvider* p_output_provider = nullptr;
|
|
auto target_device = OrtDevice();
|
|
auto& output_mlvalue = user_fetches[idx];
|
|
if (output_mlvalue.IsAllocated()) {
|
|
Tensor* p_output_tensor = output_mlvalue.GetMutable<Tensor>();
|
|
target_device = p_output_tensor->Location().device;
|
|
p_output_provider = execution_providers.Get(p_output_tensor->Location());
|
|
}
|
|
auto fetch_result_device = fetched_mlvalue.Get<Tensor>().Location().device;
|
|
if (target_device == fetch_result_device) {
|
|
user_fetches[idx] = fetched_mlvalue;
|
|
continue;
|
|
}
|
|
|
|
if (!p_output_provider) {
|
|
p_output_provider = cpu_execution_provider;
|
|
}
|
|
|
|
needed_copy = true;
|
|
FeedsFetchesManager::MLValueCopyInfo copy_info{target_device, p_output_provider};
|
|
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, fetched_mlvalue, output_mlvalue));
|
|
|
|
if (copiers) {
|
|
(*copiers)[idx] = copy_info;
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
static common::Status CachedCopyOutputsAcrossDevices(
|
|
const std::vector<OrtValue>& fetches, std::vector<OrtValue>& user_fetches,
|
|
const std::vector<FeedsFetchesManager::MLValueCopyInfo>& copy_info,
|
|
const DataTransferManager& data_transfer_mgr) {
|
|
auto num_outputs = fetches.size();
|
|
|
|
// internal logic error if these are mismatched
|
|
ORT_ENFORCE(num_outputs == copy_info.size());
|
|
|
|
// used the cached copy logic if available
|
|
for (size_t idx = 0; idx < num_outputs; ++idx) {
|
|
ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], fetches[idx], user_fetches[idx]));
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
static DeviceCopyCheck CheckExecutionProviders(const ExecutionProviders& execution_providers) {
|
|
for (const auto& execution_provider : execution_providers) {
|
|
if (!ProviderIsCpuBased(execution_provider->Type())) {
|
|
return DeviceCopyCheck::Unknown;
|
|
}
|
|
}
|
|
|
|
return DeviceCopyCheck::NoCopy;
|
|
}
|
|
|
|
// execute graph with cached info from FeedsFetchesManager.
|
|
common::Status ExecuteGraphWithCachedInfo(
|
|
const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
|
|
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
|
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators, bool sequential_execution,
|
|
const bool& terminate_flag, const logging::Logger& logger) {
|
|
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
|
|
auto device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();
|
|
|
|
std::unique_ptr<IExecutor> p_exec;
|
|
if (sequential_execution) {
|
|
p_exec = std::unique_ptr<IExecutor>(new SequentialExecutor(terminate_flag));
|
|
} else {
|
|
p_exec = std::unique_ptr<IExecutor>(new ParallelExecutor(session_state, terminate_flag));
|
|
}
|
|
|
|
if (device_copy_checks.status == DeviceCopyCheck::NoCopy) {
|
|
// no device copies are needed so simple execute
|
|
ORT_RETURN_IF_ERROR(p_exec->Execute(session_state,
|
|
feeds_fetches_info.feeds_mlvalue_idxs, feeds,
|
|
feeds_fetches_info.fetches_mlvalue_idxs, fetches, fetch_allocators, logger));
|
|
} else {
|
|
const std::vector<OrtValue>* p_feeds = &feeds;
|
|
std::vector<OrtValue>* p_fetches = &fetches;
|
|
std::vector<OrtValue> device_feeds;
|
|
std::vector<OrtValue> device_fetches;
|
|
|
|
// Copy inputs
|
|
if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) {
|
|
ORT_RETURN_IF_ERROR(CachedCopyInputsAcrossDevices(feeds, device_feeds,
|
|
feeds_fetches_manager.GetFeedsDeviceCopiers(),
|
|
session_state.GetDataTransferMgr()));
|
|
p_feeds = &device_feeds;
|
|
}
|
|
|
|
// setup fetches.
|
|
if (fetches.empty()) {
|
|
fetches.resize(feeds_fetches_info.output_names.size());
|
|
}
|
|
|
|
// if no output copy is needed, we can just use the fetches directly. otherwise we need to use a temporary set
|
|
// and run CopyOutputsAcrossDevices.
|
|
if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
|
|
ORT_RETURN_IF_ERROR(CachedSetupFetchesForExecute(fetches, device_fetches,
|
|
feeds_fetches_manager.GetCanUseFetchDuringExecutionFlags()));
|
|
p_fetches = &device_fetches;
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(p_exec->Execute(session_state,
|
|
feeds_fetches_info.feeds_mlvalue_idxs, *p_feeds,
|
|
feeds_fetches_info.fetches_mlvalue_idxs, *p_fetches, fetch_allocators,
|
|
logger));
|
|
|
|
if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
|
|
ORT_RETURN_IF_ERROR(CachedCopyOutputsAcrossDevices(*p_fetches, fetches,
|
|
feeds_fetches_manager.GetFetchesDeviceCopiers(),
|
|
session_state.GetDataTransferMgr()));
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
// execute graph and update feeds_fetches_manager with cached copy info if cache_copy_info is true
|
|
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
|
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
|
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
|
|
bool sequential_execution, const bool& terminate_flag, const logging::Logger& logger,
|
|
bool cache_copy_info) {
|
|
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
|
|
auto device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();
|
|
|
|
ORT_ENFORCE(device_copy_checks.status == DeviceCopyCheck::Unknown);
|
|
|
|
std::unique_ptr<IExecutor> p_exec;
|
|
if (sequential_execution) {
|
|
p_exec = std::unique_ptr<IExecutor>(new SequentialExecutor(terminate_flag));
|
|
} else {
|
|
p_exec = std::unique_ptr<IExecutor>(new ParallelExecutor(session_state, terminate_flag));
|
|
}
|
|
|
|
// see if we can skip copies due to the types of execution providers available
|
|
if (CheckExecutionProviders(session_state.GetExecutionProviders()) == DeviceCopyCheck::NoCopy) {
|
|
device_copy_checks.input_copy_needed = DeviceCopyCheck::NoCopy;
|
|
device_copy_checks.output_copy_needed = DeviceCopyCheck::NoCopy;
|
|
|
|
// no device copies are needed so simple execute
|
|
ORT_RETURN_IF_ERROR(p_exec->Execute(session_state,
|
|
feeds_fetches_info.feeds_mlvalue_idxs, feeds,
|
|
feeds_fetches_info.fetches_mlvalue_idxs, fetches, fetch_allocators, logger));
|
|
} else {
|
|
bool copy_needed = false;
|
|
|
|
const std::vector<OrtValue>* p_feeds = &feeds;
|
|
std::vector<OrtValue>* p_fetches = &fetches;
|
|
std::vector<OrtValue> device_feeds;
|
|
std::vector<OrtValue> device_fetches;
|
|
|
|
// Copy inputs
|
|
auto* copiers = cache_copy_info ? &feeds_fetches_manager.GetMutableFeedsDeviceCopiers() : nullptr;
|
|
ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state,
|
|
feeds_fetches_info.feed_names, feeds, device_feeds,
|
|
copy_needed, copiers));
|
|
|
|
if (copy_needed) {
|
|
p_feeds = &device_feeds;
|
|
}
|
|
|
|
device_copy_checks.input_copy_needed = copy_needed ? DeviceCopyCheck::Copy
|
|
: DeviceCopyCheck::NoCopy;
|
|
|
|
// setup fetches.
|
|
if (fetches.empty()) {
|
|
fetches.resize(feeds_fetches_info.output_names.size());
|
|
}
|
|
|
|
auto* use_provided_fetch_flags =
|
|
cache_copy_info ? &feeds_fetches_manager.GetMutableCanUseFetchDuringExecutionFlags()
|
|
: nullptr;
|
|
|
|
ORT_RETURN_IF_ERROR(SetupFetchesForExecute(session_state, feeds_fetches_info.output_names,
|
|
fetches, device_fetches,
|
|
use_provided_fetch_flags));
|
|
p_fetches = &device_fetches;
|
|
|
|
ORT_RETURN_IF_ERROR(p_exec->Execute(session_state,
|
|
feeds_fetches_info.feeds_mlvalue_idxs, *p_feeds,
|
|
feeds_fetches_info.fetches_mlvalue_idxs, *p_fetches, fetch_allocators,
|
|
logger));
|
|
|
|
copiers = cache_copy_info ? &feeds_fetches_manager.GetMutableFetchesDeviceCopiers() : nullptr;
|
|
ORT_RETURN_IF_ERROR(CopyOutputsAcrossDevices(session_state, *p_fetches, fetches, copy_needed, copiers));
|
|
|
|
device_copy_checks.output_copy_needed = copy_needed ? DeviceCopyCheck::Copy : DeviceCopyCheck::NoCopy;
|
|
}
|
|
|
|
// save the result of all the checks and use cached info next time
|
|
if (cache_copy_info) {
|
|
feeds_fetches_manager.SetDeviceCopyChecks(device_copy_checks);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
#if defined(DEBUG_NODE_INPUTS_OUTPUTS)
|
|
std::ostream& operator<<(std::ostream& out, const BFloat16& value) {
|
|
return out << value.ToFloat();
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const MLFloat16& value) {
|
|
return out << value.val;
|
|
}
|
|
|
|
template <typename T>
|
|
static void DumpTensor(const Tensor& tensor, const TensorShape& shape) {
|
|
auto num_items = shape.Size();
|
|
|
|
if (num_items == 0) {
|
|
std::cout << "no data";
|
|
return;
|
|
}
|
|
|
|
size_t num_dims = shape.NumDimensions();
|
|
size_t num_rows = 1;
|
|
if (num_dims > 1) {
|
|
num_rows = static_cast<size_t>(shape[0]);
|
|
}
|
|
|
|
size_t row_size = num_items / num_rows;
|
|
|
|
auto data = tensor.DataAsSpan<T>();
|
|
|
|
auto print_val = [](const T& value) {
|
|
if (std::is_floating_point_v<T>)
|
|
std::cout << std::setprecision(8) << value;
|
|
else
|
|
std::cout << value;
|
|
};
|
|
|
|
for (int row = 0; row < num_rows; ++row) {
|
|
print_val(data[row * row_size]);
|
|
for (int i = 1; i < row_size; ++i) {
|
|
std::cout << ", ";
|
|
print_val(data[row * row_size + i]);
|
|
}
|
|
std::cout << "\n";
|
|
}
|
|
|
|
std::cout << std::endl;
|
|
}
|
|
|
|
void DumpNodeInputs(const OpKernelContext& context, const Node& node) {
|
|
std::cout << "-----------\n";
|
|
std::cout << node.OpType() << " node: " << node.Name() << "\n";
|
|
|
|
const auto& input_defs = node.InputDefs();
|
|
|
|
for (auto i = 0, end = context.InputCount(); i < end; ++i) {
|
|
if (input_defs[i]->Exists()) {
|
|
std::cout << "Input " << i << " Name: " << input_defs[i]->Name();
|
|
|
|
const auto* type = context.InputType(i);
|
|
|
|
if (type) {
|
|
if (type->IsTensorType()) {
|
|
const auto& tensor = *context.Input<Tensor>(i);
|
|
const auto& shape = tensor.Shape();
|
|
|
|
std::cout << " Shape: " << shape << "\n";
|
|
} else {
|
|
std::cout << " is non-tensor type.\n";
|
|
}
|
|
} else {
|
|
// should never happen...
|
|
std::cout << " was missing data type\n";
|
|
}
|
|
} else {
|
|
std::cout << "Input " << i << " is optional and was not provided.\n";
|
|
}
|
|
}
|
|
}
|
|
|
|
void DumpNodeOutputs(OpKernelContext& context, const Node& node, const SessionState& session_state) {
|
|
std::cout << "-----------\n";
|
|
const auto& output_defs = node.OutputDefs();
|
|
|
|
const auto& execution_providers = session_state.GetExecutionProviders();
|
|
const auto* cpu_execution_provider = execution_providers.Get(onnxruntime::kCpuExecutionProvider);
|
|
|
|
for (auto i = 0, end = context.OutputCount(); i < end; ++i) {
|
|
if (output_defs[i]->Exists()) {
|
|
std::cout << "Output " << i << " Name: " << output_defs[i]->Name();
|
|
|
|
const auto* type = context.OutputType(i);
|
|
|
|
if (type) {
|
|
if (type->IsTensorType()) {
|
|
const auto& tensor = *context.Output<Tensor>(i);
|
|
const auto data_type = tensor.DataType();
|
|
const auto& shape = tensor.Shape();
|
|
|
|
std::cout << " Shape: " << shape << "\n";
|
|
|
|
// check tensor is on CPU before dumping it
|
|
auto& tensor_location = tensor.Location();
|
|
auto* provider = execution_providers.Get(tensor_location);
|
|
if (!provider) {
|
|
provider = cpu_execution_provider;
|
|
}
|
|
|
|
if (provider == cpu_execution_provider || tensor_location.mem_type == OrtMemTypeCPUOutput) {
|
|
DispatchOnTensorType(data_type, DumpTensor, tensor, shape);
|
|
} else {
|
|
std::cout << " is not on CPU. Provider=" << provider->Type() << "\n";
|
|
}
|
|
} else {
|
|
std::cout << " is non-tensor type.\n";
|
|
}
|
|
} else {
|
|
// should never happen...
|
|
std::cout << "missing data type\n";
|
|
}
|
|
} else {
|
|
std::cout << "Output " << i << " is optional and was not produced.\n";
|
|
}
|
|
|
|
std::cout << std::endl;
|
|
}
|
|
}
|
|
#endif
|
|
|
|
} // namespace utils
|
|
} // namespace onnxruntime
|