SkipLayerNorm fusion with different input and output type (#15500)

SkipLayerNorm fusion fuses LayerNorm and one or more Add kernels now.
While LayerNormalization kernel allows different input and output type
by definition, SkipLayerNormalization must have the same input and
output type.

This graph is valid as the output of Add node is float16 and two inputs
from initializers are float.


![image](https://user-images.githubusercontent.com/35605090/231874079-3f3b03cc-f751-4ad9-a002-31116a35117f.png)

But, when Add and LayerNormalization are fused, it fails because two
inputs of Add node are float16 type and SkipLayerNormalization must have
the same input types. To avoid this failure, this PR adds Cast node
before inputs of SkipLayerNormalization when input and output type are
different and output type is float. The above graph is fused as follows,


![image](https://user-images.githubusercontent.com/35605090/231874097-6405713a-7c95-4b5b-a293-1305976edc94.png)

For performance, it'd better for SkipLayerNormalization to support
different input and output type, but this PR is to unblock Turing NLR v5
base mode in Babel. When we have more cases, we can support it.
This commit is contained in:
Sunghoon 2023-04-13 23:07:47 -07:00 committed by GitHub
parent d76cf374c4
commit fda0aa14c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 182 additions and 67 deletions

View file

@ -90,6 +90,36 @@ static bool CheckSecondAdd(Graph& graph, Node& add, ProviderType providertype) {
add_input1_shape->dim(2).dim_value() == add_input2_shape->dim(0).dim_value();
}
// Add a Cast to convert input from float16/bfloat16 to float when input type is different fromm output type
static NodeArg* CastToFloat(Graph& graph, NodeArg* input, int32_t output_data_type, ProviderType provider_type) {
if (nullptr == input->Type() ||
input->TypeAsProto()->tensor_type().elem_type() == output_data_type ||
output_data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return input;
}
auto input_shape = input->Shape();
TypeProto input_float;
input_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
for (auto i = 0; i < input_shape->dim_size(); ++i) {
auto dim = input_float.mutable_tensor_type()->mutable_shape()->add_dim();
*dim = input_shape->dim(i);
}
auto& cast_float = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(input->Name() + "_Float"), &input_float);
auto& node = graph.AddNode(graph.GenerateNodeName(input->Name() + "_Cast"),
"Cast",
"Cast Input to float",
std::array{input},
std::array{&cast_float},
nullptr,
kOnnxDomain);
node.AddAttribute("to", int64_t{ONNX_NAMESPACE::TensorProto_DataType_FLOAT});
node.SetExecutionProviderType(provider_type);
return &cast_float;
}
/**
Skip Layer Normalization will fuse Add + LayerNormalization into one node, and another Add if applicable
@ -243,6 +273,14 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
nodes_to_remove.push_back(*p_add1);
nodes_to_remove.push_back(ln_node);
// If input types are different than output type and output type is float, insert cast node after inputs.
for (auto& input_def: skip_layer_norm_input_defs) {
input_def = CastToFloat(graph,
input_def,
ln_node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(),
ln_node.GetExecutionProviderType());
}
Node& skip_layer_norm_node = graph.AddNode(graph.GenerateNodeName("SkipLayerNormalization"),
"SkipLayerNormalization",
"fused SkipLayerNorm subgraphs ",

View file

@ -4906,7 +4906,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTestCudaEp) {
}
static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_path, int add_count, int ln_count,
int skip_ln_count, logging::Logger* logger) {
int skip_ln_count, int cast_count, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();
@ -4925,43 +4925,57 @@ static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_pat
ASSERT_TRUE(op_to_count["Sqrt"] == 0);
ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count);
ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == skip_ln_count);
ASSERT_TRUE(op_to_count["Cast"] == cast_count);
}
TEST_F(GraphTransformationTests, SkipLayerNormFusionTest) {
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output.onnx", 1, 0, 1, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output.onnx", 1, 0, 1, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output.onnx", 1, 0, 1, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output.onnx", 1, 0, 1, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, 0, logger_.get());
}
TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/skip_layer_norm_input_output_check.onnx";
TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) {
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_with_cast.onnx", 0, 0, 1, 2, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx", 1, 1, 0, 0, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output_with_cast.onnx", 1, 1, 0, 0, logger_.get());
}
static void TestSkipLayerNormFusionInputOutputCheck(const std::basic_string<ORTCHAR_T>& model_uri, bool with_cast, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger));
for (Node& node : graph.Nodes()) {
if (node.OpType() == "SkipLayerNormalization") {
// check inputs
std::vector<NodeArg*>& input_defs = node.MutableInputDefs();
EXPECT_EQ(input_defs.size(), 5u) << "SkipLayerNormalization number of inputs does not equal to 5. Got:" << node.InputDefs().size();
EXPECT_EQ(input_defs[0]->Name(), "input.1");
EXPECT_EQ(input_defs[1]->Name(), "6");
EXPECT_EQ(input_defs[0]->Name(), ((with_cast) ? "input.1_Float" : "input.1"));
EXPECT_EQ(input_defs[1]->Name(), ((with_cast) ? "6_Float" : "6"));
EXPECT_EQ(input_defs[2]->Name(), "1");
EXPECT_EQ(input_defs[3]->Name(), "2");
EXPECT_EQ(input_defs[4]->Name(), "4");
EXPECT_EQ(input_defs[4]->Name(), ((with_cast) ? "4_Float" : "4"));
// check outputs
std::vector<NodeArg*>& output_defs = node.MutableOutputDefs();
@ -4971,26 +4985,38 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) {
EXPECT_EQ(node.OutputDefs().size(), 1u) << "SkipLayerNormalization number of outputs does not equal to 1. Got:" << node.OutputDefs().size();
#endif
EXPECT_EQ(output_defs[0]->Name(), "19");
} else if (node.OpType() == "Cast") {
EXPECT_TRUE(with_cast) << "Unexpected node: " << node.OpType() << "," << node.Name();
} else {
EXPECT_EQ(node.OpType(), "MatMul") << "Unexpected node: " << node.OpType() << "," << node.Name();
}
}
}
TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx";
TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) {
TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_check.onnx", false, logger_.get());
TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_with_cast_check.onnx", true, logger_.get());
}
static void TestSkipLayerNormFusionNoBeta(const std::basic_string<ORTCHAR_T>& model_uri, bool with_cast, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["LayerNormalization"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 1);
ASSERT_TRUE(op_to_count["Cast"] == ((with_cast) ? 2 : 0));
}
TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) {
TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx", false, logger_.get());
TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta_with_cast.onnx", true, logger_.get());
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) {

View file

@ -1,7 +1,7 @@
from enum import Enum
import onnx
from onnx import TensorProto, helper
from onnx import OperatorSetIdProto, TensorProto, helper
class Format(Enum):
@ -10,19 +10,36 @@ class Format(Enum):
Format3 = 3
def GenerateModel(format, model_name, multi_output_add=False, add_output_in_graph_output=False): # noqa: N802
nodes = [ # LayerNorm subgraph
helper.make_node("ReduceMean", ["ln_in"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1),
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb1_out"], "sub1"),
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb2_out"], "sub2"),
helper.make_node("Pow", ["sb2_out", "pow_in_2"], ["pow_out"], "pow"),
helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1),
helper.make_node("Add", ["rd2_out", "const_e12"], ["add1_out"], "add1"),
helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"),
helper.make_node("Div", ["sb1_out", "sqrt_out"], ["div_out"], "div1"),
helper.make_node("Mul", ["gamma", "div_out"], ["mul_out"], "mul"),
helper.make_node("Add", ["mul_out", "beta"], ["C"], "add0"),
]
def generate_model(model_format, model_name, multi_output_add=False, add_output_in_graph_output=False, with_cast=False):
nodes = [] # LayerNorm subgraph
if with_cast:
nodes.extend(
[
helper.make_node("Cast", ["ln_in"], ["c_out"], "cast", to=1),
helper.make_node("ReduceMean", ["c_out"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1),
helper.make_node("Sub", ["c_out", "rd1_out"], ["sb1_out"], "sub1"),
helper.make_node("Sub", ["c_out", "rd1_out"], ["sb2_out"], "sub2"),
]
)
else:
nodes.extend(
[
helper.make_node("ReduceMean", ["ln_in"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1),
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb1_out"], "sub1"),
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb2_out"], "sub2"),
]
)
nodes.extend(
[ # LayerNorm subgraph
helper.make_node("Pow", ["sb2_out", "pow_in_2"], ["pow_out"], "pow"),
helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1),
helper.make_node("Add", ["rd2_out", "const_e12"], ["add1_out"], "add1"),
helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"),
helper.make_node("Div", ["sb1_out", "sqrt_out"], ["div_out"], "div1"),
helper.make_node("Mul", ["gamma", "div_out"], ["mul_out"], "mul"),
helper.make_node("Add", ["mul_out", "beta"], ["C"], "add0"),
]
)
initializers = [ # initializers
helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]),
@ -31,7 +48,7 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
helper.make_tensor("beta", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
]
if format is Format.Format1:
if model_format is Format.Format1:
nodes.extend(
[
helper.make_node("Add", ["A", "bias"], ["add3_out"], "add3"),
@ -40,10 +57,12 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
)
initializers.extend(
[
helper.make_tensor("bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor(
"bias", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]
),
]
)
elif format is Format.Format2:
elif model_format is Format.Format2:
nodes.extend(
[
helper.make_node("Add", ["B", "bias"], ["add3_out"], "add3"),
@ -52,10 +71,12 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
)
initializers.extend(
[
helper.make_tensor("bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor(
"bias", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]
),
]
)
elif format is Format.Format3:
elif model_format is Format.Format3:
nodes.extend(
[
helper.make_node("Add", ["A", "B"], ["ln_in"], "add2"),
@ -63,15 +84,15 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
)
if multi_output_add:
neg_input = "ln_in" if format is Format.Format3 else "add3_out"
neg_input = "ln_in" if model_format is Format.Format3 else "add3_out"
nodes.extend([helper.make_node("Neg", [neg_input], ["neg_out"], "neg")])
graph = helper.make_graph(
nodes,
"SkipLayerNorm_format3", # name
[ # inputs
helper.make_tensor_value_info("A", TensorProto.FLOAT, [16, 32, 4]),
helper.make_tensor_value_info("B", TensorProto.FLOAT, [16, 32, 4]),
helper.make_tensor_value_info("A", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [16, 32, 4]),
helper.make_tensor_value_info("B", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [16, 32, 4]),
],
[ # outputs
helper.make_tensor_value_info("C", TensorProto.FLOAT, [16, 32, 4]),
@ -80,32 +101,62 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
)
if add_output_in_graph_output:
extra_output = "ln_in" if format is Format.Format3 else "add3_out"
graph.output.extend([helper.make_tensor_value_info(extra_output, TensorProto.FLOAT, [16, 32, 4])])
extra_output = "ln_in" if model_format is Format.Format3 else "add3_out"
graph.output.extend(
[
helper.make_tensor_value_info(
extra_output, TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [16, 32, 4]
)
]
)
model = helper.make_model(graph)
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
onnxdomain.domain = ""
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets = [onnxdomain, msdomain]
model = helper.make_model(graph, opset_imports=opsets)
onnx.save(model, model_name)
GenerateModel(Format.Format1, "skip_layer_norm_format1.onnx")
GenerateModel(Format.Format2, "skip_layer_norm_format2.onnx")
GenerateModel(Format.Format3, "skip_layer_norm_format3.onnx")
GenerateModel(Format.Format1, "skip_layer_norm_format1_partial.onnx", multi_output_add=True)
GenerateModel(Format.Format2, "skip_layer_norm_format2_partial.onnx", multi_output_add=True)
GenerateModel(Format.Format3, "skip_layer_norm_format3_no_fusion.onnx", multi_output_add=True)
def generate_skip_layer_norm(with_cast=False):
suffix = "_with_cast" if with_cast else ""
GenerateModel(
Format.Format1,
"skip_layer_norm_format1_graph_output.onnx",
add_output_in_graph_output=True,
)
GenerateModel(
Format.Format2,
"skip_layer_norm_format2_graph_output.onnx",
add_output_in_graph_output=True,
)
GenerateModel(
Format.Format3,
"skip_layer_norm_format3_graph_output.onnx",
add_output_in_graph_output=True,
)
generate_model(Format.Format1, f"skip_layer_norm_format1{suffix}.onnx", with_cast=with_cast)
generate_model(Format.Format2, f"skip_layer_norm_format2{suffix}.onnx", with_cast=with_cast)
generate_model(Format.Format3, f"skip_layer_norm_format3{suffix}.onnx", with_cast=with_cast)
generate_model(
Format.Format1, f"skip_layer_norm_format1_partial{suffix}.onnx", multi_output_add=True, with_cast=with_cast
)
generate_model(
Format.Format2, f"skip_layer_norm_format2_partial{suffix}.onnx", multi_output_add=True, with_cast=with_cast
)
generate_model(
Format.Format3, f"skip_layer_norm_format3_no_fusion{suffix}.onnx", multi_output_add=True, with_cast=with_cast
)
generate_model(
Format.Format1,
f"skip_layer_norm_format1_graph_output{suffix}.onnx",
add_output_in_graph_output=True,
with_cast=with_cast,
)
generate_model(
Format.Format2,
f"skip_layer_norm_format2_graph_output{suffix}.onnx",
add_output_in_graph_output=True,
with_cast=with_cast,
)
generate_model(
Format.Format3,
f"skip_layer_norm_format3_graph_output{suffix}.onnx",
add_output_in_graph_output=True,
with_cast=with_cast,
)
generate_skip_layer_norm(with_cast=False)
generate_skip_layer_norm(with_cast=True)