mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Implement L1 graph transformer for free dimension override (#1825)
* Implement FreeDimensionOverrideTransformer * Add test * Fix compiler warnings * Update comment * LOGS_DEFAULT * Merge from master
This commit is contained in:
parent
561f2c4a9a
commit
a7beed798e
9 changed files with 213 additions and 3 deletions
|
|
@ -3,11 +3,14 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gsl/span>
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
struct FreeDimensionOverride;
|
||||
|
||||
namespace transformer_utils {
|
||||
|
||||
|
|
@ -21,6 +24,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
|
|||
If transformers_and_rules_to_enable is not empty, it returns the intersection between the predefined transformers/rules
|
||||
and the transformers_and_rules_to_enable. */
|
||||
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerLevel level,
|
||||
gsl::span<const FreeDimensionOverride> free_dimension_overrides,
|
||||
const std::vector<std::string>& rules_and_transformers_to_enable = {});
|
||||
|
||||
/** Given a TransformerLevel, this method generates a name for the rule-based graph transformer of that level. */
|
||||
|
|
|
|||
86
onnxruntime/core/optimizer/free_dim_override_transformer.cc
Normal file
86
onnxruntime/core/optimizer/free_dim_override_transformer.cc
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/free_dim_override_transformer.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
static std::string ToLower(std::string s) {
|
||||
std::transform(s.begin(), s.end(), s.begin(), [](char c) {
|
||||
return static_cast<char>(::tolower(c));
|
||||
});
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
/*explicit*/ FreeDimensionOverrideTransformer::FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply)
|
||||
: GraphTransformer("FreeDimensionOverrideTransformer") {
|
||||
for (const auto& o : overrides_to_apply) {
|
||||
// Convert to lowercase to perform case-insensitive comparisons later
|
||||
std::string denotation = ToLower(o.dimension_denotation);
|
||||
|
||||
dimension_override_by_denotation_.emplace(denotation, o.dimension_override);
|
||||
}
|
||||
}
|
||||
|
||||
Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/) const {
|
||||
for (const onnxruntime::NodeArg* graph_input : graph.GetInputs()) {
|
||||
// Get the current input's type and shape
|
||||
const auto* input_type = graph_input->TypeAsProto();
|
||||
const auto* input_shape = graph_input->Shape();
|
||||
|
||||
if (!input_type || !input_shape || !input_type->has_tensor_type()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Construct a new shape for this input, replacing free dimensions with their overrides
|
||||
onnx::TensorShapeProto new_shape;
|
||||
for (int32_t dim_index = 0; dim_index < input_shape->dim_size(); ++dim_index) {
|
||||
const auto& dimension = input_shape->dim(dim_index);
|
||||
|
||||
// By default just make a copy of the dimension
|
||||
auto* new_dimension = new_shape.add_dim();
|
||||
*new_dimension = dimension;
|
||||
|
||||
if (dimension.has_denotation()) {
|
||||
// Convert to lowercase to perform case-insensitive comparison
|
||||
auto it = dimension_override_by_denotation_.find(ToLower(dimension.denotation()));
|
||||
if (it == dimension_override_by_denotation_.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t dimension_override = it->second;
|
||||
|
||||
// If this dimension actually has a value but it doesn't match the override value, return an
|
||||
// error.
|
||||
if (dimension.has_dim_value() && dimension.dim_value() != dimension_override) {
|
||||
LOGS_DEFAULT(ERROR) << "The model has input '" << graph_input->Name() << "' "
|
||||
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
|
||||
<< "but the size of this dimension " << dimension.dim_value() << " "
|
||||
<< "does not equal the specified override of" << dimension_override << ".";
|
||||
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
|
||||
}
|
||||
|
||||
// Set the dimension override
|
||||
new_dimension->clear_dim_param();
|
||||
new_dimension->set_dim_value(dimension_override);
|
||||
}
|
||||
}
|
||||
|
||||
// Set the new shape
|
||||
auto* mutable_graph_input = graph.GetNodeArg(graph_input->Name());
|
||||
assert(mutable_graph_input != nullptr);
|
||||
mutable_graph_input->SetShape(new_shape);
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
31
onnxruntime/core/optimizer/free_dim_override_transformer.h
Normal file
31
onnxruntime/core/optimizer/free_dim_override_transformer.h
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gsl/span>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct FreeDimensionOverride;
|
||||
|
||||
/**
|
||||
@Class FreeDimensionOverrideTransformer
|
||||
|
||||
Transformer that overrides free dimensions in the graph with the specific value
|
||||
that matches the denotation for that dimension.
|
||||
*/
|
||||
class FreeDimensionOverrideTransformer : public GraphTransformer {
|
||||
public:
|
||||
explicit FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply);
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
|
||||
std::map<std::string, int64_t> dimension_override_by_denotation_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -17,7 +17,9 @@
|
|||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/optimizer/nchwc_transformer.h"
|
||||
#include "core/optimizer/free_dim_override_transformer.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -87,6 +89,7 @@ std::unique_ptr<RuleBasedGraphTransformer> GenerateRuleBasedGraphTransformer(Tra
|
|||
}
|
||||
|
||||
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerLevel level,
|
||||
gsl::span<const FreeDimensionOverride> free_dimension_overrides,
|
||||
const std::vector<std::string>& transformers_and_rules_to_enable) {
|
||||
std::vector<std::unique_ptr<GraphTransformer>> transformers;
|
||||
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
|
||||
|
|
@ -95,6 +98,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
std::unordered_set<std::string> l1_execution_providers = {};
|
||||
|
||||
transformers.emplace_back(std::make_unique<ConstantFolding>(l1_execution_providers));
|
||||
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(free_dimension_overrides));
|
||||
|
||||
rule_transformer = GenerateRuleBasedGraphTransformer(level, transformers_and_rules_to_enable, l1_execution_providers);
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -1037,7 +1037,7 @@ void InferenceSession::AddPredefinedTransformers(GraphTransformerManager& transf
|
|||
const std::vector<std::string>& custom_list) {
|
||||
auto add_transformers = [&](TransformerLevel level) {
|
||||
// Generate and register transformers for level
|
||||
auto transformers_to_register = transformer_utils::GenerateTransformers(level, custom_list);
|
||||
auto transformers_to_register = transformer_utils::GenerateTransformers(level, session_options_.free_dimension_overrides, custom_list);
|
||||
for (auto& entry : transformers_to_register) {
|
||||
transformer_manager.Register(std::move(entry), level);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,6 +46,12 @@ namespace logging {
|
|||
class LoggingManager;
|
||||
}
|
||||
|
||||
struct FreeDimensionOverride
|
||||
{
|
||||
std::string dimension_denotation;
|
||||
int64_t dimension_override;
|
||||
};
|
||||
|
||||
/**
|
||||
* Configuration information for a session.
|
||||
*/
|
||||
|
|
@ -93,6 +99,10 @@ struct SessionOptions {
|
|||
// controls the size of the thread pool used to parallelize the execution of nodes (ops)
|
||||
// configuring this makes sense only when you're using parallel executor
|
||||
int inter_op_num_threads = 0;
|
||||
|
||||
// For models with free input dimensions (most commonly batch size), specifies a set of values to override those
|
||||
// free dimensions with, keyed by dimension denotation.
|
||||
std::vector<FreeDimensionOverride> free_dimension_overrides;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
61
onnxruntime/test/optimizer/free_dimension_override_test.cc
Normal file
61
onnxruntime/test/optimizer/free_dimension_override_test.cc
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/graph_transformer_mgr.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "core/optimizer/free_dim_override_transformer.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(FreeDimensionOverrideTransformerTest, Test) {
|
||||
string model_uri = "testdata/abs_free_dimensions.onnx";
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
|
||||
// The model's input shape has two free dimensions, which have the denotation of DATA_BATCH
|
||||
// and DATA_CHANNEL. Supplying these overrides to the transformer should replace those free
|
||||
// dimensions with values of 1 and 42, respectively.
|
||||
std::vector<FreeDimensionOverride> overrides =
|
||||
{
|
||||
FreeDimensionOverride{ onnx::DATA_BATCH, 1 },
|
||||
FreeDimensionOverride{ onnx::DATA_CHANNEL, 42 },
|
||||
};
|
||||
|
||||
auto graph_transformer = std::make_unique<FreeDimensionOverrideTransformer>(overrides);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr(5);
|
||||
graph_transformation_mgr.Register(std::move(graph_transformer), TransformerLevel::Level1);
|
||||
|
||||
graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1);
|
||||
|
||||
// Verify that the shape of the input graph has the correct values
|
||||
|
||||
const auto& graph_inputs = graph.GetInputs();
|
||||
ASSERT_TRUE(graph_inputs.size() == 1); // This model only has a single input ('x')
|
||||
|
||||
const auto* input_shape = graph_inputs[0]->Shape();
|
||||
ASSERT_TRUE(input_shape->dim_size() == 3); // Model takes a 3D tensor as input; two of those dimensions are (were) free dimensions
|
||||
|
||||
ASSERT_TRUE(input_shape->dim(0).denotation() == onnx::DATA_BATCH);
|
||||
ASSERT_TRUE(input_shape->dim(0).has_dim_value());
|
||||
ASSERT_TRUE(input_shape->dim(0).dim_value() == 1);
|
||||
|
||||
ASSERT_TRUE(input_shape->dim(1).denotation() == onnx::DATA_CHANNEL);
|
||||
ASSERT_TRUE(input_shape->dim(1).has_dim_value());
|
||||
ASSERT_TRUE(input_shape->dim(1).dim_value() == 42);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -40,7 +40,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
|
|||
std::string l2_rule1 = "ConvBNFusion";
|
||||
std::vector<std::string> custom_list = {l1_rule1, l1_transformer, l2_rule1};
|
||||
|
||||
auto transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level1, custom_list);
|
||||
auto transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level1, {}, custom_list);
|
||||
ASSERT_TRUE(transformers.size() == 2);
|
||||
auto l1_rule_transformer_name = transformer_utils::GenerateRuleBasedTransformerName(TransformerLevel::Level1);
|
||||
RuleBasedGraphTransformer* rule_transformer = nullptr;
|
||||
|
|
@ -51,7 +51,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
|
|||
}
|
||||
ASSERT_TRUE(rule_transformer && rule_transformer->RulesCount() == 1);
|
||||
|
||||
transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, custom_list);
|
||||
transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, {}, custom_list);
|
||||
ASSERT_TRUE(transformers.size() == 1);
|
||||
rule_transformer = dynamic_cast<RuleBasedGraphTransformer*>(transformers[0].get());
|
||||
ASSERT_TRUE(rule_transformer->RulesCount() == 1);
|
||||
|
|
|
|||
14
onnxruntime/test/testdata/abs_free_dimensions.onnx
vendored
Normal file
14
onnxruntime/test/testdata/abs_free_dimensions.onnx
vendored
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
backend-test:s
|
||||
|
||||
xy"Abstest_absZ9
|
||||
x4
|
||||
2.
|
||||
None
|
||||
DATA_BATCH
|
||||
NoneDATA_CHANNEL
|
||||
b
|
||||
y
|
||||
|
||||
None
|
||||
None
|
||||
B
|
||||
Loading…
Reference in a new issue