From 723e30b361101b24efae02855d479213ac24db47 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 17 Jun 2019 19:17:53 -0700 Subject: [PATCH] Refine node selection logic in ShapeToInitializer optimizer (#1219) * Initial commit * Modify existing test to cover the change * PR feedback * PR feedback --- .../core/optimizer/shape_to_initializer.cc | 7 +-- .../test/optimizer/graph_transform_test.cc | 8 ++-- .../test/testdata/transform/shape-add.onnx | 44 +++++++++++++------ 3 files changed, 40 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/optimizer/shape_to_initializer.cc b/onnxruntime/core/optimizer/shape_to_initializer.cc index 66a8a77575..a479d7816d 100644 --- a/onnxruntime/core/optimizer/shape_to_initializer.cc +++ b/onnxruntime/core/optimizer/shape_to_initializer.cc @@ -59,15 +59,16 @@ bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node) return false; } - // The shape of the input has to be statically known. Moreover, each dimension should have a specific value - // (the rule cannot be applied if one of the dimension is a symbolic variable). + // The shape of the input has to be statically known. Moreover, each dimension should have a meaningful value + // (the rule cannot be applied if one of the dimensions has a negative value or if it is a symbolic variable). const auto* input_shape = node.InputDefs()[0]->Shape(); if (!input_shape) { return false; } for (int i = 0, num_dims = input_shape->dim_size(); i < num_dims; i++) { - if (!input_shape->dim(i).has_dim_value()) { + const auto& input_dim = input_shape->dim(i); + if (!input_dim.has_dim_value() || input_dim.dim_value() < 0) { return false; } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 16febfd106..e6fffbdb07 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -131,7 +131,7 @@ TEST(GraphTransformationTests, ShapeToInitializer) { ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Shape"] == 3); + ASSERT_TRUE(op_to_count["Shape"] == 4); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); @@ -141,8 +141,10 @@ TEST(GraphTransformationTests, ShapeToInitializer) { ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); op_to_count = CountOpsInGraph(graph); - // One of the Shapes is not eliminated because it inlcludes a symbolic dimension. - ASSERT_TRUE(op_to_count["Shape"] == 1); + // Two of the Shapes are not eliminated because: + // One includes a symbolic dimension. + // Another one includes a negative dimension + ASSERT_TRUE(op_to_count["Shape"] == 2); } // Check transformations in the case of a subgraph with constant inputs. diff --git a/onnxruntime/test/testdata/transform/shape-add.onnx b/onnxruntime/test/testdata/transform/shape-add.onnx index bbe25966a2..6a718dae60 100644 --- a/onnxruntime/test/testdata/transform/shape-add.onnx +++ b/onnxruntime/test/testdata/transform/shape-add.onnx @@ -1,15 +1,20 @@ -lotus-transfomrs:â +lotus-transfomrs:Å -AC"Shape +AD"Shape -BD"Shape +BE"Shape + +CG"Shape  -C -DE"Add +D +EF"Add + +F +GH"Add -EF"Shape +HI"Shape  -FG"Identity shape-addZ +IJ"Identity shape-addZ A   @@ -19,24 +24,37 @@   N -b -G +Z +C + + + + ÿÿÿÿÿÿÿÿÿb +J  j -C - - -j D  j E + +j +G +  j F + +j +H + + +j +I +  B