diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 2f9e6fbfd6..6e6ef53a35 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -99,10 +99,14 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override { - std::map replacement_defs; std::vector removed_nodes; for (auto& node : graph.Nodes()) { + if (std::find(removed_nodes.cbegin(), removed_nodes.cend(), node.Index()) != removed_nodes.cend()) { + // node has already been marked for removal, and any following node updated so we need to ignore it here + continue; + } + if (node.OpType() == "Cast") { // if cast's next node is also cast and next cast's output type equal to cast's input type // remove those two cast. @@ -138,8 +142,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { num_child++; } - if (child_removed == num_child && - child_removed > 0 && + if (child_removed == num_child && + child_removed > 0 && graph_outputs.find(node.OutputDefs()[0]) == graph_outputs.end()) { removed_nodes.push_back(node.Index()); } @@ -158,7 +162,6 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { }; Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const { - if (force_cpu_fp32_) ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph)); @@ -241,7 +244,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie if (modified) { ORT_RETURN_IF_ERROR(graph.Resolve()); } - + RemoveDuplicateCastTransformer remover; // RemoveDuplicateCastTransformer is a special transformer required for correctness. // It is provider agnostic so simply send an empty vector. diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 79ddf5f1b3..981fed9e21 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -11,6 +11,9 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { + +static const std::string MODEL_FOLDER = "testdata/transform/"; + typedef std::vector ArgMap; TEST(TransformerTest, InsertCastGPUTest) { auto model = std::make_shared("test"); @@ -103,5 +106,32 @@ TEST(TransformerTest, InsertCastAllCPUTest) { EXPECT_EQ((*it).OpType(), "Cast"); } } + +// test that when there are 3 Cast ops in a row we remove the correct ones +TEST(TransformerTest, ThreeInARowRemoval) { + std::string model_uri = MODEL_FOLDER + "triple-cast.onnx"; + std::shared_ptr model; + auto status = Model::Load(model_uri, model); + ASSERT_TRUE(status.IsOK()) << status; + + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + // there are 3 in a row prior to a Transpose, and one post-Transpose. + // we want to remove 2 of the first 3 + ASSERT_TRUE(op_to_count["Cast"] == 4); + + InsertCastTransformer transformer("Test"); + + bool modified = false; + status = transformer.Apply(graph, modified); + EXPECT_TRUE(status.IsOK()) << status; + EXPECT_TRUE(modified) << "Transformer should have removed some Cast nodes"; + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status; + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Cast"] == 2); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/test_utils.cc b/onnxruntime/test/framework/test_utils.cc index b9d09975c0..f39498999c 100644 --- a/onnxruntime/test/framework/test_utils.cc +++ b/onnxruntime/test/framework/test_utils.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "test_utils.h" +#include "core/graph/graph.h" + namespace onnxruntime { namespace test { IExecutionProvider* TestCPUExecutionProvider() { @@ -27,10 +29,22 @@ IExecutionProvider* TestTensorrtExecutionProvider() { #ifdef USE_OPENVINO IExecutionProvider* TestOpenVINOExecutionProvider() { - static OpenVINOExecutionProviderInfo info; - static OpenVINOExecutionProvider openvino_provider(info); - return &openvino_provider; + static OpenVINOExecutionProviderInfo info; + static OpenVINOExecutionProvider openvino_provider(info); + return &openvino_provider; } #endif + +// Returns a map with the number of occurrences of each operator in the graph. +// Helper function to check that the graph transformations have been successfully applied. +std::map CountOpsInGraph(const Graph& graph) { + std::map op_to_count; + for (auto& node : graph.Nodes()) { + op_to_count[node.OpType()] = + op_to_count.count(node.OpType()) == 0 ? 1 : ++op_to_count[node.OpType()]; + } + return op_to_count; +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/test_utils.h b/onnxruntime/test/framework/test_utils.h index fb5ca035cd..82dc329406 100644 --- a/onnxruntime/test/framework/test_utils.h +++ b/onnxruntime/test/framework/test_utils.h @@ -2,14 +2,18 @@ // Licensed under the MIT License. #pragma once +#include +#include + #include "core/framework/allocatormgr.h" #include "core/framework/execution_provider.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/framework/ml_value.h" + #ifdef USE_CUDA #include "core/providers/cuda/cuda_execution_provider.h" #endif -#ifdef USE_TENSORRT +#ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_execution_provider.h" #endif #ifdef USE_OPENVINO @@ -17,6 +21,8 @@ #endif namespace onnxruntime { +class Graph; + namespace test { // Doesn't work with ExecutionProviders class and KernelRegistryManager IExecutionProvider* TestCPUExecutionProvider(); @@ -62,5 +68,10 @@ void AllocateMLValue(AllocatorPtr alloc, const std::vector& dims, OrtVa DataTypeImpl::GetType(), DataTypeImpl::GetType()->GetDeleteFunc()); } + +// Returns a map with the number of occurrences of each operator in the graph. +// Helper function to check that the graph transformations have been successfully applied. +std::map CountOpsInGraph(const Graph& graph); + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e6fffbdb07..7cbdd59a16 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -38,16 +38,6 @@ namespace test { static const std::string MODEL_FOLDER = "testdata/transform/"; -// Returns a map with the number of occurrences of each operator in the graph. -// Helper function to check that the graph transformations have been successfully applied. -std::map CountOpsInGraph(const Graph& graph) { - std::map op_to_count; - for (auto& node : graph.Nodes()) { - op_to_count[node.OpType()] = - op_to_count.count(node.OpType()) == 0 ? 1 : ++op_to_count[node.OpType()]; - } - return op_to_count; -} TEST(GraphTransformationTests, IdentityElimination) { string model_uri = MODEL_FOLDER + "abs-id-max.onnx"; std::shared_ptr model; @@ -141,7 +131,7 @@ TEST(GraphTransformationTests, ShapeToInitializer) { ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); op_to_count = CountOpsInGraph(graph); - // Two of the Shapes are not eliminated because: + // Two of the Shapes are not eliminated because: // One includes a symbolic dimension. // Another one includes a negative dimension ASSERT_TRUE(op_to_count["Shape"] == 2); diff --git a/onnxruntime/test/testdata/transform/triple-cast.onnx b/onnxruntime/test/testdata/transform/triple-cast.onnx new file mode 100644 index 0000000000..4970e622d6 Binary files /dev/null and b/onnxruntime/test/testdata/transform/triple-cast.onnx differ