diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index dc5499d0a7..ceb2b75795 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -44,6 +44,25 @@ class KernelRegistry { const TypeConstraintMap& type_constraints, const KernelCreateInfo** out) const; + /** + * @brief Find out whether a kernel is registered, without a node. + * This should be useful in graph optimizers, to check whether + * the node it is about to generate, is supported or not. + * @param exec_provider + * @param op_type + * @param domain + * @param version + * @param type_constraints + * @param out + * @return + */ + Status TryFindKernel(ProviderType exec_provider, + std::string_view op_type, + std::string_view domain, + int version, + const KernelRegistry::TypeConstraintMap& type_constraints, + const KernelCreateInfo** out) const; + static bool HasImplementationOf(const KernelRegistry& r, const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver) { diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index a4ab218ceb..d695e0e04c 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -116,35 +116,41 @@ bool MatchKernelDefTypes(const std::unordered_map= node_version); + kernel_start_version <= since_ver && kernel_end_version >= since_ver); if (!valid_version) { std::ostringstream ostr; - ostr << "Op with name (" << node.Name() << ")" - << " and type (" << node.OpType() << ")" - << " Version mismatch." - << " node_version: " << node_version + ostr << " Version mismatch." + << " node_version: " << since_ver << " kernel start version: " << kernel_start_version << " kernel_end_version: " << kernel_end_version; error_str = ostr.str(); + } + return valid_version; +} + +bool KernelRegistry::VerifyKernelDef(const Node& node, + const KernelDef& kernel_def, + const IKernelTypeStrResolver* kernel_type_str_resolver, + const TypeConstraintMap* type_constraint_values, + std::string& error_str) { + // check if version matches + bool valid_version = VerifyVersion(node.SinceVersion(), kernel_def, error_str); + + if (!valid_version) { return false; } @@ -157,12 +163,9 @@ bool KernelRegistry::VerifyKernelDef(const Node& node, if (!matched) { std::ostringstream ostr; - ostr << "Found kernel for Op with name (" << node.Name() << ")" - << " and type (" << node.OpType() << ")" + ostr << "Kernel found kernel" << " in the supported version range" - << " (node_version: " << node_version - << " kernel start version: " << kernel_start_version - << " kernel_end_version: " << kernel_end_version << ")." + << " (node_version: " << node.SinceVersion() << ")." << " However the types are incompatible. " << mismatch_reason; error_str = ostr.str(); } @@ -203,6 +206,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, if (!verify_kernel_def_error_strs.empty()) { std::ostringstream oss; oss << "Op with name (" << node.Name() << ")" + << " domain (" << node.Domain() << ")" << " and type (" << node.OpType() << ")" << " kernel is not supported in " << expected_provider << "." << " Encountered following errors: ("; @@ -229,6 +233,68 @@ Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provide return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out); } +static bool KernelDefCompatible(int version, const KernelDef& kernel_def, + const KernelRegistry::TypeConstraintMap& type_constraint_values, + std::string& error_str) { + if (!VerifyVersion(version, kernel_def, error_str)) { + return false; + } + + const auto& kernel_type_constraints = kernel_def.TypeConstraints(); + bool matched = MatchKernelDefTypes(kernel_type_constraints, type_constraint_values); + + if (!matched) { + std::ostringstream ostr; + ostr << "Kernel found kernel" + << " in the supported version range" + << " (node_version: " << version << ")." + << " However the types are incompatible."; + error_str = ostr.str(); + } + + return matched; +} + +Status KernelRegistry::TryFindKernel(ProviderType exec_provider, + std::string_view op_type, + std::string_view domain, + int version, + const KernelRegistry::TypeConstraintMap& type_constraints, + const KernelCreateInfo** out) const { + auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider)); + if (out) *out = nullptr; + + std::vector verify_kernel_def_error_strs; + + for (auto i = range.first; i != range.second; ++i) { + std::string error_str; + if (KernelDefCompatible(version, *i->second.kernel_def, type_constraints, error_str)) { + if (out) { + *out = &i->second; + } + return Status::OK(); + } + + verify_kernel_def_error_strs.push_back(error_str); + } + + if (!verify_kernel_def_error_strs.empty()) { + std::ostringstream oss; + oss << "Op type (" << op_type << ")" + << " domain (" << domain << ")" + << " kernel is not supported in " << exec_provider << "." + << " Encountered following errors: ("; + std::copy(verify_kernel_def_error_strs.begin(), verify_kernel_def_error_strs.end(), + std::ostream_iterator(oss, "\n")); + oss << ")"; + + VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); + } + + return Status(common::ONNXRUNTIME, common::FAIL, "Kernel not found"); +} + Status KernelRegistry::Register(KernelDefBuilder& kernel_builder, const KernelCreateFn& kernel_creator) { return Register(KernelCreateInfo(kernel_builder.Build(), kernel_creator)); diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 1a7fabdbe7..7c087ec77d 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -84,9 +84,7 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { // going to a node that will need a Cast. // // Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32. -static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph) { - bool isolated_fp16_node = false; - +static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { // we can check if it's an isolated fp16 node // if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't), // does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node), @@ -96,70 +94,135 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: !node.ContainsSubgraph() && !graph.NodeProducesGraphOutput(node) && node.GetExecutionProviderType() == kCpuExecutionProvider) { - do { - // find the number of fp16 inputs as we need to make sure they're all coming from nodes that will be cast - const auto& input_defs = node.InputDefs(); - size_t num_fp16_inputs = std::count_if(input_defs.cbegin(), input_defs.cend(), - [](const NodeArg* input_def) { - return IsMLFloat16Tensor(*input_def); - }); + // + // Three tasks here: + // 1. make sure all tensor(float16) inputs and first output coming from or + // going to nodes that will be cast to fp32 + // 2. check the current node is float16 node. + // 3. check the current node has a float32 implementation + // Only return true when all three are satisfied + // + const auto* schema = node.Op(); + if (!schema) { + // no way to know whether it is safe to convert this to fp32, give up + return false; + } - if (num_fp16_inputs == 0) { - break; + const TypeConstraintMap& type_schema = schema->typeConstraintMap(); + InlinedHashMap type_constraint_map; + type_constraint_map.reserve(type_schema.size()); + + // For each formal parameters, there might be 0-n + // actual inputs, this makes it very tricky to find out which + // actual input should map to which formal parameter + + const auto& input_arg_counts = node.InputArgCount(); + const auto& input_defs = node.InputDefs(); + const auto& formal_inputs = schema->inputs(); + const size_t num_inputs = std::min(formal_inputs.size(), input_arg_counts.size()); + + InlinedHashSet fp16_args; + int input_idx_start = 0; + for (size_t formal_idx = 0; + formal_idx < num_inputs; + input_idx_start += input_arg_counts[formal_idx], formal_idx++) { + const auto& type_str = formal_inputs[formal_idx].GetTypeStr(); + TypeConstraintMap::const_iterator it = type_schema.find(type_str); + if (it == type_schema.end()) { + // Don't care about parameter that does not have a type constraint. + continue; } - size_t num_fp16_input_edges = 0; + // type_str is like T, T1 or T2 ... + for (int input_idx = 0; input_idx < input_arg_counts[formal_idx]; input_idx++) { + const size_t idx = static_cast(input_idx_start) + static_cast(input_idx); + ORT_ENFORCE(idx < input_defs.size()); + const NodeArg* input_def = input_defs[idx]; + if (!input_def || !input_def->Exists()) { + continue; + } + if (IsMLFloat16Tensor(*input_def)) { + fp16_args.emplace(static_cast(idx)); + type_constraint_map[type_str] = DataTypeImpl::GetTensorType(); + break; // we don't have multiple tensors feeding into one input + } + type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(input_def->TypeAsProto())); + break; // we don't have multiple tensors feeding into one input + } + } - // check if all nodes providing our fp16 input need to be cast to fp32 - for (auto input_edge = node.InputEdgesBegin(), end = node.InputEdgesEnd(); input_edge != end; ++input_edge) { - const NodeArg& input_def = *input_defs[input_edge->GetDstArgIndex()]; + if (fp16_args.empty()) { + return false; + } - if (IsMLFloat16Tensor(input_def)) { - // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 - if (!NodeNeedsInputCastToFp32(input_edge->GetNode())) { - break; - } - - ++num_fp16_input_edges; + // check if all nodes providing our fp16 input need to be cast to fp32 + for (auto input_edge = node.InputEdgesBegin(), end = node.InputEdgesEnd(); input_edge != end; ++input_edge) { + const int arg_idx = input_edge->GetDstArgIndex(); + if (fp16_args.find(arg_idx) != fp16_args.end()) { + // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 + if (!NodeNeedsInputCastToFp32(input_edge->GetNode())) { + return false; } } + } - // one or more fp16 inputs are coming from a graph input or initializer - if (num_fp16_inputs != num_fp16_input_edges) { - break; + // if we got here all nodes providing our fp16 input/s will be cast to fp32. + // check if the same applies to the nodes consuming our fp16 output. + fp16_args.clear(); + const auto& output_defs = node.OutputDefs(); + const auto& formal_outputs = schema->outputs(); + const size_t num_outputs = std::min(formal_outputs.size(), output_defs.size()); + for (size_t idx = 0; idx < num_outputs; idx++) { + const auto& type_str = formal_outputs[idx].GetTypeStr(); + TypeConstraintMap::const_iterator it = type_schema.find(type_str); + if (it == type_schema.end()) { + // Don't care about parameter that does not have a type constraint. + continue; } - // if we got here all nodes providing our fp16 input/s will be cast to fp32. - // check if the same applies to all nodes consuming our fp16 output. + const NodeArg* output_def = output_defs[idx]; + if (!output_def || !output_def->Exists()) { + continue; + } + if (IsMLFloat16Tensor(*output_def)) { + fp16_args.emplace((int)idx); + type_constraint_map[type_str] = DataTypeImpl::GetTensorType(); + } else { + type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(output_def->TypeAsProto())); + } + } - bool node_has_fp16_output = false; + if (fp16_args.empty()) { + return false; // no fp16 output + } - for (auto output_edge = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge != end; ++output_edge) { - const NodeArg& output_def = *node.OutputDefs()[output_edge->GetSrcArgIndex()]; - if (IsMLFloat16Tensor(output_def)) { - node_has_fp16_output = true; - - // if the node consuming our fp16 output does not need a cast, we should run in fp16 - if (!NodeNeedsInputCastToFp32(output_edge->GetNode())) { - break; - } + for (auto output_edge = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge != end; ++output_edge) { + const int arg_idx = output_edge->GetSrcArgIndex(); + if (fp16_args.find(arg_idx) != fp16_args.end()) { + // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 + if (!NodeNeedsInputCastToFp32(output_edge->GetNode())) { + return false; } } + } - if (node_has_fp16_output) { - // all nodes providing our fp16 input/s will be cast to fp32, and - // we produce one or more fp16 outputs, and all nodes consuming those outputs will be cast to fp32 - isolated_fp16_node = true; - } - } while (false); + // now all fp16 inputs and outputs would have a cast + // make sure fp32 version of the kernel is available. + const KernelCreateInfo* kernel_create_info{}; + const auto lookup_status = cpu_kernel_registry.TryFindKernel( + kCpuExecutionProvider, node.OpType(), node.Domain(), + node.SinceVersion(), type_constraint_map, &kernel_create_info); + if (lookup_status.IsOK() && kernel_create_info != nullptr) { + return true; + } } - return isolated_fp16_node; + return false; } -Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph) { +static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph)) { + if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 node.SetExecutionProviderType(""); } @@ -338,7 +401,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { if (force_cpu_fp32_) - ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph)); + ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_)); GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.h b/onnxruntime/core/optimizer/insert_cast_transformer.h index 86d3a3a960..8be08d5158 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.h +++ b/onnxruntime/core/optimizer/insert_cast_transformer.h @@ -6,6 +6,8 @@ #include "core/graph/graph_viewer.h" #include "core/framework/op_kernel.h" #include "core/optimizer/graph_transformer.h" +#include "core/framework/kernel_registry_manager.h" +#include "core/framework/kernel_registry.h" namespace onnxruntime { @@ -16,19 +18,26 @@ Transformer to insert cast node that casts float16 to float for cpu nodes */ class InsertCastTransformer : public onnxruntime::GraphTransformer { public: - InsertCastTransformer(const std::string& name) + /** + * @brief Initializer + * @param name for logging purpose + * @param cpu_kernel_registry used to query whether an op node can be safely created + */ + InsertCastTransformer(const std::string& name, const KernelRegistry* cpu_kernel_registry) : onnxruntime::GraphTransformer(name), - force_cpu_fp32_(true) { - } + cpu_kernel_registries_(cpu_kernel_registry), + force_cpu_fp32_(cpu_kernel_registry != nullptr) {} private: Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const; + const KernelRegistry* cpu_kernel_registries_; + // Currently because we only have very few cpu kernels support float16, place those nodes on float16 // will introduce many cast between fp32 and fp16, which will slow the execution. // A better solution is to have a cost model to evaluate does it works to place the node on float16. // Here for simplify, we only force the single-node-float16 sub-graph to float32 - bool force_cpu_fp32_; + const bool force_cpu_fp32_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5ae8ab8cf7..b5709698f7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -976,7 +976,13 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // Insert cast node/s. { - InsertCastTransformer insert_cast_transformer{"CastFloat16Transformer"}; + const InlinedVector> kernel_regs = + kernel_registry_manager_.GetKernelRegistriesByProviderType(kCpuExecutionProvider); + const KernelRegistry* cpu_regs = nullptr; + if (!kernel_regs.empty()) { + cpu_regs = kernel_regs[0]; + } + InsertCastTransformer insert_cast_transformer{"CastFloat16Transformer", cpu_regs}; ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(insert_cast_transformer, *session_logger_, graph)); } diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 3f8d37d3ed..0196e3385f 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -8,6 +8,7 @@ #include "gtest/gtest.h" #include "test_utils.h" #include "test/test_environment.h" +#include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/asserts.h" @@ -38,7 +39,7 @@ TEST(TransformerTest, InsertCastGPUTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = true; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -86,7 +87,7 @@ TEST(TransformerTest, InsertCastAllCPUTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = true; EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK()); @@ -123,7 +124,7 @@ TEST(TransformerTest, ThreeInARowRemoval) { // we want to remove 2 of the first 3 ASSERT_TRUE(op_to_count["Cast"] == 4); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -146,7 +147,7 @@ TEST(TransformerTest, RandomNormalLikeWithFloat16Inputs) { ASSERT_TRUE(status.IsOK()) << status; Graph& graph = model->MainGraph(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -166,7 +167,7 @@ TEST(TransformerTest, MultinomialWithFloat16Input) { ASSERT_TRUE(status.IsOK()) << status; Graph& graph = model->MainGraph(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -186,7 +187,7 @@ TEST(TransformerTest, InsertCastNodeTwice) { ASSERT_TRUE(status.IsOK()) << status; Graph& graph = model->MainGraph(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); // First insert bool modified = false; @@ -279,7 +280,7 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = true; EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK()); diff --git a/onnxruntime/test/providers/compare_provider_test_utils.cc b/onnxruntime/test/providers/compare_provider_test_utils.cc index 540b9c0592..ebf06009c4 100644 --- a/onnxruntime/test/providers/compare_provider_test_utils.cc +++ b/onnxruntime/test/providers/compare_provider_test_utils.cc @@ -63,7 +63,7 @@ void CompareOpTester::CompareWithCPU(const std::string& target_provider_type, // the function body is instead used for CPU pass. This option allows the comparison with // the CPU kernel by adding the input/output casts before looking for a registered CPU kernel. if (need_cpu_cast) { - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", GetExecutionProvider(kCpuExecutionProvider)->GetKernelRegistry().get()); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(status.IsOK());