Fix bug in the transformer that removes unnecessary Cast nodes where it was re-processing removed nodes leading to multiple calls to RemoveNode for the same node. (#1291)

Description:
The remove duplicate Cast logic was processing a node already removed, leading to multiple calls to remove the same node causing an error. Add a check so that nodes marked for removal are skipped.

Motivation and Context
If a model has 3 Cast nodes in a row the bug would cause an exception to be thrown due to multiple calls to remove the same node. This causes the latest optimized tf2onnx conversion of ssd_mobilenet to break.
This commit is contained in:
Scott McKay 2019-06-25 15:17:08 +10:00 committed by GitHub
parent c9d83a52a8
commit 86dc3b4360
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 20 deletions

View file

@ -99,10 +99,14 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override {
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacement_defs;
std::vector<onnxruntime::NodeIndex> removed_nodes;
for (auto& node : graph.Nodes()) {
if (std::find(removed_nodes.cbegin(), removed_nodes.cend(), node.Index()) != removed_nodes.cend()) {
// node has already been marked for removal, and any following node updated so we need to ignore it here
continue;
}
if (node.OpType() == "Cast") {
// if cast's next node is also cast and next cast's output type equal to cast's input type
// remove those two cast.
@ -138,8 +142,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
num_child++;
}
if (child_removed == num_child &&
child_removed > 0 &&
if (child_removed == num_child &&
child_removed > 0 &&
graph_outputs.find(node.OutputDefs()[0]) == graph_outputs.end()) {
removed_nodes.push_back(node.Index());
}
@ -158,7 +162,6 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
};
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const {
if (force_cpu_fp32_)
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph));
@ -241,7 +244,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
if (modified) {
ORT_RETURN_IF_ERROR(graph.Resolve());
}
RemoveDuplicateCastTransformer remover;
// RemoveDuplicateCastTransformer is a special transformer required for correctness.
// It is provider agnostic so simply send an empty vector.

View file

@ -11,6 +11,9 @@
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {
static const std::string MODEL_FOLDER = "testdata/transform/";
typedef std::vector<onnxruntime::NodeArg*> ArgMap;
TEST(TransformerTest, InsertCastGPUTest) {
auto model = std::make_shared<onnxruntime::Model>("test");
@ -103,5 +106,32 @@ TEST(TransformerTest, InsertCastAllCPUTest) {
EXPECT_EQ((*it).OpType(), "Cast");
}
}
// test that when there are 3 Cast ops in a row we remove the correct ones
TEST(TransformerTest, ThreeInARowRemoval) {
std::string model_uri = MODEL_FOLDER + "triple-cast.onnx";
std::shared_ptr<Model> model;
auto status = Model::Load(model_uri, model);
ASSERT_TRUE(status.IsOK()) << status;
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
// there are 3 in a row prior to a Transpose, and one post-Transpose.
// we want to remove 2 of the first 3
ASSERT_TRUE(op_to_count["Cast"] == 4);
InsertCastTransformer transformer("Test");
bool modified = false;
status = transformer.Apply(graph, modified);
EXPECT_TRUE(status.IsOK()) << status;
EXPECT_TRUE(modified) << "Transformer should have removed some Cast nodes";
status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status;
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Cast"] == 2);
}
} // namespace test
} // namespace onnxruntime

View file

@ -2,6 +2,8 @@
// Licensed under the MIT License.
#include "test_utils.h"
#include "core/graph/graph.h"
namespace onnxruntime {
namespace test {
IExecutionProvider* TestCPUExecutionProvider() {
@ -27,10 +29,22 @@ IExecutionProvider* TestTensorrtExecutionProvider() {
#ifdef USE_OPENVINO
IExecutionProvider* TestOpenVINOExecutionProvider() {
static OpenVINOExecutionProviderInfo info;
static OpenVINOExecutionProvider openvino_provider(info);
return &openvino_provider;
static OpenVINOExecutionProviderInfo info;
static OpenVINOExecutionProvider openvino_provider(info);
return &openvino_provider;
}
#endif
// Returns a map with the number of occurrences of each operator in the graph.
// Helper function to check that the graph transformations have been successfully applied.
std::map<std::string, int> CountOpsInGraph(const Graph& graph) {
std::map<std::string, int> op_to_count;
for (auto& node : graph.Nodes()) {
op_to_count[node.OpType()] =
op_to_count.count(node.OpType()) == 0 ? 1 : ++op_to_count[node.OpType()];
}
return op_to_count;
}
} // namespace test
} // namespace onnxruntime

View file

@ -2,14 +2,18 @@
// Licensed under the MIT License.
#pragma once
#include <map>
#include <string>
#include "core/framework/allocatormgr.h"
#include "core/framework/execution_provider.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/framework/ml_value.h"
#ifdef USE_CUDA
#include "core/providers/cuda/cuda_execution_provider.h"
#endif
#ifdef USE_TENSORRT
#ifdef USE_TENSORRT
#include "core/providers/tensorrt/tensorrt_execution_provider.h"
#endif
#ifdef USE_OPENVINO
@ -17,6 +21,8 @@
#endif
namespace onnxruntime {
class Graph;
namespace test {
// Doesn't work with ExecutionProviders class and KernelRegistryManager
IExecutionProvider* TestCPUExecutionProvider();
@ -62,5 +68,10 @@ void AllocateMLValue(AllocatorPtr alloc, const std::vector<int64_t>& dims, OrtVa
DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
}
// Returns a map with the number of occurrences of each operator in the graph.
// Helper function to check that the graph transformations have been successfully applied.
std::map<std::string, int> CountOpsInGraph(const Graph& graph);
} // namespace test
} // namespace onnxruntime

View file

@ -38,16 +38,6 @@ namespace test {
static const std::string MODEL_FOLDER = "testdata/transform/";
// Returns a map with the number of occurrences of each operator in the graph.
// Helper function to check that the graph transformations have been successfully applied.
std::map<std::string, int> CountOpsInGraph(const Graph& graph) {
std::map<std::string, int> op_to_count;
for (auto& node : graph.Nodes()) {
op_to_count[node.OpType()] =
op_to_count.count(node.OpType()) == 0 ? 1 : ++op_to_count[node.OpType()];
}
return op_to_count;
}
TEST(GraphTransformationTests, IdentityElimination) {
string model_uri = MODEL_FOLDER + "abs-id-max.onnx";
std::shared_ptr<Model> model;
@ -141,7 +131,7 @@ TEST(GraphTransformationTests, ShapeToInitializer) {
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
op_to_count = CountOpsInGraph(graph);
// Two of the Shapes are not eliminated because:
// Two of the Shapes are not eliminated because:
// One includes a symbolic dimension.
// Another one includes a negative dimension
ASSERT_TRUE(op_to_count["Shape"] == 2);

Binary file not shown.