mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Move Gelu and LayerNorm fusion to L1 optimization (#21332)
According to https://github.com/microsoft/onnxruntime/issues/20915, we move the Gelu and LayerNorm fusion to L1 with a condition on the ONNX opset the model imports (LayerNorm requires opset 16+ and Gelu requires opset 20+.) If the opset version doesn't meet the requirements, the fusion is delayed to L2 optimization since the internal contrib op doesn't have a requirement for any specific ONNX opset. --------- Co-authored-by: Scott McKay <Scott.McKay@microsoft.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
parent
de7a02beef
commit
2cdc05f189
11 changed files with 225 additions and 41 deletions
|
|
@ -44,6 +44,22 @@ static bool IsSupportedDataType(const Node& node) {
|
|||
[root]--> Gelu ==>
|
||||
*/
|
||||
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
const auto& version_map = graph.DomainToVersionMap();
|
||||
const auto& onnx_version = version_map.find(kOnnxDomain);
|
||||
// Gelu is an official ONNX operator as of opset 20, so we can fuse in level 1 if it is available
|
||||
const bool onnx_gelu_available = (onnx_version != version_map.end() && onnx_version->second >= 20);
|
||||
const bool fuse_in_level_1 = onnx_gelu_available || allow_contrib_op_in_level_1_;
|
||||
const auto op_domain = fuse_in_level_1 && onnx_gelu_available ? kOnnxDomain : kMSDomain;
|
||||
|
||||
if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
|
||||
// The following check assumes that there is a GeluFusion instance registered in Level1 that may have
|
||||
// already done this fusion, in which case we don't need to do it again.
|
||||
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto compatible_providers = GetCompatibleExecutionProviders();
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
@ -162,7 +178,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
"Gelu",
|
||||
"fused Gelu subgraphs ",
|
||||
gelu_input_defs,
|
||||
{}, {}, kMSDomain);
|
||||
{}, {}, op_domain);
|
||||
|
||||
// Assign provider to this new node. Provider should be same as the provider for old node.
|
||||
gelu_node.SetExecutionProviderType(div.GetExecutionProviderType());
|
||||
|
|
|
|||
|
|
@ -17,9 +17,26 @@ x * 0.5 * (1.0 + erf(x / sqrt(2.0))), where x is the input.
|
|||
|
||||
*/
|
||||
class GeluFusion : public GraphTransformer {
|
||||
private:
|
||||
TransformerLevel optimization_level_ = TransformerLevel::Level1;
|
||||
bool allow_contrib_op_in_level_1_ = false;
|
||||
std::string GetGeluFusionName(TransformerLevel level) {
|
||||
switch (level) {
|
||||
case TransformerLevel::Level1:
|
||||
return "GeluFusionL1";
|
||||
case TransformerLevel::Level2:
|
||||
return "GeluFusionL2";
|
||||
default:
|
||||
return "GeluFusion";
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("GeluFusion", compatible_execution_providers) {}
|
||||
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
|
||||
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
|
||||
: GraphTransformer(GetGeluFusionName(level), compatible_execution_providers),
|
||||
optimization_level_(level),
|
||||
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -235,6 +235,9 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
|
||||
session_options.free_dimension_overrides));
|
||||
|
||||
transformers.emplace_back(std::make_unique<GeluFusion>());
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>());
|
||||
|
||||
if (!disable_quant_qdq) {
|
||||
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());
|
||||
|
||||
|
|
@ -325,8 +328,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
|
||||
transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_rocm_acl_armnn_js_eps));
|
||||
|
||||
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps, level));
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps, level));
|
||||
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps));
|
||||
|
|
|
|||
|
|
@ -139,6 +139,21 @@ data are casted to float/double to calculate for precision, so if there is any C
|
|||
Such Cast Op can be the input of the sub-graph, or an Cast Op between the Div and Mul nodes.
|
||||
*/
|
||||
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
const auto& version_map = graph.DomainToVersionMap();
|
||||
const auto& onnx_version = version_map.find(kOnnxDomain);
|
||||
// LayerNorm is an official ONNX operator as of opset 17, so we can fuse in level 1 if it is available
|
||||
const bool onnx_layernorm_available = (onnx_version != version_map.end() && onnx_version->second >= 17);
|
||||
const bool fuse_in_level_1 = onnx_layernorm_available || allow_contrib_op_in_level_1_;
|
||||
|
||||
if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
|
||||
// The following check assumes that there is a LayerNormFusion instance registered in Level1 that may have
|
||||
// already done this fusion, in which case we don't need to do it again.
|
||||
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto compatible_providers = GetCompatibleExecutionProviders();
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove;
|
||||
|
|
|
|||
|
|
@ -17,9 +17,26 @@ The formula corresponding to LayerNorm activation subgraph:
|
|||
|
||||
*/
|
||||
class LayerNormFusion : public GraphTransformer {
|
||||
private:
|
||||
TransformerLevel optimization_level_ = TransformerLevel::Level1;
|
||||
bool allow_contrib_op_in_level_1_ = false;
|
||||
std::string GetLayerNormFusionName(TransformerLevel level) {
|
||||
switch (level) {
|
||||
case TransformerLevel::Level1:
|
||||
return "LayerNormFusionL1";
|
||||
case TransformerLevel::Level2:
|
||||
return "LayerNormFusionL2";
|
||||
default:
|
||||
return "LayerNormFusion";
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("LayerNormFusion", compatible_execution_providers) {}
|
||||
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
|
||||
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
|
||||
: GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers),
|
||||
optimization_level_(level),
|
||||
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -4434,7 +4434,11 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -4445,6 +4449,28 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
|
|||
ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, GeluFusionTest_Opset20) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_opset20.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Div"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Erf"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Gelu"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
@ -4452,7 +4478,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -4470,7 +4500,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -4488,7 +4522,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphInput) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -4506,8 +4544,12 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphOutput) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -4522,8 +4564,12 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
|
|||
|
|
@ -41,7 +41,11 @@ TEST_F(GraphTransformationTests, LayerNormFusionTest) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -79,7 +83,11 @@ TEST_F(GraphTransformationTests, TwoLayerNormShareSameInput) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -94,7 +102,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -115,7 +127,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_2) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -131,7 +147,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_3) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -147,7 +167,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_4) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -169,7 +193,11 @@ TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -290,9 +318,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
|
|||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
|
||||
1, pre_graph_checker, post_graph_checker));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
|
||||
std::unique_ptr<GraphTransformer> transformer_2 =
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
|
||||
TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_6) {
|
||||
|
|
@ -314,9 +347,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_6) {
|
|||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
|
||||
1, nullptr, post_graph_checker));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
|
||||
std::unique_ptr<GraphTransformer> transformer_2 =
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
|
||||
TransformerLevel::Level1, 1, nullptr, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
|
||||
TransformerLevel::Level2, 1, nullptr, post_graph_checker));
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_7) {
|
||||
|
|
@ -341,9 +379,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_7) {
|
|||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
|
||||
1, nullptr, post_graph_checker));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
|
||||
std::unique_ptr<GraphTransformer> transformer_2 =
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
|
||||
TransformerLevel::Level1, 1, nullptr, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
|
||||
TransformerLevel::Level2, 1, nullptr, post_graph_checker));
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_8) {
|
||||
|
|
@ -365,9 +408,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_8) {
|
|||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
|
||||
1, nullptr, post_graph_checker));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
|
||||
std::unique_ptr<GraphTransformer> transformer_2 =
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
|
||||
TransformerLevel::Level1, 1, nullptr, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
|
||||
TransformerLevel::Level2, 1, nullptr, post_graph_checker));
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_9) {
|
||||
|
|
@ -393,9 +441,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_9) {
|
|||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
|
||||
1, nullptr, post_graph_checker));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
std::unique_ptr<GraphTransformer> transformer_1 = std::make_unique<LayerNormFusion>();
|
||||
std::unique_ptr<GraphTransformer> transformer_2 =
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2);
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1),
|
||||
TransformerLevel::Level1, 1, nullptr, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2),
|
||||
TransformerLevel::Level2, 1, nullptr, post_graph_checker));
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
|
||||
|
|
@ -438,7 +491,11 @@ TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -529,8 +586,12 @@ static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_pat
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -579,8 +640,12 @@ static void TestSkipLayerNormFusionInputOutputCheck(const std::basic_string<ORTC
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger));
|
||||
|
||||
for (Node& node : graph.Nodes()) {
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/gelu_opset20.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/gelu_opset20.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -285,6 +285,7 @@ void IterateSubgraphFromNode(Graph& graph,
|
|||
PushAllOutputNode(graph, to_visit, cur, visited);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "FastGelu", {1}, kMSDomain) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {20}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "QuickGelu", {1}, kMSDomain) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sqrt", {6, 13})) {
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
// CSE will not merge them.
|
||||
transformers.emplace_back(std::make_unique<ConstantSharing>(compatible_eps));
|
||||
// LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input.
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps, level, true));
|
||||
// Remove duplicate nodes. Must be applied before any recompute transformations.
|
||||
if (config.gelu_recompute || config.attn_dropout_recompute || config.transformer_layer_recompute) {
|
||||
transformers.emplace_back(std::make_unique<CommonSubexpressionEliminationApplyOnce>(compatible_eps));
|
||||
|
|
@ -129,7 +129,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>(compatible_eps));
|
||||
}
|
||||
|
||||
transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps, level, true));
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps,
|
||||
true /* skip_device_check*/));
|
||||
|
|
|
|||
|
|
@ -919,9 +919,13 @@ TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluRecompute>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
|
|||
Loading…
Reference in a new issue