mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Fix fusion for two LayerNorm sharing same input but with different weights (#15919)
in gpt_j_residual(https://arxiv.org/pdf/2204.06745.pdf), there are 2 LN nodes will share one same input, and ORT does CSE graph optimization before LN fusion, which will modify the LN graph pattern and thus make LN fusion failure. 
This commit is contained in:
parent
5607a7151a
commit
4dc4470cc7
4 changed files with 88 additions and 1 deletions
|
|
@ -72,6 +72,21 @@ TEST_F(GraphTransformationTests, LayerNormFusionTest) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, TwoLayerNormShareSameInput) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_shared_input.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count.size() == 1);
|
||||
ASSERT_TRUE(op_to_count["LayerNormalization"] == 2);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.onnx
vendored
Normal file
Binary file not shown.
71
onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.py
vendored
Normal file
71
onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.py
vendored
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import onnx
|
||||
from onnx import OperatorSetIdProto, TensorProto, helper
|
||||
|
||||
|
||||
# in gpt_j_residual, there will be 2 LN share the same input
|
||||
def GenerateModel(model_name): # noqa: N802
|
||||
nodes = [
|
||||
# LN1 subgraph
|
||||
helper.make_node("ReduceMean", ["A"], ["LN1/rd1_out"], "LN1/reduce", axes=[-1]),
|
||||
helper.make_node("Sub", ["A", "LN1/rd1_out"], ["LN1/sub1_out"], "LN1/sub"),
|
||||
helper.make_node("Pow", ["LN1/sub1_out", "LN1/pow_in_2"], ["LN1/pow_out"], "LN1/pow"),
|
||||
helper.make_node("ReduceMean", ["LN1/pow_out"], ["LN1/rd2_out"], "LN1/reduce2", axes=[-1]),
|
||||
helper.make_node("Add", ["LN1/rd2_out", "LN1/const_0"], ["LN1/add1_out"], "LN1/add"),
|
||||
helper.make_node("Sqrt", ["LN1/add1_out"], ["LN1/sqrt_out"], "LN1/sqrt"),
|
||||
helper.make_node("Div", ["LN1/sub1_out", "LN1/sqrt_out"], ["LN1/div_out"], "LN1/div"),
|
||||
helper.make_node("Mul", ["LN1/gamma", "LN1/div_out"], ["LN1/mul_out"], "LN1/mul"),
|
||||
helper.make_node("Add", ["LN1/beta", "LN1/mul_out"], ["LN1/C"], "LN1/add2"),
|
||||
# LN2 subgraph
|
||||
helper.make_node("ReduceMean", ["A"], ["LN2/rd1_out"], "LN2/reduce", axes=[-1]),
|
||||
helper.make_node("Sub", ["A", "LN2/rd1_out"], ["LN2/sub1_out"], "LN2/sub"),
|
||||
helper.make_node("Pow", ["LN2/sub1_out", "LN2/pow_in_2"], ["LN2/pow_out"], "LN2/pow"),
|
||||
helper.make_node("ReduceMean", ["LN2/pow_out"], ["LN2/rd2_out"], "LN2/reduce2", axes=[-1]),
|
||||
helper.make_node("Add", ["LN2/rd2_out", "LN2/const_0"], ["LN2/add1_out"], "LN2/add"),
|
||||
helper.make_node("Sqrt", ["LN2/add1_out"], ["LN2/sqrt_out"], "LN2/sqrt"),
|
||||
helper.make_node("Div", ["LN2/sub1_out", "LN2/sqrt_out"], ["LN2/div_out"], "LN2/div"),
|
||||
helper.make_node("Mul", ["LN2/gamma", "LN2/div_out"], ["LN2/mul_out"], "LN2/mul"),
|
||||
helper.make_node("Add", ["LN2/beta", "LN2/mul_out"], ["LN2/C"], "LN2/add2"),
|
||||
]
|
||||
|
||||
initializers = [
|
||||
# LN1 initializers
|
||||
helper.make_tensor("LN1/pow_in_2", TensorProto.FLOAT, [], [2]),
|
||||
helper.make_tensor("LN1/const_0", TensorProto.FLOAT, [], [0]),
|
||||
helper.make_tensor("LN1/gamma", TensorProto.FLOAT, [4], [1, 2, 3, 4]),
|
||||
helper.make_tensor("LN1/beta", TensorProto.FLOAT, [4], [1, 2, 3, 4]),
|
||||
# LN2 initializers
|
||||
helper.make_tensor("LN2/pow_in_2", TensorProto.FLOAT, [], [2]),
|
||||
helper.make_tensor("LN2/const_0", TensorProto.FLOAT, [], [0]),
|
||||
helper.make_tensor("LN2/gamma", TensorProto.FLOAT, [4], [1, 2, 3, 4]),
|
||||
helper.make_tensor("LN2/beta", TensorProto.FLOAT, [4], [1, 2, 3, 4]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"2LayerNormShareSameInput", # name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info("A", TensorProto.FLOAT, [16, 32, 4]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info("LN1/C", TensorProto.FLOAT, [16, 32, 4]),
|
||||
helper.make_tensor_value_info("LN2/C", TensorProto.FLOAT, [16, 32, 4]),
|
||||
],
|
||||
initializers,
|
||||
)
|
||||
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
onnxdomain.domain = ""
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
opsets = [onnxdomain, msdomain]
|
||||
|
||||
model = helper.make_model(graph, opset_imports=opsets)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
|
||||
GenerateModel("layer_norm_shared_input.onnx")
|
||||
|
|
@ -109,6 +109,8 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
// CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
|
||||
// default, CSE will not merge them, because the different initializers are represented by different NodeArg.
|
||||
transformers.emplace_back(std::make_unique<ConstantSharing>(compatible_eps));
|
||||
// LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input.
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
|
||||
// Remove duplicate nodes. Must be applied before any recompute transformations.
|
||||
if (config.gelu_recompute || config.attn_dropout_recompute || config.transformer_layer_recompute) {
|
||||
transformers.emplace_back(std::make_unique<CommonSubexpressionEliminationApplyOnce>(compatible_eps));
|
||||
|
|
@ -117,7 +119,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
}
|
||||
|
||||
transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps,
|
||||
true /* skip_device_check*/));
|
||||
|
|
|
|||
Loading…
Reference in a new issue