Avoid "infinite" loop in optimizer (#3321)

* Avoid "infinite" loop in optimizer

When symbolic dimensions are present and can be overridden,
FreeDimensionOverrideTransformer always sets modified flag to true. As a
consequence, the optimizer loops until the iteration limit is reached.
This commit is contained in:
Maxim Kalinin 2020-03-30 15:37:00 -07:00 committed by GitHub
parent 06fc9506fd
commit f2ca2b2981
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 13 deletions

View file

@ -40,6 +40,7 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified,
// Construct a new shape for this input, replacing free dimensions with their overrides
onnx::TensorShapeProto new_shape;
bool shape_modified = false;
for (int32_t dim_index = 0; dim_index < input_shape->dim_size(); ++dim_index) {
const auto& dimension = input_shape->dim(dim_index);
@ -56,28 +57,32 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified,
int64_t dimension_override = it->second;
// If this dimension actually has a value but it doesn't match the override value, return an
// error.
if (dimension.has_dim_value() && dimension.dim_value() != dimension_override) {
if (dimension.has_dim_value()) {
// If this dimension actually has a value but it doesn't match the override value, return an
// error.
if (dimension.dim_value() != dimension_override) {
LOGS(logger, ERROR) << "The model has input '" << graph_input->Name() << "' "
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
<< "but the size of this dimension " << dimension.dim_value() << " "
<< "does not equal the specified override of" << dimension_override << ".";
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
}
} else {
// Set the dimension override
new_dimension->set_dim_value(dimension_override);
shape_modified = true;
}
// Set the dimension override
new_dimension->clear_dim_param();
new_dimension->set_dim_value(dimension_override);
}
}
// Set the new shape
auto* mutable_graph_input = graph.GetNodeArg(graph_input->Name());
assert(mutable_graph_input != nullptr);
mutable_graph_input->SetShape(new_shape);
modified = true;
if (shape_modified) {
// Set the new shape
auto* mutable_graph_input = graph.GetNodeArg(graph_input->Name());
assert(mutable_graph_input != nullptr);
mutable_graph_input->SetShape(new_shape);
modified = true;
}
}
return Status::OK();

View file

@ -58,6 +58,12 @@ TEST(FreeDimensionOverrideTransformerTest, Test) {
ASSERT_TRUE(input_shape->dim(1).denotation() == onnx::DATA_CHANNEL);
ASSERT_TRUE(input_shape->dim(1).has_dim_value());
ASSERT_TRUE(input_shape->dim(1).dim_value() == 42);
graph_transformer = onnxruntime::make_unique<FreeDimensionOverrideTransformer>(overrides);
bool modified = false;
ASSERT_TRUE(graph_transformer->Apply(graph, modified,
DefaultLoggingManager().DefaultLogger()).IsOK());
ASSERT_FALSE(modified); // no overrides apply anymore
}
} // namespace test