onnxruntime/onnxruntime/core/framework/execution_provider.cc
Changming Sun bd78364411
Parallel all the activations ops (#3722)
1. Parallel all the activations ops.
2. Parallel the performance critical path of the LRN op, which makes the ONNX model zoo googlenet model runs 60% faster(latency reduced from 21ms to 13ms).
3. Make the Gemm-Activation fusion support with all the activations ops. Before this change, it only supports LeakyRelu/Relu/Sigmoid/Tanh.
4. Delete onnxruntime/test/framework/op_kernel_test.cc because the file is almost empty.
5. Remove the loggings in KernelRegistry::TryFindKernel, return Status with error message instead.
2020-05-05 01:18:17 -07:00

79 lines
2.8 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/execution_provider.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/compute_capability.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry.h"
namespace onnxruntime {
namespace {
//It assumes max(OrtMemType) <= 1, min(OrtMemType) = -2
inline int MakeKey(int id, OrtMemType mem_type) {
return id << 2 | (mem_type + 2);
}
} // namespace
AllocatorPtr IExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const {
auto iter = allocators_.find(MakeKey(id, mem_type));
if (iter != allocators_.end()) {
return iter->second;
}
return nullptr;
}
std::vector<std::unique_ptr<ComputeCapability>>
IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& kernel_registries) const {
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node : graph.Nodes()) {
for (auto registry : kernel_registries) {
if (KernelRegistry::HasImplementationOf(*registry, node, Type())) {
std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::make_unique<IndexedSubGraph>();
sub_graph->nodes.push_back(node.Index());
result.push_back(onnxruntime::make_unique<ComputeCapability>(std::move(sub_graph)));
break;
}
}
}
return result;
}
common::Status IExecutionProvider::Sync() const { return Status::OK(); };
common::Status IExecutionProvider::OnRunStart() { return Status::OK(); }
common::Status IExecutionProvider::OnRunEnd() { return Status::OK(); }
common::Status IExecutionProvider::OnSessionInitializationEnd() { return Status::OK(); }
void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) {
const OrtMemoryInfo& info = allocator->Info();
const int key = MakeKey(info.id, info.mem_type);
auto iter = allocators_.find(key);
if (iter != allocators_.end()) {
ORT_THROW("duplicated allocator");
}
allocators_.insert(iter, {key, allocator});
allocator_list_.emplace_back(gsl::not_null<IAllocator*>(allocator.get()));
}
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
}
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
std::string& /*dll_path*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
}
std::shared_ptr<KernelRegistry> IExecutionProvider::GetKernelRegistry() const {
return nullptr;
}
} // namespace onnxruntime