mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Fix bug in the transformer that removes unnecessary Cast nodes where it was re-processing removed nodes leading to multiple calls to RemoveNode for the same node. (#1291)
Description: The remove duplicate Cast logic was processing a node already removed, leading to multiple calls to remove the same node causing an error. Add a check so that nodes marked for removal are skipped. Motivation and Context If a model has 3 Cast nodes in a row the bug would cause an exception to be thrown due to multiple calls to remove the same node. This causes the latest optimized tf2onnx conversion of ssd_mobilenet to break.
This commit is contained in:
parent
c9d83a52a8
commit
86dc3b4360
6 changed files with 68 additions and 20 deletions
|
|
@ -99,10 +99,14 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override {
|
||||
|
||||
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacement_defs;
|
||||
std::vector<onnxruntime::NodeIndex> removed_nodes;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (std::find(removed_nodes.cbegin(), removed_nodes.cend(), node.Index()) != removed_nodes.cend()) {
|
||||
// node has already been marked for removal, and any following node updated so we need to ignore it here
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node.OpType() == "Cast") {
|
||||
// if cast's next node is also cast and next cast's output type equal to cast's input type
|
||||
// remove those two cast.
|
||||
|
|
@ -138,8 +142,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
num_child++;
|
||||
}
|
||||
|
||||
if (child_removed == num_child &&
|
||||
child_removed > 0 &&
|
||||
if (child_removed == num_child &&
|
||||
child_removed > 0 &&
|
||||
graph_outputs.find(node.OutputDefs()[0]) == graph_outputs.end()) {
|
||||
removed_nodes.push_back(node.Index());
|
||||
}
|
||||
|
|
@ -158,7 +162,6 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
};
|
||||
|
||||
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const {
|
||||
|
||||
if (force_cpu_fp32_)
|
||||
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph));
|
||||
|
||||
|
|
@ -241,7 +244,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
|
|||
if (modified) {
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
}
|
||||
|
||||
|
||||
RemoveDuplicateCastTransformer remover;
|
||||
// RemoveDuplicateCastTransformer is a special transformer required for correctness.
|
||||
// It is provider agnostic so simply send an empty vector.
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@
|
|||
using namespace ONNX_NAMESPACE;
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
static const std::string MODEL_FOLDER = "testdata/transform/";
|
||||
|
||||
typedef std::vector<onnxruntime::NodeArg*> ArgMap;
|
||||
TEST(TransformerTest, InsertCastGPUTest) {
|
||||
auto model = std::make_shared<onnxruntime::Model>("test");
|
||||
|
|
@ -103,5 +106,32 @@ TEST(TransformerTest, InsertCastAllCPUTest) {
|
|||
EXPECT_EQ((*it).OpType(), "Cast");
|
||||
}
|
||||
}
|
||||
|
||||
// test that when there are 3 Cast ops in a row we remove the correct ones
|
||||
TEST(TransformerTest, ThreeInARowRemoval) {
|
||||
std::string model_uri = MODEL_FOLDER + "triple-cast.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
auto status = Model::Load(model_uri, model);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
// there are 3 in a row prior to a Transpose, and one post-Transpose.
|
||||
// we want to remove 2 of the first 3
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 4);
|
||||
|
||||
InsertCastTransformer transformer("Test");
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified);
|
||||
EXPECT_TRUE(status.IsOK()) << status;
|
||||
EXPECT_TRUE(modified) << "Transformer should have removed some Cast nodes";
|
||||
status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status;
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 2);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "test_utils.h"
|
||||
#include "core/graph/graph.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
IExecutionProvider* TestCPUExecutionProvider() {
|
||||
|
|
@ -27,10 +29,22 @@ IExecutionProvider* TestTensorrtExecutionProvider() {
|
|||
|
||||
#ifdef USE_OPENVINO
|
||||
IExecutionProvider* TestOpenVINOExecutionProvider() {
|
||||
static OpenVINOExecutionProviderInfo info;
|
||||
static OpenVINOExecutionProvider openvino_provider(info);
|
||||
return &openvino_provider;
|
||||
static OpenVINOExecutionProviderInfo info;
|
||||
static OpenVINOExecutionProvider openvino_provider(info);
|
||||
return &openvino_provider;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Returns a map with the number of occurrences of each operator in the graph.
|
||||
// Helper function to check that the graph transformations have been successfully applied.
|
||||
std::map<std::string, int> CountOpsInGraph(const Graph& graph) {
|
||||
std::map<std::string, int> op_to_count;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
op_to_count[node.OpType()] =
|
||||
op_to_count.count(node.OpType()) == 0 ? 1 : ++op_to_count[node.OpType()];
|
||||
}
|
||||
return op_to_count;
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,14 +2,18 @@
|
|||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "core/framework/allocatormgr.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#endif
|
||||
#ifdef USE_TENSORRT
|
||||
#ifdef USE_TENSORRT
|
||||
#include "core/providers/tensorrt/tensorrt_execution_provider.h"
|
||||
#endif
|
||||
#ifdef USE_OPENVINO
|
||||
|
|
@ -17,6 +21,8 @@
|
|||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
class Graph;
|
||||
|
||||
namespace test {
|
||||
// Doesn't work with ExecutionProviders class and KernelRegistryManager
|
||||
IExecutionProvider* TestCPUExecutionProvider();
|
||||
|
|
@ -62,5 +68,10 @@ void AllocateMLValue(AllocatorPtr alloc, const std::vector<int64_t>& dims, OrtVa
|
|||
DataTypeImpl::GetType<Tensor>(),
|
||||
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
|
||||
}
|
||||
|
||||
// Returns a map with the number of occurrences of each operator in the graph.
|
||||
// Helper function to check that the graph transformations have been successfully applied.
|
||||
std::map<std::string, int> CountOpsInGraph(const Graph& graph);
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -38,16 +38,6 @@ namespace test {
|
|||
|
||||
static const std::string MODEL_FOLDER = "testdata/transform/";
|
||||
|
||||
// Returns a map with the number of occurrences of each operator in the graph.
|
||||
// Helper function to check that the graph transformations have been successfully applied.
|
||||
std::map<std::string, int> CountOpsInGraph(const Graph& graph) {
|
||||
std::map<std::string, int> op_to_count;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
op_to_count[node.OpType()] =
|
||||
op_to_count.count(node.OpType()) == 0 ? 1 : ++op_to_count[node.OpType()];
|
||||
}
|
||||
return op_to_count;
|
||||
}
|
||||
TEST(GraphTransformationTests, IdentityElimination) {
|
||||
string model_uri = MODEL_FOLDER + "abs-id-max.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
@ -141,7 +131,7 @@ TEST(GraphTransformationTests, ShapeToInitializer) {
|
|||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
// Two of the Shapes are not eliminated because:
|
||||
// Two of the Shapes are not eliminated because:
|
||||
// One includes a symbolic dimension.
|
||||
// Another one includes a negative dimension
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 2);
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/triple-cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/triple-cast.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue