mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Layer norm fusion deepspeed stage3 changes (#17614)
### Description <!-- Describe your changes. --> Layer norm fusion changes required for deepspeed stage 3, also includes test case. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> It helps fusing layer norm for Deepspeed Stage 3. Added a test case scenario which ensures that the fusion is working properly for the scenario.
This commit is contained in:
parent
f299016cbe
commit
d56fc7ebf5
4 changed files with 138 additions and 23 deletions
|
|
@ -414,20 +414,20 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
NodeArg* scale = nullptr;
|
||||
NodeArg* bias = nullptr;
|
||||
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
|
||||
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
|
||||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
|
||||
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
|
||||
scale = mul_node.MutableInputDefs()[i];
|
||||
}
|
||||
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
|
||||
scale = mul_node.MutableInputDefs()[i];
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) {
|
||||
if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) ||
|
||||
graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) {
|
||||
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
|
||||
bias = last_add_node.MutableInputDefs()[i];
|
||||
}
|
||||
if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
|
||||
bias = last_add_node.MutableInputDefs()[i];
|
||||
}
|
||||
}
|
||||
if (scale == nullptr || bias == nullptr) {
|
||||
|
|
@ -667,20 +667,20 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
// because SkipLayerNorm kernel, for example, has dependency on single dim size
|
||||
NodeArg* scale = nullptr;
|
||||
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
|
||||
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
|
||||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
|
||||
#ifdef ENABLE_TRAINING_CORE
|
||||
if (axes_values.empty() ||
|
||||
mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
|
||||
scale = mul_node.MutableInputDefs()[i];
|
||||
}
|
||||
#else
|
||||
// Scale must be 1d.
|
||||
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
|
||||
scale = mul_node.MutableInputDefs()[i];
|
||||
}
|
||||
#endif
|
||||
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
#ifdef ENABLE_TRAINING_CORE
|
||||
if (axes_values.empty() ||
|
||||
mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
|
||||
scale = mul_node.MutableInputDefs()[i];
|
||||
}
|
||||
#else
|
||||
// Scale must be 1d.
|
||||
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
|
||||
scale = mul_node.MutableInputDefs()[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
if (scale == nullptr) {
|
||||
|
|
|
|||
|
|
@ -429,6 +429,40 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
|
|||
}
|
||||
}
|
||||
|
||||
// It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph
|
||||
// To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly
|
||||
TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_fusion_scale_bias.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["ReduceMean"], 0);
|
||||
ASSERT_EQ(op_to_count["Sub"], 0);
|
||||
ASSERT_EQ(op_to_count["Cast"], 0);
|
||||
ASSERT_EQ(op_to_count["Pow"], 0);
|
||||
ASSERT_EQ(op_to_count["Add"], 0);
|
||||
ASSERT_EQ(op_to_count["Sqrt"], 0);
|
||||
ASSERT_EQ(op_to_count["Div"], 0);
|
||||
ASSERT_EQ(op_to_count["Mul"], 0);
|
||||
ASSERT_EQ(op_to_count["LayerNormalization"], 1);
|
||||
|
||||
for (const Node& node : graph.Nodes()) {
|
||||
if (node.OpType() == "LayerNormalization") {
|
||||
// LayerNormalization should have three inputs.
|
||||
EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size();
|
||||
// LayerNormalization input "scale" and "bias" should have the same dimension.
|
||||
const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape();
|
||||
EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization
|
||||
// doesn't support input and scale having different data types.
|
||||
TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) {
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx
vendored
Normal file
Binary file not shown.
81
onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py
vendored
Normal file
81
onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py
vendored
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import onnx
|
||||
from onnx import OperatorSetIdProto, TensorProto, helper
|
||||
|
||||
|
||||
def GenerateModel(model_name, has_casts=False, has_identity=False): # noqa: N802
|
||||
nodes = [ # LayerNorm subgraph
|
||||
helper.make_node("ReduceMean", ["A"], ["rd_out"], "reduce1", axes=[-1], keepdims=1),
|
||||
helper.make_node("Sub", ["A", "rd_out"], ["sub_out"], "sub"),
|
||||
helper.make_node("Pow", ["cast_sub_out" if has_casts else "sub_out", "pow_in_2"], ["pow_out"], "pow"),
|
||||
helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1),
|
||||
helper.make_node("Add", ["rd2_out", "const_e12_f32"], ["add1_out"], "add1"),
|
||||
helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"),
|
||||
helper.make_node("Div", ["cast_sub_out" if has_casts else "sub_out", "sqrt_out"], ["div_out"], "div"),
|
||||
helper.make_node(
|
||||
"Mul",
|
||||
["gamma_id_out" if has_identity else "gamma", "cast_div_out" if has_casts else "div_out"],
|
||||
["mul_out"],
|
||||
"mul",
|
||||
),
|
||||
helper.make_node("Add", ["mul_out", "const_e6_f16_out" if has_identity else "const_e6_f16"], ["C"], "add2"),
|
||||
]
|
||||
|
||||
if has_casts:
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Cast", ["sub_out"], ["cast_sub_out"], "cast_sub", to=1),
|
||||
helper.make_node("Cast", ["div_out"], ["cast_div_out"], "cast_2", to=10),
|
||||
]
|
||||
)
|
||||
|
||||
if has_identity:
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Identity", ["gamma"], ["gamma_id_out"], "gamma_identity"),
|
||||
helper.make_node("Identity", ["const_e6_f16"], ["const_e6_f16_out"], "const_e6_f16_identity"),
|
||||
]
|
||||
)
|
||||
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]),
|
||||
helper.make_tensor("const_e12_f32", TensorProto.FLOAT, [], [1e-12]),
|
||||
helper.make_tensor("const_e6_f16", TensorProto.FLOAT16, [4], [1e-6, 1e-6, 1e-6, 1e-6]),
|
||||
helper.make_tensor(
|
||||
"gamma",
|
||||
TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT,
|
||||
[4],
|
||||
[1, 2, 3, 4],
|
||||
),
|
||||
]
|
||||
|
||||
input_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT
|
||||
output_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"LayerNorm", # name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info("A", input_type, [16, 32, 4]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info("C", output_type, [16, 32, 4]),
|
||||
],
|
||||
initializers,
|
||||
)
|
||||
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
onnxdomain.domain = ""
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
opsets = [onnxdomain, msdomain]
|
||||
|
||||
model = helper.make_model(graph, opset_imports=opsets)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
|
||||
GenerateModel("layer_norm_fusion_scale_bias.onnx", True, True)
|
||||
Loading…
Reference in a new issue