Move Gelu and LayerNorm fusion to L1 optimization (#21332)

According to https://github.com/microsoft/onnxruntime/issues/20915, we
move the Gelu and LayerNorm fusion to L1 with a condition on the ONNX
opset the model imports (LayerNorm requires opset 16+ and Gelu requires
opset 20+.) If the opset version doesn't meet the requirements, the
fusion is delayed to L2 optimization since the internal contrib op
doesn't have a requirement for any specific ONNX opset.

---------

Co-authored-by: Scott McKay <Scott.McKay@microsoft.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
Peishen Yan 2024-09-09 11:27:52 +08:00 committed by GitHub
parent de7a02beef
commit 2cdc05f189
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 225 additions and 41 deletions

View file

@ -44,6 +44,22 @@ static bool IsSupportedDataType(const Node& node) {
[root]--> Gelu ==>
*/
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// Gelu is an official ONNX operator as of opset 20, so we can fuse in level 1 if it is available
const bool onnx_gelu_available = (onnx_version != version_map.end() && onnx_version->second >= 20);
const bool fuse_in_level_1 = onnx_gelu_available || allow_contrib_op_in_level_1_;
const auto op_domain = fuse_in_level_1 && onnx_gelu_available ? kOnnxDomain : kMSDomain;
if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
// The following check assumes that there is a GeluFusion instance registered in Level1 that may have
// already done this fusion, in which case we don't need to do it again.
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
return Status::OK();
}
const auto compatible_providers = GetCompatibleExecutionProviders();
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
@ -162,7 +178,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
"Gelu",
"fused Gelu subgraphs ",
gelu_input_defs,
{}, {}, kMSDomain);
{}, {}, op_domain);
// Assign provider to this new node. Provider should be same as the provider for old node.
gelu_node.SetExecutionProviderType(div.GetExecutionProviderType());

View file

@ -17,9 +17,26 @@ x * 0.5 * (1.0 + erf(x / sqrt(2.0))), where x is the input.
*/
class GeluFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
bool allow_contrib_op_in_level_1_ = false;
std::string GetGeluFusionName(TransformerLevel level) {
switch (level) {
case TransformerLevel::Level1:
return "GeluFusionL1";
case TransformerLevel::Level2:
return "GeluFusionL2";
default:
return "GeluFusion";
}
}
public:
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GeluFusion", compatible_execution_providers) {}
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
: GraphTransformer(GetGeluFusionName(level), compatible_execution_providers),
optimization_level_(level),
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};

View file

@ -235,6 +235,9 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
session_options.free_dimension_overrides));
transformers.emplace_back(std::make_unique<GeluFusion>());
transformers.emplace_back(std::make_unique<LayerNormFusion>());
if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());
@ -325,8 +328,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_rocm_acl_armnn_js_eps));
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps));

View file

@ -139,6 +139,21 @@ data are casted to float/double to calculate for precision, so if there is any C
Such Cast Op can be the input of the sub-graph, or an Cast Op between the Div and Mul nodes.
*/
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// LayerNorm is an official ONNX operator as of opset 17, so we can fuse in level 1 if it is available
const bool onnx_layernorm_available = (onnx_version != version_map.end() && onnx_version->second >= 17);
const bool fuse_in_level_1 = onnx_layernorm_available || allow_contrib_op_in_level_1_;
if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
// The following check assumes that there is a LayerNormFusion instance registered in Level1 that may have
// already done this fusion, in which case we don't need to do it again.
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
return Status::OK();
}
const auto compatible_providers = GetCompatibleExecutionProviders();
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove;

View file

@ -17,9 +17,26 @@ The formula corresponding to LayerNorm activation subgraph:
*/
class LayerNormFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
bool allow_contrib_op_in_level_1_ = false;
std::string GetLayerNormFusionName(TransformerLevel level) {
switch (level) {
case TransformerLevel::Level1:
return "LayerNormFusionL1";
case TransformerLevel::Level2:
return "LayerNormFusionL2";
default:
return "LayerNormFusion";
}
}
public:
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("LayerNormFusion", compatible_execution_providers) {}
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
: GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers),
optimization_level_(level),
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};

View file

@ -4434,7 +4434,11 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -4445,6 +4449,28 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 1);
}
TEST_F(GraphTransformationTests, GeluFusionTest_Opset20) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_opset20.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}
TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx";
std::shared_ptr<Model> p_model;
@ -4452,7 +4478,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -4470,7 +4500,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -4488,7 +4522,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphInput) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -4506,8 +4544,12 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphOutput) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -4522,8 +4564,12 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);

View file

@ -41,7 +41,11 @@ TEST_F(GraphTransformationTests, LayerNormFusionTest) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -79,7 +83,11 @@ TEST_F(GraphTransformationTests, TwoLayerNormShareSameInput) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -94,7 +102,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -115,7 +127,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_2) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -131,7 +147,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_3) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -147,7 +167,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_4) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -169,7 +193,11 @@ TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -290,9 +318,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
std::unique_ptr<GraphTransformer> transformer_2 =
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker));
}
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_6) {
@ -314,9 +347,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_6) {
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
1, nullptr, post_graph_checker));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
std::unique_ptr<GraphTransformer> transformer_2 =
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
TransformerLevel::Level1, 1, nullptr, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
TransformerLevel::Level2, 1, nullptr, post_graph_checker));
}
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_7) {
@ -341,9 +379,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_7) {
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
1, nullptr, post_graph_checker));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
std::unique_ptr<GraphTransformer> transformer_2 =
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
TransformerLevel::Level1, 1, nullptr, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
TransformerLevel::Level2, 1, nullptr, post_graph_checker));
}
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_8) {
@ -365,9 +408,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_8) {
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
1, nullptr, post_graph_checker));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
std::unique_ptr<GraphTransformer> transformer_2 =
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
TransformerLevel::Level1, 1, nullptr, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
TransformerLevel::Level2, 1, nullptr, post_graph_checker));
}
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_9) {
@ -393,9 +441,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_9) {
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
1, nullptr, post_graph_checker));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
std::unique_ptr<GraphTransformer> transformer_2 =
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
TransformerLevel::Level1, 1, nullptr, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
TransformerLevel::Level2, 1, nullptr, post_graph_checker));
}
TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
@ -438,7 +491,11 @@ TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -529,8 +586,12 @@ static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_pat
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -579,8 +640,12 @@ static void TestSkipLayerNormFusionInputOutputCheck(const std::basic_string<ORTC
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger));
for (Node& node : graph.Nodes()) {

Binary file not shown.

View file

@ -285,6 +285,7 @@ void IterateSubgraphFromNode(Graph& graph,
PushAllOutputNode(graph, to_visit, cur, visited);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "FastGelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {20}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "QuickGelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sqrt", {6, 13})) {

View file

@ -121,7 +121,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
// CSE will not merge them.
transformers.emplace_back(std::make_unique<ConstantSharing>(compatible_eps));
// LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input.
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps, level, true));
// Remove duplicate nodes. Must be applied before any recompute transformations.
if (config.gelu_recompute || config.attn_dropout_recompute || config.transformer_layer_recompute) {
transformers.emplace_back(std::make_unique<CommonSubexpressionEliminationApplyOnce>(compatible_eps));
@ -129,7 +129,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>(compatible_eps));
}
transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps, level, true));
#if defined(USE_CUDA) || defined(USE_ROCM)
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps,
true /* skip_device_check*/));

View file

@ -919,9 +919,13 @@ TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluRecompute>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);