checking execution provider logic updated. (#1547)

This commit is contained in:
Ke Zhang 2019-08-02 13:29:39 -07:00 committed by GitHub
parent 93cb29f958
commit cb71c69d5e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -98,7 +98,6 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
// info on the logic to create the node_info_vec.
// for (auto& node_info : node_info_vec) {
auto& node_info = node_info_vec.front();
if (node_info.p_node == nullptr) {
// dummy entry for an input that we didn't find a use of in the graph.
// use the input as is given we don't believe it's actually needed.
@ -409,25 +408,18 @@ static common::Status CachedCopyOutputsAcrossDevices(
return Status::OK();
}
// check if all the execution providers use the same allocator. if so, no copies between devices should be required,
// and the overall status for DeviceCopyChecks can be set to NoCopy
static DeviceCopyCheck CheckExecutionProviders(const ExecutionProviders& execution_providers) {
bool all_cpu = true;
for (const auto& execution_provider : execution_providers) {
const auto& allocators = execution_provider->GetAllocators();
// this won't work as desired until multiple providers can share the CPU Allocator and the logic here is updated
// to detect that..
// it will currently handle the scenario when only the CPUExecutionProvider is registered though
if (!std::all_of(allocators.cbegin(), allocators.cend(),
[](const gsl::not_null<const IAllocator*>& allocator) {
return strcmp(allocator->Info().name, CPU) == 0;
})) {
all_cpu = false;
break;
if (execution_provider->Type() != onnxruntime::kCpuExecutionProvider &&
execution_provider->Type() != onnxruntime::kMklDnnExecutionProvider &&
execution_provider->Type() != onnxruntime::kNGraphExecutionProvider &&
execution_provider->Type() != onnxruntime::kNupharExecutionProvider &&
execution_provider->Type() != onnxruntime::kOpenVINOExecutionProvider) {
return DeviceCopyCheck::Unknown;
}
}
return all_cpu ? DeviceCopyCheck::NoCopy : DeviceCopyCheck::Unknown;
return DeviceCopyCheck::NoCopy;
}
// execute graph with cached info from FeedsFetchesManager.