Refine node selection logic in ShapeToInitializer optimizer (#1219)

* Initial commit

* Modify existing test to cover the change

* PR feedback

* PR feedback
This commit is contained in:
Hariharan Seshadri 2019-06-17 19:17:53 -07:00 committed by GitHub
parent e84cb7b579
commit 723e30b361
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 19 deletions

View file

@ -59,15 +59,16 @@ bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node)
return false;
}
// The shape of the input has to be statically known. Moreover, each dimension should have a specific value
// (the rule cannot be applied if one of the dimension is a symbolic variable).
// The shape of the input has to be statically known. Moreover, each dimension should have a meaningful value
// (the rule cannot be applied if one of the dimensions has a negative value or if it is a symbolic variable).
const auto* input_shape = node.InputDefs()[0]->Shape();
if (!input_shape) {
return false;
}
for (int i = 0, num_dims = input_shape->dim_size(); i < num_dims; i++) {
if (!input_shape->dim(i).has_dim_value()) {
const auto& input_dim = input_shape->dim(i);
if (!input_dim.has_dim_value() || input_dim.dim_value() < 0) {
return false;
}
}

View file

@ -131,7 +131,7 @@ TEST(GraphTransformationTests, ShapeToInitializer) {
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Shape"] == 3);
ASSERT_TRUE(op_to_count["Shape"] == 4);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
@ -141,8 +141,10 @@ TEST(GraphTransformationTests, ShapeToInitializer) {
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
op_to_count = CountOpsInGraph(graph);
// One of the Shapes is not eliminated because it inlcludes a symbolic dimension.
ASSERT_TRUE(op_to_count["Shape"] == 1);
// Two of the Shapes are not eliminated because:
// One includes a symbolic dimension.
// Another one includes a negative dimension
ASSERT_TRUE(op_to_count["Shape"] == 2);
}
// Check transformations in the case of a subgraph with constant inputs.

View file

@ -1,15 +1,20 @@
lotus-transfomrs:â
lotus-transfomrs:Å
AC"Shape
AD"Shape
BD"Shape
BE"Shape
CG"Shape

C
DE"Add
D
EF"Add

F
GH"Add
EF"Shape
HI"Shape

FG"Identity shape-addZ
IJ"Identity shape-addZ
A


@ -19,24 +24,37 @@


N
b
G
Z
C



ÿÿÿÿÿÿÿÿÿb
J

j
C

j
D

j
E

j
G

j
F

j
H

j
I

B