mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Refine node selection logic in ShapeToInitializer optimizer (#1219)
* Initial commit * Modify existing test to cover the change * PR feedback * PR feedback
This commit is contained in:
parent
e84cb7b579
commit
723e30b361
3 changed files with 40 additions and 19 deletions
|
|
@ -59,15 +59,16 @@ bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node)
|
|||
return false;
|
||||
}
|
||||
|
||||
// The shape of the input has to be statically known. Moreover, each dimension should have a specific value
|
||||
// (the rule cannot be applied if one of the dimension is a symbolic variable).
|
||||
// The shape of the input has to be statically known. Moreover, each dimension should have a meaningful value
|
||||
// (the rule cannot be applied if one of the dimensions has a negative value or if it is a symbolic variable).
|
||||
const auto* input_shape = node.InputDefs()[0]->Shape();
|
||||
if (!input_shape) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0, num_dims = input_shape->dim_size(); i < num_dims; i++) {
|
||||
if (!input_shape->dim(i).has_dim_value()) {
|
||||
const auto& input_dim = input_shape->dim(i);
|
||||
if (!input_dim.has_dim_value() || input_dim.dim_value() < 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ TEST(GraphTransformationTests, ShapeToInitializer) {
|
|||
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 3);
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 4);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
|
|
@ -141,8 +141,10 @@ TEST(GraphTransformationTests, ShapeToInitializer) {
|
|||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
// One of the Shapes is not eliminated because it inlcludes a symbolic dimension.
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 1);
|
||||
// Two of the Shapes are not eliminated because:
|
||||
// One includes a symbolic dimension.
|
||||
// Another one includes a negative dimension
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 2);
|
||||
}
|
||||
|
||||
// Check transformations in the case of a subgraph with constant inputs.
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
lotus-transfomrs:â
|
||||
lotus-transfomrs:Å
|
||||
|
||||
AC"Shape
|
||||
AD"Shape
|
||||
|
||||
BD"Shape
|
||||
BE"Shape
|
||||
|
||||
CG"Shape
|
||||
|
||||
C
|
||||
DE"Add
|
||||
D
|
||||
EF"Add
|
||||
|
||||
F
|
||||
GH"Add
|
||||
|
||||
EF"Shape
|
||||
HI"Shape
|
||||
|
||||
FG"Identity shape-addZ
|
||||
IJ"Identity shape-addZ
|
||||
A
|
||||
|
||||
|
||||
|
|
@ -19,24 +24,37 @@
|
|||
|
||||
|
||||
N
|
||||
b
|
||||
G
|
||||
Z
|
||||
C
|
||||
|
||||
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿb
|
||||
J
|
||||
|
||||
|
||||
j
|
||||
C
|
||||
|
||||
|
||||
j
|
||||
D
|
||||
|
||||
|
||||
j
|
||||
E
|
||||
|
||||
|
||||
j
|
||||
G
|
||||
|
||||
|
||||
j
|
||||
F
|
||||
|
||||
|
||||
j
|
||||
H
|
||||
|
||||
|
||||
j
|
||||
I
|
||||
|
||||
|
||||
B
|
||||
|
|
|
|||
Loading…
Reference in a new issue