diff --git a/onnxruntime/core/optimizer/free_dim_override_transformer.cc b/onnxruntime/core/optimizer/free_dim_override_transformer.cc index 2a16526585..ac6a79a372 100644 --- a/onnxruntime/core/optimizer/free_dim_override_transformer.cc +++ b/onnxruntime/core/optimizer/free_dim_override_transformer.cc @@ -40,6 +40,7 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, // Construct a new shape for this input, replacing free dimensions with their overrides onnx::TensorShapeProto new_shape; + bool shape_modified = false; for (int32_t dim_index = 0; dim_index < input_shape->dim_size(); ++dim_index) { const auto& dimension = input_shape->dim(dim_index); @@ -56,28 +57,32 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int64_t dimension_override = it->second; - // If this dimension actually has a value but it doesn't match the override value, return an - // error. - if (dimension.has_dim_value() && dimension.dim_value() != dimension_override) { + if (dimension.has_dim_value()) { + // If this dimension actually has a value but it doesn't match the override value, return an + // error. + if (dimension.dim_value() != dimension_override) { LOGS(logger, ERROR) << "The model has input '" << graph_input->Name() << "' " << "with a fixed dimension denotation '" << dimension.denotation() << "' " << "but the size of this dimension " << dimension.dim_value() << " " << "does not equal the specified override of" << dimension_override << "."; - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override."); + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override."); + } + } else { + // Set the dimension override + new_dimension->set_dim_value(dimension_override); + shape_modified = true; } - - // Set the dimension override - new_dimension->clear_dim_param(); - new_dimension->set_dim_value(dimension_override); } } - // Set the new shape - auto* mutable_graph_input = graph.GetNodeArg(graph_input->Name()); - assert(mutable_graph_input != nullptr); - mutable_graph_input->SetShape(new_shape); - modified = true; + if (shape_modified) { + // Set the new shape + auto* mutable_graph_input = graph.GetNodeArg(graph_input->Name()); + assert(mutable_graph_input != nullptr); + mutable_graph_input->SetShape(new_shape); + modified = true; + } } return Status::OK(); diff --git a/onnxruntime/test/optimizer/free_dimension_override_test.cc b/onnxruntime/test/optimizer/free_dimension_override_test.cc index 71dbc9f8b9..653ab75804 100644 --- a/onnxruntime/test/optimizer/free_dimension_override_test.cc +++ b/onnxruntime/test/optimizer/free_dimension_override_test.cc @@ -58,6 +58,12 @@ TEST(FreeDimensionOverrideTransformerTest, Test) { ASSERT_TRUE(input_shape->dim(1).denotation() == onnx::DATA_CHANNEL); ASSERT_TRUE(input_shape->dim(1).has_dim_value()); ASSERT_TRUE(input_shape->dim(1).dim_value() == 42); + + graph_transformer = onnxruntime::make_unique(overrides); + bool modified = false; + ASSERT_TRUE(graph_transformer->Apply(graph, modified, + DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_FALSE(modified); // no overrides apply anymore } } // namespace test