From be08b47e7bc08285b48a119f1b756de67e593ce9 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Fri, 28 Apr 2023 09:32:54 -0700 Subject: [PATCH] Refine cast optimizer for safety (#15658) ### Description Cast optimizer may convert a fp16 node to fp32. This used to be safe as all fp16 kernels has fp32 implementation. As this assumption is no longer true, we need to check the validity of the operation ### Motivation and Context Main work here is to introduce an API to check whether a kernel is registered. Currently we don't have a way to do that without an operator node. This needs to be augmented. We need to query whether a kernel is registered by its property only, so that we can judge whether it is safe to construct a node long before we actually do so. --- .../core/framework/kernel_registry.h | 19 +++ onnxruntime/core/framework/kernel_registry.cc | 100 +++++++++-- .../core/optimizer/insert_cast_transformer.cc | 159 ++++++++++++------ .../core/optimizer/insert_cast_transformer.h | 17 +- onnxruntime/core/session/inference_session.cc | 8 +- .../framework/insert_cast_transformer_test.cc | 15 +- .../providers/compare_provider_test_utils.cc | 2 +- 7 files changed, 242 insertions(+), 78 deletions(-) diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index dc5499d0a7..ceb2b75795 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -44,6 +44,25 @@ class KernelRegistry { const TypeConstraintMap& type_constraints, const KernelCreateInfo** out) const; + /** + * @brief Find out whether a kernel is registered, without a node. + * This should be useful in graph optimizers, to check whether + * the node it is about to generate, is supported or not. + * @param exec_provider + * @param op_type + * @param domain + * @param version + * @param type_constraints + * @param out + * @return + */ + Status TryFindKernel(ProviderType exec_provider, + std::string_view op_type, + std::string_view domain, + int version, + const KernelRegistry::TypeConstraintMap& type_constraints, + const KernelCreateInfo** out) const; + static bool HasImplementationOf(const KernelRegistry& r, const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver) { diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index a4ab218ceb..d695e0e04c 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -116,35 +116,41 @@ bool MatchKernelDefTypes(const std::unordered_map= node_version); + kernel_start_version <= since_ver && kernel_end_version >= since_ver); if (!valid_version) { std::ostringstream ostr; - ostr << "Op with name (" << node.Name() << ")" - << " and type (" << node.OpType() << ")" - << " Version mismatch." - << " node_version: " << node_version + ostr << " Version mismatch." + << " node_version: " << since_ver << " kernel start version: " << kernel_start_version << " kernel_end_version: " << kernel_end_version; error_str = ostr.str(); + } + return valid_version; +} + +bool KernelRegistry::VerifyKernelDef(const Node& node, + const KernelDef& kernel_def, + const IKernelTypeStrResolver* kernel_type_str_resolver, + const TypeConstraintMap* type_constraint_values, + std::string& error_str) { + // check if version matches + bool valid_version = VerifyVersion(node.SinceVersion(), kernel_def, error_str); + + if (!valid_version) { return false; } @@ -157,12 +163,9 @@ bool KernelRegistry::VerifyKernelDef(const Node& node, if (!matched) { std::ostringstream ostr; - ostr << "Found kernel for Op with name (" << node.Name() << ")" - << " and type (" << node.OpType() << ")" + ostr << "Kernel found kernel" << " in the supported version range" - << " (node_version: " << node_version - << " kernel start version: " << kernel_start_version - << " kernel_end_version: " << kernel_end_version << ")." + << " (node_version: " << node.SinceVersion() << ")." << " However the types are incompatible. " << mismatch_reason; error_str = ostr.str(); } @@ -203,6 +206,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, if (!verify_kernel_def_error_strs.empty()) { std::ostringstream oss; oss << "Op with name (" << node.Name() << ")" + << " domain (" << node.Domain() << ")" << " and type (" << node.OpType() << ")" << " kernel is not supported in " << expected_provider << "." << " Encountered following errors: ("; @@ -229,6 +233,68 @@ Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provide return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out); } +static bool KernelDefCompatible(int version, const KernelDef& kernel_def, + const KernelRegistry::TypeConstraintMap& type_constraint_values, + std::string& error_str) { + if (!VerifyVersion(version, kernel_def, error_str)) { + return false; + } + + const auto& kernel_type_constraints = kernel_def.TypeConstraints(); + bool matched = MatchKernelDefTypes(kernel_type_constraints, type_constraint_values); + + if (!matched) { + std::ostringstream ostr; + ostr << "Kernel found kernel" + << " in the supported version range" + << " (node_version: " << version << ")." + << " However the types are incompatible."; + error_str = ostr.str(); + } + + return matched; +} + +Status KernelRegistry::TryFindKernel(ProviderType exec_provider, + std::string_view op_type, + std::string_view domain, + int version, + const KernelRegistry::TypeConstraintMap& type_constraints, + const KernelCreateInfo** out) const { + auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider)); + if (out) *out = nullptr; + + std::vector verify_kernel_def_error_strs; + + for (auto i = range.first; i != range.second; ++i) { + std::string error_str; + if (KernelDefCompatible(version, *i->second.kernel_def, type_constraints, error_str)) { + if (out) { + *out = &i->second; + } + return Status::OK(); + } + + verify_kernel_def_error_strs.push_back(error_str); + } + + if (!verify_kernel_def_error_strs.empty()) { + std::ostringstream oss; + oss << "Op type (" << op_type << ")" + << " domain (" << domain << ")" + << " kernel is not supported in " << exec_provider << "." + << " Encountered following errors: ("; + std::copy(verify_kernel_def_error_strs.begin(), verify_kernel_def_error_strs.end(), + std::ostream_iterator(oss, "\n")); + oss << ")"; + + VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); + } + + return Status(common::ONNXRUNTIME, common::FAIL, "Kernel not found"); +} + Status KernelRegistry::Register(KernelDefBuilder& kernel_builder, const KernelCreateFn& kernel_creator) { return Register(KernelCreateInfo(kernel_builder.Build(), kernel_creator)); diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 1a7fabdbe7..7c087ec77d 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -84,9 +84,7 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { // going to a node that will need a Cast. // // Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32. -static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph) { - bool isolated_fp16_node = false; - +static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { // we can check if it's an isolated fp16 node // if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't), // does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node), @@ -96,70 +94,135 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: !node.ContainsSubgraph() && !graph.NodeProducesGraphOutput(node) && node.GetExecutionProviderType() == kCpuExecutionProvider) { - do { - // find the number of fp16 inputs as we need to make sure they're all coming from nodes that will be cast - const auto& input_defs = node.InputDefs(); - size_t num_fp16_inputs = std::count_if(input_defs.cbegin(), input_defs.cend(), - [](const NodeArg* input_def) { - return IsMLFloat16Tensor(*input_def); - }); + // + // Three tasks here: + // 1. make sure all tensor(float16) inputs and first output coming from or + // going to nodes that will be cast to fp32 + // 2. check the current node is float16 node. + // 3. check the current node has a float32 implementation + // Only return true when all three are satisfied + // + const auto* schema = node.Op(); + if (!schema) { + // no way to know whether it is safe to convert this to fp32, give up + return false; + } - if (num_fp16_inputs == 0) { - break; + const TypeConstraintMap& type_schema = schema->typeConstraintMap(); + InlinedHashMap type_constraint_map; + type_constraint_map.reserve(type_schema.size()); + + // For each formal parameters, there might be 0-n + // actual inputs, this makes it very tricky to find out which + // actual input should map to which formal parameter + + const auto& input_arg_counts = node.InputArgCount(); + const auto& input_defs = node.InputDefs(); + const auto& formal_inputs = schema->inputs(); + const size_t num_inputs = std::min(formal_inputs.size(), input_arg_counts.size()); + + InlinedHashSet fp16_args; + int input_idx_start = 0; + for (size_t formal_idx = 0; + formal_idx < num_inputs; + input_idx_start += input_arg_counts[formal_idx], formal_idx++) { + const auto& type_str = formal_inputs[formal_idx].GetTypeStr(); + TypeConstraintMap::const_iterator it = type_schema.find(type_str); + if (it == type_schema.end()) { + // Don't care about parameter that does not have a type constraint. + continue; } - size_t num_fp16_input_edges = 0; + // type_str is like T, T1 or T2 ... + for (int input_idx = 0; input_idx < input_arg_counts[formal_idx]; input_idx++) { + const size_t idx = static_cast(input_idx_start) + static_cast(input_idx); + ORT_ENFORCE(idx < input_defs.size()); + const NodeArg* input_def = input_defs[idx]; + if (!input_def || !input_def->Exists()) { + continue; + } + if (IsMLFloat16Tensor(*input_def)) { + fp16_args.emplace(static_cast(idx)); + type_constraint_map[type_str] = DataTypeImpl::GetTensorType(); + break; // we don't have multiple tensors feeding into one input + } + type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(input_def->TypeAsProto())); + break; // we don't have multiple tensors feeding into one input + } + } - // check if all nodes providing our fp16 input need to be cast to fp32 - for (auto input_edge = node.InputEdgesBegin(), end = node.InputEdgesEnd(); input_edge != end; ++input_edge) { - const NodeArg& input_def = *input_defs[input_edge->GetDstArgIndex()]; + if (fp16_args.empty()) { + return false; + } - if (IsMLFloat16Tensor(input_def)) { - // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 - if (!NodeNeedsInputCastToFp32(input_edge->GetNode())) { - break; - } - - ++num_fp16_input_edges; + // check if all nodes providing our fp16 input need to be cast to fp32 + for (auto input_edge = node.InputEdgesBegin(), end = node.InputEdgesEnd(); input_edge != end; ++input_edge) { + const int arg_idx = input_edge->GetDstArgIndex(); + if (fp16_args.find(arg_idx) != fp16_args.end()) { + // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 + if (!NodeNeedsInputCastToFp32(input_edge->GetNode())) { + return false; } } + } - // one or more fp16 inputs are coming from a graph input or initializer - if (num_fp16_inputs != num_fp16_input_edges) { - break; + // if we got here all nodes providing our fp16 input/s will be cast to fp32. + // check if the same applies to the nodes consuming our fp16 output. + fp16_args.clear(); + const auto& output_defs = node.OutputDefs(); + const auto& formal_outputs = schema->outputs(); + const size_t num_outputs = std::min(formal_outputs.size(), output_defs.size()); + for (size_t idx = 0; idx < num_outputs; idx++) { + const auto& type_str = formal_outputs[idx].GetTypeStr(); + TypeConstraintMap::const_iterator it = type_schema.find(type_str); + if (it == type_schema.end()) { + // Don't care about parameter that does not have a type constraint. + continue; } - // if we got here all nodes providing our fp16 input/s will be cast to fp32. - // check if the same applies to all nodes consuming our fp16 output. + const NodeArg* output_def = output_defs[idx]; + if (!output_def || !output_def->Exists()) { + continue; + } + if (IsMLFloat16Tensor(*output_def)) { + fp16_args.emplace((int)idx); + type_constraint_map[type_str] = DataTypeImpl::GetTensorType(); + } else { + type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(output_def->TypeAsProto())); + } + } - bool node_has_fp16_output = false; + if (fp16_args.empty()) { + return false; // no fp16 output + } - for (auto output_edge = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge != end; ++output_edge) { - const NodeArg& output_def = *node.OutputDefs()[output_edge->GetSrcArgIndex()]; - if (IsMLFloat16Tensor(output_def)) { - node_has_fp16_output = true; - - // if the node consuming our fp16 output does not need a cast, we should run in fp16 - if (!NodeNeedsInputCastToFp32(output_edge->GetNode())) { - break; - } + for (auto output_edge = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge != end; ++output_edge) { + const int arg_idx = output_edge->GetSrcArgIndex(); + if (fp16_args.find(arg_idx) != fp16_args.end()) { + // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 + if (!NodeNeedsInputCastToFp32(output_edge->GetNode())) { + return false; } } + } - if (node_has_fp16_output) { - // all nodes providing our fp16 input/s will be cast to fp32, and - // we produce one or more fp16 outputs, and all nodes consuming those outputs will be cast to fp32 - isolated_fp16_node = true; - } - } while (false); + // now all fp16 inputs and outputs would have a cast + // make sure fp32 version of the kernel is available. + const KernelCreateInfo* kernel_create_info{}; + const auto lookup_status = cpu_kernel_registry.TryFindKernel( + kCpuExecutionProvider, node.OpType(), node.Domain(), + node.SinceVersion(), type_constraint_map, &kernel_create_info); + if (lookup_status.IsOK() && kernel_create_info != nullptr) { + return true; + } } - return isolated_fp16_node; + return false; } -Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph) { +static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph)) { + if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 node.SetExecutionProviderType(""); } @@ -338,7 +401,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { if (force_cpu_fp32_) - ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph)); + ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_)); GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.h b/onnxruntime/core/optimizer/insert_cast_transformer.h index 86d3a3a960..8be08d5158 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.h +++ b/onnxruntime/core/optimizer/insert_cast_transformer.h @@ -6,6 +6,8 @@ #include "core/graph/graph_viewer.h" #include "core/framework/op_kernel.h" #include "core/optimizer/graph_transformer.h" +#include "core/framework/kernel_registry_manager.h" +#include "core/framework/kernel_registry.h" namespace onnxruntime { @@ -16,19 +18,26 @@ Transformer to insert cast node that casts float16 to float for cpu nodes */ class InsertCastTransformer : public onnxruntime::GraphTransformer { public: - InsertCastTransformer(const std::string& name) + /** + * @brief Initializer + * @param name for logging purpose + * @param cpu_kernel_registry used to query whether an op node can be safely created + */ + InsertCastTransformer(const std::string& name, const KernelRegistry* cpu_kernel_registry) : onnxruntime::GraphTransformer(name), - force_cpu_fp32_(true) { - } + cpu_kernel_registries_(cpu_kernel_registry), + force_cpu_fp32_(cpu_kernel_registry != nullptr) {} private: Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const; + const KernelRegistry* cpu_kernel_registries_; + // Currently because we only have very few cpu kernels support float16, place those nodes on float16 // will introduce many cast between fp32 and fp16, which will slow the execution. // A better solution is to have a cost model to evaluate does it works to place the node on float16. // Here for simplify, we only force the single-node-float16 sub-graph to float32 - bool force_cpu_fp32_; + const bool force_cpu_fp32_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5ae8ab8cf7..b5709698f7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -976,7 +976,13 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // Insert cast node/s. { - InsertCastTransformer insert_cast_transformer{"CastFloat16Transformer"}; + const InlinedVector> kernel_regs = + kernel_registry_manager_.GetKernelRegistriesByProviderType(kCpuExecutionProvider); + const KernelRegistry* cpu_regs = nullptr; + if (!kernel_regs.empty()) { + cpu_regs = kernel_regs[0]; + } + InsertCastTransformer insert_cast_transformer{"CastFloat16Transformer", cpu_regs}; ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(insert_cast_transformer, *session_logger_, graph)); } diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 3f8d37d3ed..0196e3385f 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -8,6 +8,7 @@ #include "gtest/gtest.h" #include "test_utils.h" #include "test/test_environment.h" +#include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/asserts.h" @@ -38,7 +39,7 @@ TEST(TransformerTest, InsertCastGPUTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = true; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -86,7 +87,7 @@ TEST(TransformerTest, InsertCastAllCPUTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = true; EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK()); @@ -123,7 +124,7 @@ TEST(TransformerTest, ThreeInARowRemoval) { // we want to remove 2 of the first 3 ASSERT_TRUE(op_to_count["Cast"] == 4); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -146,7 +147,7 @@ TEST(TransformerTest, RandomNormalLikeWithFloat16Inputs) { ASSERT_TRUE(status.IsOK()) << status; Graph& graph = model->MainGraph(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -166,7 +167,7 @@ TEST(TransformerTest, MultinomialWithFloat16Input) { ASSERT_TRUE(status.IsOK()) << status; Graph& graph = model->MainGraph(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); @@ -186,7 +187,7 @@ TEST(TransformerTest, InsertCastNodeTwice) { ASSERT_TRUE(status.IsOK()) << status; Graph& graph = model->MainGraph(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); // First insert bool modified = false; @@ -279,7 +280,7 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); bool modified = true; EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK()); diff --git a/onnxruntime/test/providers/compare_provider_test_utils.cc b/onnxruntime/test/providers/compare_provider_test_utils.cc index 540b9c0592..ebf06009c4 100644 --- a/onnxruntime/test/providers/compare_provider_test_utils.cc +++ b/onnxruntime/test/providers/compare_provider_test_utils.cc @@ -63,7 +63,7 @@ void CompareOpTester::CompareWithCPU(const std::string& target_provider_type, // the function body is instead used for CPU pass. This option allows the comparison with // the CPU kernel by adding the input/output casts before looking for a registered CPU kernel. if (need_cpu_cast) { - InsertCastTransformer transformer("Test"); + InsertCastTransformer transformer("Test", GetExecutionProvider(kCpuExecutionProvider)->GetKernelRegistry().get()); bool modified = false; status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); ASSERT_TRUE(status.IsOK());