mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Kezhan/execute graph refactoring (#1553)
* checking execution provider logic updated. * fix the logic of copy input and output. * update * update * update * update * update * update * fix ngraph failure. * fix comments
This commit is contained in:
parent
b405482cfa
commit
bd64ca3019
9 changed files with 102 additions and 116 deletions
|
|
@ -72,6 +72,10 @@ struct OrtDevice {
|
|||
DeviceId device_id;
|
||||
};
|
||||
|
||||
inline bool operator==(const OrtDevice& left, const OrtDevice& other) {
|
||||
return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type();
|
||||
}
|
||||
|
||||
struct OrtAllocatorInfo {
|
||||
// use string for name, so we could have customized allocator in execution provider.
|
||||
const char* name;
|
||||
|
|
|
|||
|
|
@ -347,6 +347,10 @@ class PlannerImpl {
|
|||
|
||||
Status ComputeUseCounts() {
|
||||
// Note: for every ml-value, its definition must appear before all its uses in a topological sort of a valid model
|
||||
std::unordered_set<std::string> graph_inputs;
|
||||
for (auto& graph_input : graph_viewer_.GetInputsIncludingInitializers()) {
|
||||
graph_inputs.insert(graph_input->Name());
|
||||
}
|
||||
|
||||
for (auto graph_input : graph_viewer_.GetInputs()) {
|
||||
OrtValueIndex index = Index(graph_input->Name());
|
||||
|
|
@ -371,15 +375,7 @@ class PlannerImpl {
|
|||
for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) {
|
||||
auto pnode = graph_viewer_.GetNode(step.node_index);
|
||||
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index);
|
||||
for (auto node_input : pnode->InputDefs()) {
|
||||
if (node_input->Exists())
|
||||
UseCount(node_input->Name())++;
|
||||
}
|
||||
|
||||
for (auto node_input : pnode->ImplicitInputDefs()) {
|
||||
if (node_input->Exists())
|
||||
UseCount(node_input->Name())++;
|
||||
}
|
||||
// Identify where each output of this node should be allocated.
|
||||
// This is determined by the opkernel bound to the node.
|
||||
const KernelCreateInfo* kernel_create_info = nullptr;
|
||||
|
|
@ -394,31 +390,45 @@ class PlannerImpl {
|
|||
if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")";
|
||||
return Status(ONNXRUNTIME, FAIL, errormsg.str());
|
||||
}
|
||||
|
||||
auto exec_provider = execution_providers_.Get(*pnode);
|
||||
if (exec_provider == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the execution provider ",
|
||||
pnode->GetExecutionProviderType());
|
||||
}
|
||||
|
||||
auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
|
||||
// increment UseCount and add location information if applicable for the provided input def
|
||||
auto process_input = [&graph_inputs, &exec_provider, &p_kernelDef, this](const NodeArg& input, size_t arg_idx) {
|
||||
const auto& name = input.Name();
|
||||
UseCount(name)++;
|
||||
|
||||
// If it's a graph input or outer scope node arg, set its plan.
|
||||
// NOTE: Copy nodes should have already been added if a graph input is fed as input
|
||||
// to nodes assigned to different providers.
|
||||
if (graph_inputs.find(name) != graph_inputs.cend() ||
|
||||
std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(),
|
||||
[&name](const NodeArg* value) {
|
||||
return value && value->Name() == name;
|
||||
}) != outer_scope_node_args_.cend()) {
|
||||
OrtValueIndex index = Index(name);
|
||||
plan_.SetLocation(static_cast<size_t>(index),
|
||||
exec_provider->GetAllocator(0, p_kernelDef->InputMemoryType(arg_idx))->Info());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->InputDefs(), process_input));
|
||||
ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->ImplicitInputDefs(), process_input));
|
||||
|
||||
auto outputs = pnode->OutputDefs();
|
||||
auto num_outputs = outputs.size();
|
||||
|
||||
for (size_t i = 0; i < num_outputs; ++i) {
|
||||
auto* node_output = outputs[i];
|
||||
if (!node_output->Exists()) continue;
|
||||
OrtValueIndex index = Index(node_output->Name());
|
||||
ProcessDef(index, node_output);
|
||||
++UseCount(index);
|
||||
if (strcmp(default_allocator_info.name, CPU) != 0) {
|
||||
// By default, outputs of this node are allocated on the default device allocator,
|
||||
// except for outputs marked for allocation in MemoryType:
|
||||
auto memory_type = p_kernelDef->OutputMemoryType(i);
|
||||
plan_.SetLocation(static_cast<size_t>(index), memory_type == OrtMemTypeDefault
|
||||
? default_allocator_info
|
||||
: exec_provider->GetAllocator(0, memory_type)->Info());
|
||||
}
|
||||
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetAllocator(0, p_kernelDef->OutputMemoryType(i))->Info());
|
||||
}
|
||||
// if sync is needed, mark allocation plan as create_fence_if_async=true
|
||||
// note that the input arg may come from an execution provider (i.e. CPU) that does not support async,
|
||||
|
|
|
|||
|
|
@ -48,9 +48,8 @@ struct FeedsFetchesInfo {
|
|||
class FeedsFetchesManager {
|
||||
public:
|
||||
struct MLValueCopyInfo {
|
||||
int allocation_device_id = 0;
|
||||
OrtDevice target_device;
|
||||
const IExecutionProvider* allocation_provider = nullptr;
|
||||
const IExecutionProvider* copy_provider = nullptr;
|
||||
};
|
||||
|
||||
static Status Create(const std::vector<std::string>& feed_names, const std::vector<std::string>& output_names,
|
||||
|
|
|
|||
|
|
@ -175,10 +175,6 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
|
|||
//prepare the func kernel
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, *node);
|
||||
if (node->GetExecutionProviderType() == onnxruntime::kNGraphExecutionProvider || node->GetExecutionProviderType() == onnxruntime::kNnapiExecutionProvider) {
|
||||
builder.SetDefaultInputsMemoryType(OrtMemTypeCPUInput);
|
||||
builder.SetDefaultOutputMemoryType(OrtMemTypeCPUOutput);
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(fused_kernel_registry->Register(
|
||||
builder, static_cast<KernelCreatePtrFn>([](const OpKernelInfo& info) -> OpKernel* { return new FunctionKernel(info); })));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -141,17 +141,18 @@ class SessionState {
|
|||
* \param p_node0 Nullable
|
||||
* \param kci0 Nullable
|
||||
*/
|
||||
NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0)
|
||||
NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0)
|
||||
: index(index0),
|
||||
p_node(p_node0),
|
||||
kci(kci0) {
|
||||
}
|
||||
kci(kci0),
|
||||
device(&device0) {}
|
||||
|
||||
size_t index;
|
||||
// Nullable
|
||||
const onnxruntime::Node* p_node = nullptr;
|
||||
// Nullable
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
const OrtDevice* device = nullptr;
|
||||
};
|
||||
|
||||
using NameNodeInfoMapType = std::unordered_map<std::string, std::vector<NodeInfo>>;
|
||||
|
|
|
|||
|
|
@ -351,6 +351,8 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
|
|||
if (implicit_inputs && implicit_inputs->empty()) {
|
||||
implicit_inputs = nullptr;
|
||||
}
|
||||
const auto* exec_plan = session_state.GetExecutionPlan();
|
||||
const auto& name_to_id = session_state.GetOrtValueNameIdxMap();
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
// note that KernelCreateInfo may not exist for custom kernel
|
||||
|
|
@ -365,7 +367,11 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
SessionState::NodeInfo node_info(index, &node, kci);
|
||||
int arg_index;
|
||||
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index));
|
||||
const auto& device = exec_plan->GetLocation(arg_index).device;
|
||||
|
||||
SessionState::NodeInfo node_info(index, &node, kci, device);
|
||||
|
||||
if (IsArgNameInInputsOutputs(arg.Name(), graph_inputs)) {
|
||||
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(arg.Name(), node_info));
|
||||
|
|
@ -397,8 +403,13 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
|
|||
// copy to/from CPU to go through the control flow nodes where possible/applicable.
|
||||
// the processing for the subgraph where the implicit input is consumed will do the real check on whether any
|
||||
// copy to a different device is required
|
||||
SessionState::NodeInfo node_info(std::numeric_limits<size_t>::max(), &node, kci);
|
||||
for (const auto& input_def : node_implicit_inputs) {
|
||||
int arg_index;
|
||||
//Question: the implicit input may not be found in this session state name to id map, but in parent session state name to id map.
|
||||
//@Scott
|
||||
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index));
|
||||
auto& device = exec_plan->GetLocation(arg_index).device;
|
||||
SessionState::NodeInfo node_info(std::numeric_limits<size_t>::max(), &node, kci, device);
|
||||
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(input_def->Name(), node_info));
|
||||
}
|
||||
}
|
||||
|
|
@ -413,7 +424,6 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
|
|||
|
||||
auto& input_map = session_state.GetInputNodeInfoMap();
|
||||
auto end_map = input_map.cend();
|
||||
SessionState::NodeInfo empty_node_info(std::numeric_limits<size_t>::max(), nullptr, nullptr);
|
||||
|
||||
for (const auto& graph_input : graph_inputs) {
|
||||
const auto& name = graph_input->Name();
|
||||
|
|
@ -422,6 +432,10 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
|
|||
// utils::CopyOneInputAcrossDevices will use the input OrtValue as is given we don't believe it's used anywhere.
|
||||
LOGS(session_state.Logger(), INFO) << (graph.IsSubgraph() ? "Subgraph" : "Graph") << " input with name "
|
||||
<< name << " is not used by any node.";
|
||||
int arg_index;
|
||||
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(name, arg_index));
|
||||
auto& device = exec_plan->GetLocation(arg_index).device;
|
||||
SessionState::NodeInfo empty_node_info(std::numeric_limits<size_t>::max(), nullptr, nullptr, device);
|
||||
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(name, empty_node_info));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,9 +23,18 @@ AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorI
|
|||
return session_state.GetExecutionProviders().GetAllocator(allocator_info);
|
||||
}
|
||||
|
||||
common::Status AllocateHelper(const IExecutionProvider& execution_provider, int device_id, const Tensor& fetched_tensor,
|
||||
bool ProviderIsCpuBased(const std::string& provider_type) {
|
||||
return provider_type == onnxruntime::kCpuExecutionProvider ||
|
||||
provider_type == onnxruntime::kMklDnnExecutionProvider ||
|
||||
provider_type == onnxruntime::kNGraphExecutionProvider ||
|
||||
provider_type == onnxruntime::kNupharExecutionProvider ||
|
||||
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
|
||||
provider_type == onnxruntime::kNnapiExecutionProvider;
|
||||
}
|
||||
|
||||
common::Status AllocateHelper(const IExecutionProvider& execution_provider, const OrtDevice& device, const Tensor& fetched_tensor,
|
||||
OrtValue& output_mlvalue) {
|
||||
auto allocator = execution_provider.GetAllocator(device_id, OrtMemTypeDefault);
|
||||
auto allocator = execution_provider.GetAllocator(device.Id(), OrtMemTypeDefault);
|
||||
if (!allocator) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "invalid allocator");
|
||||
}
|
||||
|
|
@ -62,21 +71,21 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
|
|||
const FeedsFetchesManager::MLValueCopyInfo& copy_info,
|
||||
const OrtValue& source_mlvalue,
|
||||
OrtValue& target_mlvalue) {
|
||||
if (copy_info.copy_provider == nullptr) {
|
||||
if (copy_info.allocation_provider == nullptr){
|
||||
target_mlvalue = source_mlvalue;
|
||||
} else {
|
||||
auto& source_tensor = source_mlvalue.Get<Tensor>();
|
||||
|
||||
if (!target_mlvalue.IsAllocated()) {
|
||||
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.allocation_device_id,
|
||||
source_tensor, target_mlvalue));
|
||||
}
|
||||
|
||||
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();
|
||||
|
||||
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto& source_tensor = source_mlvalue.Get<Tensor>();
|
||||
if (!target_mlvalue.IsAllocated()) {
|
||||
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device,
|
||||
source_tensor, target_mlvalue));
|
||||
}
|
||||
|
||||
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();
|
||||
|
||||
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -86,8 +95,6 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
|
|||
FeedsFetchesManager::MLValueCopyInfo& copy_info) {
|
||||
needed_copy = false;
|
||||
|
||||
//TODO: make it configurable
|
||||
const int target_device_id = 0;
|
||||
std::vector<SessionState::NodeInfo> node_info_vec;
|
||||
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec));
|
||||
|
||||
|
|
@ -111,51 +118,23 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
|
|||
break;
|
||||
}
|
||||
|
||||
auto& required_device = *node_info.device;
|
||||
auto& input_tensor_device = orig_mlvalue.Get<Tensor>().Location().device;
|
||||
if (required_device == input_tensor_device) {
|
||||
// No copy needed for same device.
|
||||
new_mlvalue = orig_mlvalue;
|
||||
break;
|
||||
}
|
||||
|
||||
auto& required_provider_type = GetNodeInputProviderType(node_info);
|
||||
auto& input_tensor = orig_mlvalue.Get<Tensor>();
|
||||
auto& input_tensor_loc = input_tensor.Location();
|
||||
|
||||
auto* p_input_provider = exec_providers.Get(input_tensor_loc);
|
||||
if (!p_input_provider) {
|
||||
p_input_provider = exec_providers.Get(onnxruntime::kCpuExecutionProvider);
|
||||
ORT_ENFORCE(p_input_provider);
|
||||
}
|
||||
|
||||
//no copy for nGraph
|
||||
if (required_provider_type == onnxruntime::kNGraphExecutionProvider) {
|
||||
new_mlvalue = orig_mlvalue;
|
||||
break;
|
||||
}
|
||||
|
||||
auto input_provider_type = p_input_provider->Type();
|
||||
if (input_provider_type == required_provider_type && input_tensor_loc.mem_type == OrtMemTypeDefault) {
|
||||
new_mlvalue = orig_mlvalue;
|
||||
break;
|
||||
}
|
||||
|
||||
// If a node requires input on cpu and input tensor is allocated with pinned memory allocator, don't do copy
|
||||
if (required_provider_type == onnxruntime::kCpuExecutionProvider &&
|
||||
input_tensor_loc.mem_type == OrtMemTypeCPU) {
|
||||
new_mlvalue = orig_mlvalue;
|
||||
break;
|
||||
}
|
||||
|
||||
auto* required_provider = exec_providers.Get(required_provider_type);
|
||||
ORT_ENFORCE(required_provider);
|
||||
|
||||
auto* p_copy_provider = (required_provider_type != onnxruntime::kCpuExecutionProvider)
|
||||
? required_provider
|
||||
: p_input_provider;
|
||||
|
||||
copy_info.allocation_device_id = target_device_id;
|
||||
copy_info.target_device = required_device;
|
||||
copy_info.allocation_provider = required_provider;
|
||||
copy_info.copy_provider = p_copy_provider;
|
||||
|
||||
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue));
|
||||
|
||||
needed_copy = true;
|
||||
|
||||
// } loop of node_info_vec
|
||||
} while (false);
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -344,43 +323,26 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
|
|||
continue;
|
||||
}
|
||||
|
||||
auto& fetched_tensor = fetched_mlvalue.Get<Tensor>();
|
||||
auto& fetched_tensor_location = fetched_tensor.Location();
|
||||
auto* p_fetched_provider = execution_providers.Get(fetched_tensor_location);
|
||||
if (!p_fetched_provider) {
|
||||
p_fetched_provider = cpu_execution_provider;
|
||||
}
|
||||
|
||||
auto fetched_provider_type = p_fetched_provider->Type();
|
||||
auto& output_mlvalue = user_fetches[idx];
|
||||
|
||||
const IExecutionProvider* p_output_provider = nullptr;
|
||||
|
||||
auto target_device = OrtDevice();
|
||||
auto& output_mlvalue = user_fetches[idx];
|
||||
if (output_mlvalue.IsAllocated()) {
|
||||
Tensor* p_output_tensor = output_mlvalue.GetMutable<Tensor>();
|
||||
target_device = p_output_tensor->Location().device;
|
||||
p_output_provider = execution_providers.Get(p_output_tensor->Location());
|
||||
}
|
||||
auto fetch_result_device = fetched_mlvalue.Get<Tensor>().Location().device;
|
||||
if (target_device == fetch_result_device) {
|
||||
user_fetches[idx] = fetched_mlvalue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!p_output_provider) {
|
||||
p_output_provider = cpu_execution_provider;
|
||||
}
|
||||
|
||||
auto output_provider_type = p_output_provider->Type();
|
||||
|
||||
if (fetched_provider_type == output_provider_type ||
|
||||
(p_output_provider == cpu_execution_provider && fetched_tensor_location.mem_type == OrtMemTypeCPUOutput)) {
|
||||
user_fetches[idx] = fetched_mlvalue;
|
||||
continue;
|
||||
}
|
||||
|
||||
needed_copy = true;
|
||||
|
||||
auto* p_copy_provider = (fetched_provider_type != onnxruntime::kCpuExecutionProvider)
|
||||
? p_fetched_provider
|
||||
: p_output_provider;
|
||||
|
||||
const int device_id = 0; // TODO: As per comment in the copy input code, make this configurable.
|
||||
FeedsFetchesManager::MLValueCopyInfo copy_info{device_id, p_output_provider, p_copy_provider};
|
||||
FeedsFetchesManager::MLValueCopyInfo copy_info{target_device, p_output_provider};
|
||||
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, fetched_mlvalue, output_mlvalue));
|
||||
|
||||
if (copiers) {
|
||||
|
|
@ -410,11 +372,7 @@ static common::Status CachedCopyOutputsAcrossDevices(
|
|||
|
||||
static DeviceCopyCheck CheckExecutionProviders(const ExecutionProviders& execution_providers) {
|
||||
for (const auto& execution_provider : execution_providers) {
|
||||
if (execution_provider->Type() != onnxruntime::kCpuExecutionProvider &&
|
||||
execution_provider->Type() != onnxruntime::kMklDnnExecutionProvider &&
|
||||
execution_provider->Type() != onnxruntime::kNGraphExecutionProvider &&
|
||||
execution_provider->Type() != onnxruntime::kNupharExecutionProvider &&
|
||||
execution_provider->Type() != onnxruntime::kOpenVINOExecutionProvider) {
|
||||
if (!ProviderIsCpuBased(execution_provider->Type())) {
|
||||
return DeviceCopyCheck::Unknown;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,6 +72,11 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in
|
|||
DeviceAllocatorRegistrationInfo pinned_allocator_info(
|
||||
{OrtMemTypeCPUOutput, [](int) { return std::make_unique<CUDAPinnedAllocator>(0, CUDA_PINNED); }, std::numeric_limits<size_t>::max()});
|
||||
InsertAllocator(CreateAllocator(pinned_allocator_info, device_id_));
|
||||
|
||||
// TODO: this is actually used for the cuda kernels which explicitly ask for inputs from CPU.
|
||||
// This will be refactored/removed when allocator and execution provider are decoupled.
|
||||
DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUInput, [](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>("CUDA_CPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUInput)); }, std::numeric_limits<size_t>::max()});
|
||||
InsertAllocator(CreateAllocator(cpu_allocator_info));
|
||||
}
|
||||
|
||||
CUDAExecutionProvider::~CUDAExecutionProvider() {
|
||||
|
|
|
|||
|
|
@ -448,7 +448,6 @@ void OpTester::Run(ExpectResult expect_result,
|
|||
std::unordered_map<std::string, OrtValue> feeds;
|
||||
std::vector<std::string> output_names;
|
||||
FillFeedsAndOutputNames(feeds, output_names);
|
||||
|
||||
// Run the model
|
||||
SessionOptions so;
|
||||
so.session_logid = op_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue