onnxruntime/onnxruntime/core/framework/kernel_registry.cc
Ke Zhang 3bf0e364e2
Move CopyTensor out of IExecutionProvider interface. (#1268)
* add ortdevice class

* add data transfer manager for copying tensors.

* update

* add data trasnfer for gpu

* fix constexpr build break.

* update

* remove unnecessary header files.

* remove unnecessary header files.

* add dependency

* add dependency

* add dependency

* add dependency

* fix linux build break.

* update

* fix build break

* fix build break

* fix build break

* update

* update

* update c api.

* update to not use OrtCreateAllocatorInfo

* change to all eps .

* fix linux build break

* remove useless codes.

* update

* move datatransfermanager in session state

* update

* fix cuda build break.

* fix comments

* fix windows GPU build.

* fix comments

* fix build break

* fix comments

* fix test failure

* update

* fix comments

* fix onnx runtime server.

* update

* fix test failure.

* fix comments

* fix comment
2019-07-11 14:49:20 -07:00

302 lines
13 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <memory>
#include <unordered_map>
#include "core/framework/kernel_registry.h"
#include "core/framework/session_state.h"
using namespace ::onnxruntime::common;
namespace onnxruntime {
namespace {
// Traverses the node's formal parameters and calls TraverseFn with the formal
// parameter and its associated TypeProto.
// node - the node to traverse
// param_filter_fn - called to determine whether to consider a given formal parameter:
// bool ParamFilterFn(const ONNX_NAMESPACE::OpSchema::FormalParameter& param)
// param - the formal parameter
// returns true if the formal parameter should be considered, false otherwise
// traverse_fn - called to process the formal parameter and its associated TypeProto:
// bool TraverseFn(const ONNX_NAMESPACE::OpSchema::FormalParameter& param,
// const ONNX_NAMESPACE::TypeProto* type)
// param - the formal paremeter
// type - the associated TypeProto
// returns true if traversal should continue, false otherwise
template <typename ParamFilterFn, typename TraverseFn>
void TraverseFormalParametersWithTypeProto(const Node& node,
ParamFilterFn param_filter_fn,
TraverseFn traverse_fn) {
const ONNX_NAMESPACE::OpSchema& op_schema = *node.Op();
// process inputs:
const size_t len = node.InputArgCount().size();
ORT_ENFORCE(len <= op_schema.inputs().size());
int actual_index = 0;
for (size_t formal_index = 0; formal_index != len; ++formal_index) {
const auto& param = op_schema.inputs()[formal_index];
if (param_filter_fn(param)) {
// get type of any corresponding actual parameter, if present
for (int i = 0, end = node.InputArgCount()[formal_index]; i < end; ++i) {
const NodeArg* arg = node.InputDefs()[actual_index + i];
if (!arg->Exists()) continue; // a missing optional argument
if (!traverse_fn(param, arg->TypeAsProto())) return;
}
}
actual_index += node.InputArgCount()[formal_index];
}
// process outputs:
auto actual_outputs = node.OutputDefs();
const auto num_actual_outputs = actual_outputs.size();
const auto last_formal = op_schema.outputs().size() - 1;
for (size_t i = 0; i != num_actual_outputs; ++i) {
const auto& formal = op_schema.outputs()[std::min(i, last_formal)];
if (!param_filter_fn(formal)) continue;
const NodeArg* arg = actual_outputs[i];
if (!arg->Exists()) continue;
if (!traverse_fn(formal, arg->TypeAsProto())) return;
}
}
class TypeBindingResolver {
public:
TypeBindingResolver(const Node& node, bool use_lookup_map)
: node_{node},
type_binding_map_{} {
if (use_lookup_map) {
type_binding_map_ = std::make_unique<TypeBindingMap>();
TraverseFormalParametersWithTypeProto(
node_,
[](const ONNX_NAMESPACE::OpSchema::FormalParameter&) -> bool { return true; },
[this](const ONNX_NAMESPACE::OpSchema::FormalParameter& param,
const ONNX_NAMESPACE::TypeProto* type) -> bool {
type_binding_map_->emplace(param.GetName(), type);
type_binding_map_->emplace(param.GetTypeStr(), type);
return true;
});
}
}
// Resolves a name to a TypeProto* for a given node.
// The name can represent either a type parameter or an input/output parameter.
// Returns the resolved TypeProto* or nullptr if unable to resolve.
const ONNX_NAMESPACE::TypeProto* Resolve(const std::string& name_or_type_str) const {
// lookup if available
if (type_binding_map_) {
auto found_it = type_binding_map_->find(name_or_type_str);
if (found_it == type_binding_map_->end()) return nullptr;
return found_it->second;
}
// fall back to node parameter traversal
const ONNX_NAMESPACE::TypeProto* result{};
TraverseFormalParametersWithTypeProto(
node_,
[&name_or_type_str](const ONNX_NAMESPACE::OpSchema::FormalParameter& param) -> bool {
return param.GetName() == name_or_type_str || param.GetTypeStr() == name_or_type_str;
},
[&result](const ONNX_NAMESPACE::OpSchema::FormalParameter&,
const ONNX_NAMESPACE::TypeProto* type) -> bool {
result = type;
return false;
});
return result;
}
private:
// map from input/output name or type string to TypeProto pointer
using TypeBindingMap = std::unordered_map<std::string, const ONNX_NAMESPACE::TypeProto*>;
const Node& node_;
std::unique_ptr<TypeBindingMap> type_binding_map_;
};
}; // namespace
bool KernelRegistry::VerifyKernelDef(const onnxruntime::Node& node,
const KernelDef& kernel_def,
std::string& error_str,
onnxruntime::ProviderType exec_provider) {
// check if domain matches
if (node.Domain() != kernel_def.Domain()) {
std::ostringstream ostr;
ostr << "Op: " << node.OpType()
<< " Domain mismatch: "
<< " Expected: " << kernel_def.Domain()
<< " Actual: " << node.Domain();
error_str = ostr.str();
return false;
}
// check if execution provider matches
const auto& node_provider = node.GetExecutionProviderType();
const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider);
if (expected_provider != kernel_def.Provider()) {
std::ostringstream ostr;
ostr << "Op: " << node.OpType()
<< " Execution provider mismatch."
<< " Expected: " << expected_provider
<< " Actual: " << kernel_def.Provider();
error_str = ostr.str();
return false;
}
// check if version matches
int kernel_start_version;
int kernel_end_version;
kernel_def.SinceVersion(&kernel_start_version, &kernel_end_version);
int node_since_version = node.Op()->since_version();
// Ideal case is, if schema is Since(5), current opset version is opset 7,
// kernel_def Since(8) Invalid
// kernel_def Since(6) Valid
// kernel_def Since(5) Valid
// kernel_def Since(4) Invalid
// kernel_def Since(4, 6) Valid
// Right now there is no "until version" on schema, it is difficult to get opset version here.(require a lot of interface change.)
// As a trade off, we will temporary require kernel definition to have the same since version as schema definition.
// so kernel_def Since(6) will become invalid now.
// After ONNX add "until version" on the schema object, we will update this place
bool valid_version = kernel_start_version == node_since_version // the idea case this branch should be kernel_start_version >= node_version && kernel_start_version <= until_version
|| (kernel_start_version < node_since_version && kernel_end_version != INT_MAX && kernel_end_version >= node_since_version);
if (!valid_version) {
std::ostringstream ostr;
ostr << "Op: " << node.OpType()
<< " Version mismatch."
<< " node_version: " << node_since_version
<< " kernel start version: " << kernel_start_version
<< " kernel_end_version: " << kernel_end_version;
error_str = ostr.str();
return false;
}
// check if type matches
auto& kernel_type_constraints = kernel_def.TypeConstraints();
// Note: The number of formal input/output parameters is N and the number of
// type constraints is M. We select between an O(N*M) and an O(N+M) approach.
// The O(N*M) approach has lower initial overhead.
// kTypeBindingResolverComplexityThreshold is the value of N*M above which we
// will use the O(N+M) approach.
constexpr int kTypeBindingResolverComplexityThreshold = 50 * 50;
const bool use_lookup_map = (kernel_type_constraints.size() * (node.Op()->inputs().size() + node.Op()->outputs().size()) >
kTypeBindingResolverComplexityThreshold);
TypeBindingResolver type_binding_resolver{node, use_lookup_map};
for (auto& constraint : kernel_type_constraints) {
const std::string& name = constraint.first;
const std::vector<MLDataType>& allowed_types = constraint.second;
const ONNX_NAMESPACE::TypeProto* actual_type = type_binding_resolver.Resolve(name);
// If actual_type is null, this represents a type-constraint on a
// missing optional parameter, which can be skipped.
// TODO: We should check that names specified in kernel_type_constraints are
// valid names (of types or parameters) at the time that kernels are registered.
if ((nullptr != actual_type) &&
!std::any_of(allowed_types.begin(), allowed_types.end(),
[actual_type, &node, &error_str](const DataTypeImpl* expected_type) {
bool rc = expected_type->IsCompatible(*actual_type); // for easier debugging
if (!rc) {
// TODO print type information as well
error_str = "Op: " + node.OpType() + " Incompatible types.";
}
return rc;
})) {
return false;
}
}
return true;
}
Status KernelRegistry::Register(KernelDefBuilder& kernel_builder,
const KernelCreateFn& kernel_creator) {
return Register(KernelCreateInfo(kernel_builder.Build(), kernel_creator));
}
Status KernelRegistry::Register(KernelCreateInfo&& create_info) {
auto& op_name = create_info.kernel_def->OpName();
// Check op version conflicts.
auto range = kernel_creator_fn_map_.equal_range(op_name);
for (auto i = range.first; i != range.second; ++i) {
if (i->second.kernel_def &&
i->second.status.IsOK() &&
i->second.kernel_def->IsConflict(*create_info.kernel_def)) {
auto st = create_info.status =
Status(ONNXRUNTIME, FAIL,
"Failed to add kernel for " + op_name +
": Conflicting with a registered kernel with op versions.");
// For invalid entries, we keep them in the map now. Must check for status
// when using the entries from the map.
kernel_creator_fn_map_.emplace(op_name, std::move(create_info));
return st;
}
}
// Register the kernel.
// Ownership of the KernelDef is transferred to the map.
kernel_creator_fn_map_.emplace(op_name, std::move(create_info));
return Status::OK();
}
Status KernelRegistry::TryCreateKernel(const onnxruntime::Node& node,
const IExecutionProvider& execution_provider,
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const FuncManager& funcs_mgr,
const DataTransferManager& data_transfer_mgr,
/*out*/ std::unique_ptr<OpKernel>& op_kernel) const {
const KernelCreateInfo* kernel_create_info = TryFindKernel(node, execution_provider.Type());
if (!kernel_create_info) {
return Status(ONNXRUNTIME, FAIL, "Failed to find kernel for " + node.OpType());
}
OpKernelInfo kernel_info(node,
*kernel_create_info->kernel_def,
execution_provider,
constant_initialized_tensors,
ort_value_name_idx_map,
funcs_mgr,
data_transfer_mgr);
op_kernel.reset(kernel_create_info->kernel_create_func(kernel_info));
return Status::OK();
}
static std::string ToString(const std::vector<std::string>& error_strs) {
std::ostringstream ostr;
std::for_each(std::begin(error_strs), std::end(error_strs),
[&ostr](const std::string& str) { ostr << str << " "; });
return ostr.str();
}
// TODO: return a Status instead of logging error messages here.
// Because this function often returns nullptr, which is totally expected.
// if this function is called before graph partition, then node.provider is not set.
// In this case, the kernel's provider must equal to exec_provider
// otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored.
const KernelCreateInfo* KernelRegistry::TryFindKernel(const onnxruntime::Node& node,
onnxruntime::ProviderType exec_provider) const {
auto range = kernel_creator_fn_map_.equal_range(node.OpType());
std::vector<std::string> error_strs;
for (auto i = range.first; i != range.second; ++i) {
if (!i->second.status.IsOK()) {
LOGS_DEFAULT(ERROR) << "Failed to create kernel for op: " << node.OpType()
<< " since it was ill-formed during registration";
continue;
}
std::string error_str;
if (VerifyKernelDef(node, *i->second.kernel_def, error_str, exec_provider)) {
return &i->second;
}
error_strs.push_back(error_str);
}
std::string expected_provider =
(node.GetExecutionProviderType().empty() ? exec_provider : node.GetExecutionProviderType());
LOGS_DEFAULT(INFO) << node.OpType() << " kernel is not supported in " << expected_provider
<< " Encountered following errors: " << ToString(error_strs);
return nullptr;
}
} // namespace onnxruntime