diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index e592b1f098..3bc74f57ae 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -3,11 +3,14 @@ #pragma once +#include + #include "core/optimizer/graph_transformer.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/rewrite_rule.h" namespace onnxruntime { +struct FreeDimensionOverride; namespace transformer_utils { @@ -21,6 +24,7 @@ std::vector> GenerateRewriteRules(TransformerLevel If transformers_and_rules_to_enable is not empty, it returns the intersection between the predefined transformers/rules and the transformers_and_rules_to_enable. */ std::vector> GenerateTransformers(TransformerLevel level, + gsl::span free_dimension_overrides, const std::vector& rules_and_transformers_to_enable = {}); /** Given a TransformerLevel, this method generates a name for the rule-based graph transformer of that level. */ diff --git a/onnxruntime/core/optimizer/free_dim_override_transformer.cc b/onnxruntime/core/optimizer/free_dim_override_transformer.cc new file mode 100644 index 0000000000..1804e9cae4 --- /dev/null +++ b/onnxruntime/core/optimizer/free_dim_override_transformer.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/logging/logging.h" +#include "core/session/inference_session.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/free_dim_override_transformer.h" + +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::common; +namespace onnxruntime { + +static std::string ToLower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), [](char c) { + return static_cast(::tolower(c)); + }); + + return s; +} + +/*explicit*/ FreeDimensionOverrideTransformer::FreeDimensionOverrideTransformer(gsl::span overrides_to_apply) + : GraphTransformer("FreeDimensionOverrideTransformer") { + for (const auto& o : overrides_to_apply) { + // Convert to lowercase to perform case-insensitive comparisons later + std::string denotation = ToLower(o.dimension_denotation); + + dimension_override_by_denotation_.emplace(denotation, o.dimension_override); + } +} + +Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/) const { + for (const onnxruntime::NodeArg* graph_input : graph.GetInputs()) { + // Get the current input's type and shape + const auto* input_type = graph_input->TypeAsProto(); + const auto* input_shape = graph_input->Shape(); + + if (!input_type || !input_shape || !input_type->has_tensor_type()) { + continue; + } + + // Construct a new shape for this input, replacing free dimensions with their overrides + onnx::TensorShapeProto new_shape; + for (int32_t dim_index = 0; dim_index < input_shape->dim_size(); ++dim_index) { + const auto& dimension = input_shape->dim(dim_index); + + // By default just make a copy of the dimension + auto* new_dimension = new_shape.add_dim(); + *new_dimension = dimension; + + if (dimension.has_denotation()) { + // Convert to lowercase to perform case-insensitive comparison + auto it = dimension_override_by_denotation_.find(ToLower(dimension.denotation())); + if (it == dimension_override_by_denotation_.end()) { + continue; + } + + int64_t dimension_override = it->second; + + // If this dimension actually has a value but it doesn't match the override value, return an + // error. + if (dimension.has_dim_value() && dimension.dim_value() != dimension_override) { + LOGS_DEFAULT(ERROR) << "The model has input '" << graph_input->Name() << "' " + << "with a fixed dimension denotation '" << dimension.denotation() << "' " + << "but the size of this dimension " << dimension.dim_value() << " " + << "does not equal the specified override of" << dimension_override << "."; + + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override."); + } + + // Set the dimension override + new_dimension->clear_dim_param(); + new_dimension->set_dim_value(dimension_override); + } + } + + // Set the new shape + auto* mutable_graph_input = graph.GetNodeArg(graph_input->Name()); + assert(mutable_graph_input != nullptr); + mutable_graph_input->SetShape(new_shape); + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/free_dim_override_transformer.h b/onnxruntime/core/optimizer/free_dim_override_transformer.h new file mode 100644 index 0000000000..d92ac47a08 --- /dev/null +++ b/onnxruntime/core/optimizer/free_dim_override_transformer.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/common.h" +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +struct FreeDimensionOverride; + +/** +@Class FreeDimensionOverrideTransformer + +Transformer that overrides free dimensions in the graph with the specific value +that matches the denotation for that dimension. +*/ +class FreeDimensionOverrideTransformer : public GraphTransformer { + public: + explicit FreeDimensionOverrideTransformer(gsl::span overrides_to_apply); + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + + std::map dimension_override_by_denotation_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 173bd6fc6e..c163b609fd 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -17,7 +17,9 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/shape_to_initializer.h" #include "core/optimizer/nchwc_transformer.h" +#include "core/optimizer/free_dim_override_transformer.h" #include "core/mlas/inc/mlas.h" +#include "core/session/inference_session.h" namespace onnxruntime { @@ -87,6 +89,7 @@ std::unique_ptr GenerateRuleBasedGraphTransformer(Tra } std::vector> GenerateTransformers(TransformerLevel level, + gsl::span free_dimension_overrides, const std::vector& transformers_and_rules_to_enable) { std::vector> transformers; std::unique_ptr rule_transformer = nullptr; @@ -95,6 +98,7 @@ std::vector> GenerateTransformers(TransformerL std::unordered_set l1_execution_providers = {}; transformers.emplace_back(std::make_unique(l1_execution_providers)); + transformers.emplace_back(std::make_unique(free_dimension_overrides)); rule_transformer = GenerateRuleBasedGraphTransformer(level, transformers_and_rules_to_enable, l1_execution_providers); } break; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5a02b4393a..999a16a7fb 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1037,7 +1037,7 @@ void InferenceSession::AddPredefinedTransformers(GraphTransformerManager& transf const std::vector& custom_list) { auto add_transformers = [&](TransformerLevel level) { // Generate and register transformers for level - auto transformers_to_register = transformer_utils::GenerateTransformers(level, custom_list); + auto transformers_to_register = transformer_utils::GenerateTransformers(level, session_options_.free_dimension_overrides, custom_list); for (auto& entry : transformers_to_register) { transformer_manager.Register(std::move(entry), level); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 5dc5b26fc6..8f704c8b91 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -46,6 +46,12 @@ namespace logging { class LoggingManager; } +struct FreeDimensionOverride +{ + std::string dimension_denotation; + int64_t dimension_override; +}; + /** * Configuration information for a session. */ @@ -93,6 +99,10 @@ struct SessionOptions { // controls the size of the thread pool used to parallelize the execution of nodes (ops) // configuring this makes sense only when you're using parallel executor int inter_op_num_threads = 0; + + // For models with free input dimensions (most commonly batch size), specifies a set of values to override those + // free dimensions with, keyed by dimension denotation. + std::vector free_dimension_overrides; }; /** diff --git a/onnxruntime/test/optimizer/free_dimension_override_test.cc b/onnxruntime/test/optimizer/free_dimension_override_test.cc new file mode 100644 index 0000000000..c898d1fc30 --- /dev/null +++ b/onnxruntime/test/optimizer/free_dimension_override_test.cc @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/graph_transformer_mgr.h" +#include "test/framework/test_utils.h" +#include "test/test_environment.h" +#include "gtest/gtest.h" +#include "core/optimizer/free_dim_override_transformer.h" +#include "core/session/inference_session.h" + +using namespace std; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace test { + +TEST(FreeDimensionOverrideTransformerTest, Test) { + string model_uri = "testdata/abs_free_dimensions.onnx"; + + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + Graph& graph = model->MainGraph(); + + // The model's input shape has two free dimensions, which have the denotation of DATA_BATCH + // and DATA_CHANNEL. Supplying these overrides to the transformer should replace those free + // dimensions with values of 1 and 42, respectively. + std::vector overrides = + { + FreeDimensionOverride{ onnx::DATA_BATCH, 1 }, + FreeDimensionOverride{ onnx::DATA_CHANNEL, 42 }, + }; + + auto graph_transformer = std::make_unique(overrides); + + onnxruntime::GraphTransformerManager graph_transformation_mgr(5); + graph_transformation_mgr.Register(std::move(graph_transformer), TransformerLevel::Level1); + + graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1); + + // Verify that the shape of the input graph has the correct values + + const auto& graph_inputs = graph.GetInputs(); + ASSERT_TRUE(graph_inputs.size() == 1); // This model only has a single input ('x') + + const auto* input_shape = graph_inputs[0]->Shape(); + ASSERT_TRUE(input_shape->dim_size() == 3); // Model takes a 3D tensor as input; two of those dimensions are (were) free dimensions + + ASSERT_TRUE(input_shape->dim(0).denotation() == onnx::DATA_BATCH); + ASSERT_TRUE(input_shape->dim(0).has_dim_value()); + ASSERT_TRUE(input_shape->dim(0).dim_value() == 1); + + ASSERT_TRUE(input_shape->dim(1).denotation() == onnx::DATA_CHANNEL); + ASSERT_TRUE(input_shape->dim(1).has_dim_value()); + ASSERT_TRUE(input_shape->dim(1).dim_value() == 42); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index b554224b0b..5f7d0f747f 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -40,7 +40,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { std::string l2_rule1 = "ConvBNFusion"; std::vector custom_list = {l1_rule1, l1_transformer, l2_rule1}; - auto transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level1, custom_list); + auto transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level1, {}, custom_list); ASSERT_TRUE(transformers.size() == 2); auto l1_rule_transformer_name = transformer_utils::GenerateRuleBasedTransformerName(TransformerLevel::Level1); RuleBasedGraphTransformer* rule_transformer = nullptr; @@ -51,7 +51,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { } ASSERT_TRUE(rule_transformer && rule_transformer->RulesCount() == 1); - transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, custom_list); + transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, {}, custom_list); ASSERT_TRUE(transformers.size() == 1); rule_transformer = dynamic_cast(transformers[0].get()); ASSERT_TRUE(rule_transformer->RulesCount() == 1); diff --git a/onnxruntime/test/testdata/abs_free_dimensions.onnx b/onnxruntime/test/testdata/abs_free_dimensions.onnx new file mode 100644 index 0000000000..7d4c041a01 --- /dev/null +++ b/onnxruntime/test/testdata/abs_free_dimensions.onnx @@ -0,0 +1,14 @@ + backend-test:s + +xy"Abstest_absZ9 +x4 +2. +None +DATA_BATCH +None DATA_CHANNEL +b +y + +None +None +B \ No newline at end of file