mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Add session option configuration to enable GeluApproximation (#7131)
This commit is contained in:
parent
8e54b76e2d
commit
53c123dcee
4 changed files with 41 additions and 32 deletions
|
|
@ -47,6 +47,9 @@ static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.se
|
|||
// Its default value is "1"
|
||||
static const char* const kOrtSessionOptionsEnableQuantQDQ = "session.enable_quant_qdq";
|
||||
|
||||
// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
|
||||
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
|
||||
|
||||
// Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking
|
||||
// "0": thread will block if found no job to run
|
||||
// "1": default, thread will spin a number of times before blocking
|
||||
|
|
|
|||
|
|
@ -119,7 +119,11 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
const std::vector<std::string>& transformers_and_rules_to_enable) {
|
||||
std::vector<std::unique_ptr<GraphTransformer>> transformers;
|
||||
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
|
||||
bool enable_quant_qdq = session_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQ, "1") == "1";
|
||||
bool enable_quant_qdq = session_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQ, "1") == "1";
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
bool enable_gelu_approximation = session_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1";
|
||||
#endif
|
||||
|
||||
switch (level) {
|
||||
case TransformerLevel::Level1: {
|
||||
std::unordered_set<std::string> l1_execution_providers = {};
|
||||
|
|
@ -169,6 +173,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
|
||||
transformers.emplace_back(onnxruntime::make_unique<FastGeluFusion>(cpu_cuda_rocm_execution_providers));
|
||||
|
||||
if (enable_gelu_approximation){
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(cpu_cuda_rocm_execution_providers));
|
||||
}
|
||||
|
||||
transformers.emplace_back(onnxruntime::make_unique<MatMulScaleFusion>(cpu_cuda_rocm_execution_providers));
|
||||
#endif
|
||||
} break;
|
||||
|
|
@ -198,15 +206,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
return transformers;
|
||||
}
|
||||
|
||||
// Some transformers have side-effect like result is not exactly same.
|
||||
// These transformers could only be enabled by custom transformer list.
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
if (level == TransformerLevel::Level2) {
|
||||
std::unordered_set<std::string> cuda_rocm_execution_providers = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider};
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(cuda_rocm_execution_providers));
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<std::unique_ptr<GraphTransformer>> filtered_list;
|
||||
// If the rule-based transformer is not empty, it should be included in the custom transformer list below.
|
||||
if (rule_transformer != nullptr) {
|
||||
|
|
|
|||
|
|
@ -2324,6 +2324,35 @@ TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) {
|
|||
EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second;
|
||||
}
|
||||
|
||||
static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_options) {
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
|
||||
bool has_gelu_approximation = false;
|
||||
auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), {});
|
||||
for (auto& transformer : transformers) {
|
||||
if (transformer->Name() == "GeluApproximation") {
|
||||
has_gelu_approximation = true;
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(has_gelu_approximation, is_enabled);
|
||||
}
|
||||
|
||||
// Test session option configuration for GeluApproximation
|
||||
TEST_F(GraphTransformationTests, GeluApproximation_SessionOptionConfig) {
|
||||
SessionOptions session_options;
|
||||
|
||||
// GeluApproximation is not enabled by default.
|
||||
VerifyGeluApproximation(false, session_options);
|
||||
|
||||
session_options.AddConfigEntry(kOrtSessionOptionsEnableGeluApproximation, "1");
|
||||
VerifyGeluApproximation(true, session_options);
|
||||
|
||||
session_options.AddConfigEntry(kOrtSessionOptionsEnableGeluApproximation, "0");
|
||||
VerifyGeluApproximation(false, session_options);
|
||||
}
|
||||
|
||||
// Test Gelu -> FastGelu
|
||||
TEST_F(GraphTransformationTests, GeluApproximation_Gelu) {
|
||||
auto model_uri = MODEL_FOLDER "approximation/gelu.onnx";
|
||||
|
|
|
|||
|
|
@ -63,27 +63,5 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(GraphTransformerUtilsTests, TestCustomOnlyTransformers) {
|
||||
// Transformers that are disabled by default. They can only be enabled by custom list.
|
||||
std::string l2_transformer = "GeluApproximation";
|
||||
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
|
||||
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
|
||||
std::vector<std::string> default_list = {};
|
||||
auto default_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, *cpu_execution_provider.get(), default_list);
|
||||
for (auto& transformer : default_transformers) {
|
||||
ASSERT_TRUE(transformer->Name() != l2_transformer);
|
||||
}
|
||||
|
||||
std::vector<std::string> custom_list = {l2_transformer};
|
||||
auto custom_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, *cpu_execution_provider.get(), custom_list);
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
ASSERT_TRUE(custom_transformers.size() == 1);
|
||||
ASSERT_TRUE(custom_transformers[0]->Name() == l2_transformer);
|
||||
#else
|
||||
ASSERT_TRUE(custom_transformers.size() == 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue