mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
* Fix #9671 by running the level 1 rewrite rules first and allowing the transpose optimizer to run multiple times to ensure it completes in level 1. Removed unnecessary call to GenerateRuleBasedGraphTransformer as there are no level 2 rewrite rules.
This commit is contained in:
parent
03f9d77e17
commit
724009289b
4 changed files with 29 additions and 19 deletions
|
|
@ -146,14 +146,23 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
|
||||
const std::unordered_set<std::string>& rules_and_transformers_to_disable) {
|
||||
std::vector<std::unique_ptr<GraphTransformer>> transformers;
|
||||
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
|
||||
bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
|
||||
bool disable_quant_qdq =
|
||||
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
bool enable_gelu_approximation = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1";
|
||||
bool enable_gelu_approximation =
|
||||
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1";
|
||||
#endif
|
||||
|
||||
switch (level) {
|
||||
case TransformerLevel::Level1: {
|
||||
// RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run)
|
||||
// so run them first so there is potentially less for the more intensive optimizations like ConstantFolding,
|
||||
// CommonSubexpressionElimination and TransposeOptimizer to do.
|
||||
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {});
|
||||
if (rule_transformer != nullptr) {
|
||||
transformers.emplace_back(std::move(rule_transformer));
|
||||
}
|
||||
|
||||
// no filtering on execution provider for L1 optimizations as they only use official ONNX operators
|
||||
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
|
||||
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq));
|
||||
|
|
@ -163,16 +172,11 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
session_options.free_dimension_overrides));
|
||||
auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault);
|
||||
transformers.emplace_back(std::make_unique<TransposeOptimizer>(std::move(cpu_allocator)));
|
||||
|
||||
rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {});
|
||||
} break;
|
||||
|
||||
case TransformerLevel::Level2: {
|
||||
std::unordered_set<std::string> cpu_ep = {onnxruntime::kCpuExecutionProvider};
|
||||
|
||||
// create rule based transformer consisting of all the level2 rewrite rules
|
||||
rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, cpu_ep);
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
const std::unordered_set<std::string> cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider,
|
||||
onnxruntime::kRocmExecutionProvider};
|
||||
|
|
@ -238,10 +242,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
ORT_THROW("Unsupported optimization level: ", static_cast<int>(level));
|
||||
}
|
||||
|
||||
if (rule_transformer != nullptr) {
|
||||
transformers.emplace_back(std::move(rule_transformer));
|
||||
}
|
||||
|
||||
FilterTransformers(transformers, rules_and_transformers_to_disable);
|
||||
|
||||
return transformers;
|
||||
|
|
|
|||
|
|
@ -21,10 +21,6 @@ class TransposeOptimizer : public GraphTransformer {
|
|||
: GraphTransformer("TransposeOptimizer"), cpu_allocator_(std::move(cpu_allocator)) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
// One run should be sufficient. Multiple runs should be ok but are prohibited to prevent any possibility of an
|
||||
// infinite loop.
|
||||
bool ShouldOnlyApplyOnce() const override { return true; }
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@
|
|||
#include "graph_transform_test_builder.h"
|
||||
|
||||
#include "core/graph/graph.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
|
||||
void SetNodeArgShape(NodeArg* node_arg, const std::optional<std::vector<int64_t>>& shape) {
|
||||
if (shape == std::nullopt) {
|
||||
node_arg->ClearShape();
|
||||
|
|
@ -77,7 +78,6 @@ int EstimateTransposeCost(const Graph& graph) {
|
|||
return cost;
|
||||
}
|
||||
|
||||
|
||||
TEST(TransposeOptimizerTests, TestSplit) {
|
||||
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
|
||||
auto* input0_arg = builder.MakeInput<float>({4, 6, 10}, 0.0, 1.0);
|
||||
|
|
@ -3849,8 +3849,22 @@ TEST(TransposeOptimizerTests, TestOmitIdentityTranspose) {
|
|||
/*opset_version*/ 15);
|
||||
}
|
||||
|
||||
// regression test for a model where the transpose optimizations were not completed in a single pass in level 1.
|
||||
// fixed by
|
||||
// a) moving the RewriteRule level 1 optimizations so they run prior to the transpose optimizer; and
|
||||
// b) not returning `true` from TransposeOptimizer::ShouldOnlyApplyOnce as it should be safe to run the
|
||||
// transpose optimizer multiple times to ensure it completes in level 1.
|
||||
// either of those changes would have fixed the issue.
|
||||
// see https://github.com/microsoft/onnxruntime/issues/9671 for more details.
|
||||
TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue9671) {
|
||||
auto model_uri = ORT_TSTR("testdata/gh_issue_9671.onnx");
|
||||
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue9671";
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.Load(model_uri));
|
||||
ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/gh_issue_9671.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/gh_issue_9671.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue