diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index d2781bcbea..cf473eb83b 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -8,44 +8,30 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -class IdGenerator { - public: - int Next() { - return id++; - } - - private: - int id = 0; -}; - bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const { - //If the node's input is float16 and currently the node is not assigned to any XP. - //we need insert a cast to float, and put the node on CPU for default behavior. - //TODO: a better check is to check does the CPU kernel with float exist or not. + // If the node's input is float16 and currently the node is not assigned to any XP. + // we need insert a cast to float, and put the node on CPU for default behavior. + // TODO: a better check is to check does the CPU kernel with float exist or not. return input->Type() != nullptr && DataTypeImpl::TypeFromProto(*input->TypeAsProto()) == DataTypeImpl::GetTensorType() && node->GetExecutionProviderType().empty(); } onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, - IdGenerator& id_generator, onnxruntime::NodeArg* old_arg, TypeProto* new_type, bool new_on_input, int64_t to_type, onnxruntime::ProviderType providerType) { - //insert cast op to cast input - int id = id_generator.Next(); + // insert cast op to cast input + std::string node_name = graph.GenerateNodeName("Inserted_Cast"); - char str[32]; - snprintf(str, 32, "CastDef_%d", id); - - auto* new_arg = &graph.GetOrCreateNodeArg(str, new_type); + auto* new_arg = &graph.GetOrCreateNodeArg(node_name, new_type); std::vector input_defs = {new_on_input ? new_arg : old_arg}; std::vector output_defs = {new_on_input ? old_arg : new_arg}; - auto& cast_node = graph.AddNode(str, "Cast", "cast node to cast from float16 to float32 on cpu", input_defs, output_defs); + auto& cast_node = graph.AddNode(node_name, "Cast", "cast node to cast from float16 to float32 on cpu", input_defs, output_defs); cast_node.AddAttribute("to", to_type); cast_node.SetExecutionProviderType(providerType); return new_arg; @@ -84,7 +70,7 @@ Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph) { } for (auto& node : graph.Nodes()) { - if (IsSingleInputNodeFloat16Node(node)) { + if (node.OpType() != "Cast" && IsSingleInputNodeFloat16Node(node)) { node.SetExecutionProviderType(""); } } @@ -205,8 +191,8 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie TypeProto float_tensor_proto; float_16_tensor_proto.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT16); float_tensor_proto.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); - IdGenerator id_generator; std::map input_def_updates; + for (onnxruntime::NodeIndex i : order) { auto node = graph.GetNode(i); if (!node) @@ -221,9 +207,8 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie if (input_def_updates.count(src_arg)) { replacement_defs[src_arg] = input_def_updates[src_arg]; } else { - //insert cast op to cast input + // insert cast op to cast input auto dst_arg = AddCastNode(graph, - id_generator, src_arg, &float_tensor_proto, false, @@ -271,10 +256,9 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie if (output->Type() && DataTypeImpl::TypeFromProto(*output->TypeAsProto()) == DataTypeImpl::GetTensorType() && casted) { - //insert cast op to cast output back to float16 + // insert cast op to cast output back to float16 auto dst_arg = output; auto src_arg = AddCastNode(graph, - id_generator, dst_arg, &float_tensor_proto, true, @@ -283,7 +267,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie replacement_defs[dst_arg] = src_arg; } } - + node->ReplaceDefs(replacement_defs); modified = modified || casted; diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.h b/onnxruntime/core/optimizer/insert_cast_transformer.h index 6eec898c0b..86d3a3a960 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.h +++ b/onnxruntime/core/optimizer/insert_cast_transformer.h @@ -23,7 +23,6 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer { private: Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; - bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const; // Currently because we only have very few cpu kernels support float16, place those nodes on float16 diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index e493784fa5..e6924acf0d 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -174,5 +174,36 @@ TEST(TransformerTest, MultinomialWithFloat16Input) { EXPECT_TRUE(status.IsOK()) << status; } +// This test is to test insert_cast_transform the same graph twice +// insert_cast_transform needs to detect existing Cast Node +// Prevent inserting the same Cast node twice +TEST(TransformerTest, InsertCastNodeTwice) { + auto model_uri = MODEL_FOLDER ORT_TSTR("insert_cast_twice.onnx"); + std::shared_ptr model; + auto status = Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(status.IsOK()) << status; + + Graph& graph = model->MainGraph(); + InsertCastTransformer transformer("Test"); + + // First insert + bool modified = false; + status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(status.IsOK()) << status; + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_TRUE(modified) << "Transformer should have added some Cast nodes"; + EXPECT_TRUE(op_to_count["Cast"] == 5) << "Insert 3 more Cast nodes."; + + // Second insert + modified = false; + status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(status.IsOK()) << status; + op_to_count = CountOpsInGraph(graph); + // Same graph without modification; The number of Cast node remains + EXPECT_TRUE(!modified) << "Transformer should not modify the modfied graph again"; + EXPECT_TRUE(op_to_count["Cast"] == 5) << "Remain the same number of Cast node"; + +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/transform/insert_cast_twice.onnx b/onnxruntime/test/testdata/transform/insert_cast_twice.onnx new file mode 100644 index 0000000000..6e39561a19 Binary files /dev/null and b/onnxruntime/test/testdata/transform/insert_cast_twice.onnx differ