mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
SkipLayerNorm fusion with different input and output type (#15500)
SkipLayerNorm fusion fuses LayerNorm and one or more Add kernels now. While LayerNormalization kernel allows different input and output type by definition, SkipLayerNormalization must have the same input and output type. This graph is valid as the output of Add node is float16 and two inputs from initializers are float.  But, when Add and LayerNormalization are fused, it fails because two inputs of Add node are float16 type and SkipLayerNormalization must have the same input types. To avoid this failure, this PR adds Cast node before inputs of SkipLayerNormalization when input and output type are different and output type is float. The above graph is fused as follows,  For performance, it'd better for SkipLayerNormalization to support different input and output type, but this PR is to unblock Turing NLR v5 base mode in Babel. When we have more cases, we can support it.
This commit is contained in:
parent
d76cf374c4
commit
fda0aa14c8
14 changed files with 182 additions and 67 deletions
|
|
@ -90,6 +90,36 @@ static bool CheckSecondAdd(Graph& graph, Node& add, ProviderType providertype) {
|
|||
add_input1_shape->dim(2).dim_value() == add_input2_shape->dim(0).dim_value();
|
||||
}
|
||||
|
||||
// Add a Cast to convert input from float16/bfloat16 to float when input type is different fromm output type
|
||||
static NodeArg* CastToFloat(Graph& graph, NodeArg* input, int32_t output_data_type, ProviderType provider_type) {
|
||||
if (nullptr == input->Type() ||
|
||||
input->TypeAsProto()->tensor_type().elem_type() == output_data_type ||
|
||||
output_data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
return input;
|
||||
}
|
||||
|
||||
auto input_shape = input->Shape();
|
||||
TypeProto input_float;
|
||||
input_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
for (auto i = 0; i < input_shape->dim_size(); ++i) {
|
||||
auto dim = input_float.mutable_tensor_type()->mutable_shape()->add_dim();
|
||||
*dim = input_shape->dim(i);
|
||||
}
|
||||
auto& cast_float = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(input->Name() + "_Float"), &input_float);
|
||||
|
||||
auto& node = graph.AddNode(graph.GenerateNodeName(input->Name() + "_Cast"),
|
||||
"Cast",
|
||||
"Cast Input to float",
|
||||
std::array{input},
|
||||
std::array{&cast_float},
|
||||
nullptr,
|
||||
kOnnxDomain);
|
||||
|
||||
node.AddAttribute("to", int64_t{ONNX_NAMESPACE::TensorProto_DataType_FLOAT});
|
||||
node.SetExecutionProviderType(provider_type);
|
||||
return &cast_float;
|
||||
}
|
||||
|
||||
/**
|
||||
Skip Layer Normalization will fuse Add + LayerNormalization into one node, and another Add if applicable
|
||||
|
||||
|
|
@ -243,6 +273,14 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
nodes_to_remove.push_back(*p_add1);
|
||||
nodes_to_remove.push_back(ln_node);
|
||||
|
||||
// If input types are different than output type and output type is float, insert cast node after inputs.
|
||||
for (auto& input_def: skip_layer_norm_input_defs) {
|
||||
input_def = CastToFloat(graph,
|
||||
input_def,
|
||||
ln_node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(),
|
||||
ln_node.GetExecutionProviderType());
|
||||
}
|
||||
|
||||
Node& skip_layer_norm_node = graph.AddNode(graph.GenerateNodeName("SkipLayerNormalization"),
|
||||
"SkipLayerNormalization",
|
||||
"fused SkipLayerNorm subgraphs ",
|
||||
|
|
|
|||
|
|
@ -4906,7 +4906,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTestCudaEp) {
|
|||
}
|
||||
|
||||
static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_path, int add_count, int ln_count,
|
||||
int skip_ln_count, logging::Logger* logger) {
|
||||
int skip_ln_count, int cast_count, logging::Logger* logger) {
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
|
@ -4925,43 +4925,57 @@ static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_pat
|
|||
ASSERT_TRUE(op_to_count["Sqrt"] == 0);
|
||||
ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == skip_ln_count);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == cast_count);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, SkipLayerNormFusionTest) {
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, 0, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, 0, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1, 0, logger_.get());
|
||||
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1, 0, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1, 0, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0, 0, logger_.get());
|
||||
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output.onnx", 1, 0, 1, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output.onnx", 1, 0, 1, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output.onnx", 1, 0, 1, 0, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output.onnx", 1, 0, 1, 0, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, 0, logger_.get());
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/skip_layer_norm_input_output_check.onnx";
|
||||
TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) {
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_with_cast.onnx", 0, 0, 1, 2, logger_.get());
|
||||
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx", 1, 1, 0, 0, logger_.get());
|
||||
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get());
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output_with_cast.onnx", 1, 1, 0, 0, logger_.get());
|
||||
}
|
||||
|
||||
static void TestSkipLayerNormFusionInputOutputCheck(const std::basic_string<ORTCHAR_T>& model_uri, bool with_cast, logging::Logger* logger) {
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger));
|
||||
|
||||
for (Node& node : graph.Nodes()) {
|
||||
if (node.OpType() == "SkipLayerNormalization") {
|
||||
// check inputs
|
||||
std::vector<NodeArg*>& input_defs = node.MutableInputDefs();
|
||||
EXPECT_EQ(input_defs.size(), 5u) << "SkipLayerNormalization number of inputs does not equal to 5. Got:" << node.InputDefs().size();
|
||||
EXPECT_EQ(input_defs[0]->Name(), "input.1");
|
||||
EXPECT_EQ(input_defs[1]->Name(), "6");
|
||||
EXPECT_EQ(input_defs[0]->Name(), ((with_cast) ? "input.1_Float" : "input.1"));
|
||||
EXPECT_EQ(input_defs[1]->Name(), ((with_cast) ? "6_Float" : "6"));
|
||||
EXPECT_EQ(input_defs[2]->Name(), "1");
|
||||
EXPECT_EQ(input_defs[3]->Name(), "2");
|
||||
EXPECT_EQ(input_defs[4]->Name(), "4");
|
||||
EXPECT_EQ(input_defs[4]->Name(), ((with_cast) ? "4_Float" : "4"));
|
||||
|
||||
// check outputs
|
||||
std::vector<NodeArg*>& output_defs = node.MutableOutputDefs();
|
||||
|
|
@ -4971,26 +4985,38 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) {
|
|||
EXPECT_EQ(node.OutputDefs().size(), 1u) << "SkipLayerNormalization number of outputs does not equal to 1. Got:" << node.OutputDefs().size();
|
||||
#endif
|
||||
EXPECT_EQ(output_defs[0]->Name(), "19");
|
||||
} else if (node.OpType() == "Cast") {
|
||||
EXPECT_TRUE(with_cast) << "Unexpected node: " << node.OpType() << "," << node.Name();
|
||||
} else {
|
||||
EXPECT_EQ(node.OpType(), "MatMul") << "Unexpected node: " << node.OpType() << "," << node.Name();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx";
|
||||
TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) {
|
||||
TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_check.onnx", false, logger_.get());
|
||||
TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_with_cast_check.onnx", true, logger_.get());
|
||||
}
|
||||
|
||||
static void TestSkipLayerNormFusionNoBeta(const std::basic_string<ORTCHAR_T>& model_uri, bool with_cast, logging::Logger* logger) {
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["LayerNormalization"] == 0);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == ((with_cast) ? 2 : 0));
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) {
|
||||
TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx", false, logger_.get());
|
||||
TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta_with_cast.onnx", true, logger_.get());
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) {
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_graph_output_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_graph_output_with_cast.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial_with_cast.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_with_cast.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_graph_output_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_graph_output_with_cast.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial_with_cast.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_with_cast.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_graph_output_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_graph_output_with_cast.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_with_cast.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
|
||||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
from onnx import OperatorSetIdProto, TensorProto, helper
|
||||
|
||||
|
||||
class Format(Enum):
|
||||
|
|
@ -10,19 +10,36 @@ class Format(Enum):
|
|||
Format3 = 3
|
||||
|
||||
|
||||
def GenerateModel(format, model_name, multi_output_add=False, add_output_in_graph_output=False): # noqa: N802
|
||||
nodes = [ # LayerNorm subgraph
|
||||
helper.make_node("ReduceMean", ["ln_in"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1),
|
||||
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb1_out"], "sub1"),
|
||||
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb2_out"], "sub2"),
|
||||
helper.make_node("Pow", ["sb2_out", "pow_in_2"], ["pow_out"], "pow"),
|
||||
helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1),
|
||||
helper.make_node("Add", ["rd2_out", "const_e12"], ["add1_out"], "add1"),
|
||||
helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"),
|
||||
helper.make_node("Div", ["sb1_out", "sqrt_out"], ["div_out"], "div1"),
|
||||
helper.make_node("Mul", ["gamma", "div_out"], ["mul_out"], "mul"),
|
||||
helper.make_node("Add", ["mul_out", "beta"], ["C"], "add0"),
|
||||
]
|
||||
def generate_model(model_format, model_name, multi_output_add=False, add_output_in_graph_output=False, with_cast=False):
|
||||
nodes = [] # LayerNorm subgraph
|
||||
if with_cast:
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Cast", ["ln_in"], ["c_out"], "cast", to=1),
|
||||
helper.make_node("ReduceMean", ["c_out"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1),
|
||||
helper.make_node("Sub", ["c_out", "rd1_out"], ["sb1_out"], "sub1"),
|
||||
helper.make_node("Sub", ["c_out", "rd1_out"], ["sb2_out"], "sub2"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("ReduceMean", ["ln_in"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1),
|
||||
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb1_out"], "sub1"),
|
||||
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb2_out"], "sub2"),
|
||||
]
|
||||
)
|
||||
nodes.extend(
|
||||
[ # LayerNorm subgraph
|
||||
helper.make_node("Pow", ["sb2_out", "pow_in_2"], ["pow_out"], "pow"),
|
||||
helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1),
|
||||
helper.make_node("Add", ["rd2_out", "const_e12"], ["add1_out"], "add1"),
|
||||
helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"),
|
||||
helper.make_node("Div", ["sb1_out", "sqrt_out"], ["div_out"], "div1"),
|
||||
helper.make_node("Mul", ["gamma", "div_out"], ["mul_out"], "mul"),
|
||||
helper.make_node("Add", ["mul_out", "beta"], ["C"], "add0"),
|
||||
]
|
||||
)
|
||||
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]),
|
||||
|
|
@ -31,7 +48,7 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
|
|||
helper.make_tensor("beta", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
]
|
||||
|
||||
if format is Format.Format1:
|
||||
if model_format is Format.Format1:
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Add", ["A", "bias"], ["add3_out"], "add3"),
|
||||
|
|
@ -40,10 +57,12 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
|
|||
)
|
||||
initializers.extend(
|
||||
[
|
||||
helper.make_tensor("bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor(
|
||||
"bias", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]
|
||||
),
|
||||
]
|
||||
)
|
||||
elif format is Format.Format2:
|
||||
elif model_format is Format.Format2:
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Add", ["B", "bias"], ["add3_out"], "add3"),
|
||||
|
|
@ -52,10 +71,12 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
|
|||
)
|
||||
initializers.extend(
|
||||
[
|
||||
helper.make_tensor("bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor(
|
||||
"bias", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]
|
||||
),
|
||||
]
|
||||
)
|
||||
elif format is Format.Format3:
|
||||
elif model_format is Format.Format3:
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Add", ["A", "B"], ["ln_in"], "add2"),
|
||||
|
|
@ -63,15 +84,15 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
|
|||
)
|
||||
|
||||
if multi_output_add:
|
||||
neg_input = "ln_in" if format is Format.Format3 else "add3_out"
|
||||
neg_input = "ln_in" if model_format is Format.Format3 else "add3_out"
|
||||
nodes.extend([helper.make_node("Neg", [neg_input], ["neg_out"], "neg")])
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"SkipLayerNorm_format3", # name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info("A", TensorProto.FLOAT, [16, 32, 4]),
|
||||
helper.make_tensor_value_info("B", TensorProto.FLOAT, [16, 32, 4]),
|
||||
helper.make_tensor_value_info("A", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [16, 32, 4]),
|
||||
helper.make_tensor_value_info("B", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [16, 32, 4]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info("C", TensorProto.FLOAT, [16, 32, 4]),
|
||||
|
|
@ -80,32 +101,62 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap
|
|||
)
|
||||
|
||||
if add_output_in_graph_output:
|
||||
extra_output = "ln_in" if format is Format.Format3 else "add3_out"
|
||||
graph.output.extend([helper.make_tensor_value_info(extra_output, TensorProto.FLOAT, [16, 32, 4])])
|
||||
extra_output = "ln_in" if model_format is Format.Format3 else "add3_out"
|
||||
graph.output.extend(
|
||||
[
|
||||
helper.make_tensor_value_info(
|
||||
extra_output, TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [16, 32, 4]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
onnxdomain.domain = ""
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
opsets = [onnxdomain, msdomain]
|
||||
|
||||
model = helper.make_model(graph, opset_imports=opsets)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
|
||||
GenerateModel(Format.Format1, "skip_layer_norm_format1.onnx")
|
||||
GenerateModel(Format.Format2, "skip_layer_norm_format2.onnx")
|
||||
GenerateModel(Format.Format3, "skip_layer_norm_format3.onnx")
|
||||
GenerateModel(Format.Format1, "skip_layer_norm_format1_partial.onnx", multi_output_add=True)
|
||||
GenerateModel(Format.Format2, "skip_layer_norm_format2_partial.onnx", multi_output_add=True)
|
||||
GenerateModel(Format.Format3, "skip_layer_norm_format3_no_fusion.onnx", multi_output_add=True)
|
||||
def generate_skip_layer_norm(with_cast=False):
|
||||
suffix = "_with_cast" if with_cast else ""
|
||||
|
||||
GenerateModel(
|
||||
Format.Format1,
|
||||
"skip_layer_norm_format1_graph_output.onnx",
|
||||
add_output_in_graph_output=True,
|
||||
)
|
||||
GenerateModel(
|
||||
Format.Format2,
|
||||
"skip_layer_norm_format2_graph_output.onnx",
|
||||
add_output_in_graph_output=True,
|
||||
)
|
||||
GenerateModel(
|
||||
Format.Format3,
|
||||
"skip_layer_norm_format3_graph_output.onnx",
|
||||
add_output_in_graph_output=True,
|
||||
)
|
||||
generate_model(Format.Format1, f"skip_layer_norm_format1{suffix}.onnx", with_cast=with_cast)
|
||||
generate_model(Format.Format2, f"skip_layer_norm_format2{suffix}.onnx", with_cast=with_cast)
|
||||
generate_model(Format.Format3, f"skip_layer_norm_format3{suffix}.onnx", with_cast=with_cast)
|
||||
generate_model(
|
||||
Format.Format1, f"skip_layer_norm_format1_partial{suffix}.onnx", multi_output_add=True, with_cast=with_cast
|
||||
)
|
||||
generate_model(
|
||||
Format.Format2, f"skip_layer_norm_format2_partial{suffix}.onnx", multi_output_add=True, with_cast=with_cast
|
||||
)
|
||||
generate_model(
|
||||
Format.Format3, f"skip_layer_norm_format3_no_fusion{suffix}.onnx", multi_output_add=True, with_cast=with_cast
|
||||
)
|
||||
generate_model(
|
||||
Format.Format1,
|
||||
f"skip_layer_norm_format1_graph_output{suffix}.onnx",
|
||||
add_output_in_graph_output=True,
|
||||
with_cast=with_cast,
|
||||
)
|
||||
generate_model(
|
||||
Format.Format2,
|
||||
f"skip_layer_norm_format2_graph_output{suffix}.onnx",
|
||||
add_output_in_graph_output=True,
|
||||
with_cast=with_cast,
|
||||
)
|
||||
generate_model(
|
||||
Format.Format3,
|
||||
f"skip_layer_norm_format3_graph_output{suffix}.onnx",
|
||||
add_output_in_graph_output=True,
|
||||
with_cast=with_cast,
|
||||
)
|
||||
|
||||
|
||||
generate_skip_layer_norm(with_cast=False)
|
||||
generate_skip_layer_norm(with_cast=True)
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_input_output_with_cast_check.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_input_output_with_cast_check.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta_with_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta_with_cast.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue