mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Avoid "infinite" loop in optimizer (#3321)
* Avoid "infinite" loop in optimizer When symbolic dimensions are present and can be overridden, FreeDimensionOverrideTransformer always sets modified flag to true. As a consequence, the optimizer loops until the iteration limit is reached.
This commit is contained in:
parent
06fc9506fd
commit
f2ca2b2981
2 changed files with 24 additions and 13 deletions
|
|
@ -40,6 +40,7 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified,
|
|||
|
||||
// Construct a new shape for this input, replacing free dimensions with their overrides
|
||||
onnx::TensorShapeProto new_shape;
|
||||
bool shape_modified = false;
|
||||
for (int32_t dim_index = 0; dim_index < input_shape->dim_size(); ++dim_index) {
|
||||
const auto& dimension = input_shape->dim(dim_index);
|
||||
|
||||
|
|
@ -56,28 +57,32 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified,
|
|||
|
||||
int64_t dimension_override = it->second;
|
||||
|
||||
// If this dimension actually has a value but it doesn't match the override value, return an
|
||||
// error.
|
||||
if (dimension.has_dim_value() && dimension.dim_value() != dimension_override) {
|
||||
if (dimension.has_dim_value()) {
|
||||
// If this dimension actually has a value but it doesn't match the override value, return an
|
||||
// error.
|
||||
if (dimension.dim_value() != dimension_override) {
|
||||
LOGS(logger, ERROR) << "The model has input '" << graph_input->Name() << "' "
|
||||
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
|
||||
<< "but the size of this dimension " << dimension.dim_value() << " "
|
||||
<< "does not equal the specified override of" << dimension_override << ".";
|
||||
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
|
||||
}
|
||||
} else {
|
||||
// Set the dimension override
|
||||
new_dimension->set_dim_value(dimension_override);
|
||||
shape_modified = true;
|
||||
}
|
||||
|
||||
// Set the dimension override
|
||||
new_dimension->clear_dim_param();
|
||||
new_dimension->set_dim_value(dimension_override);
|
||||
}
|
||||
}
|
||||
|
||||
// Set the new shape
|
||||
auto* mutable_graph_input = graph.GetNodeArg(graph_input->Name());
|
||||
assert(mutable_graph_input != nullptr);
|
||||
mutable_graph_input->SetShape(new_shape);
|
||||
modified = true;
|
||||
if (shape_modified) {
|
||||
// Set the new shape
|
||||
auto* mutable_graph_input = graph.GetNodeArg(graph_input->Name());
|
||||
assert(mutable_graph_input != nullptr);
|
||||
mutable_graph_input->SetShape(new_shape);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -58,6 +58,12 @@ TEST(FreeDimensionOverrideTransformerTest, Test) {
|
|||
ASSERT_TRUE(input_shape->dim(1).denotation() == onnx::DATA_CHANNEL);
|
||||
ASSERT_TRUE(input_shape->dim(1).has_dim_value());
|
||||
ASSERT_TRUE(input_shape->dim(1).dim_value() == 42);
|
||||
|
||||
graph_transformer = onnxruntime::make_unique<FreeDimensionOverrideTransformer>(overrides);
|
||||
bool modified = false;
|
||||
ASSERT_TRUE(graph_transformer->Apply(graph, modified,
|
||||
DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
ASSERT_FALSE(modified); // no overrides apply anymore
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue