diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index 903563d364..f265dfc7f5 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -90,6 +90,36 @@ static bool CheckSecondAdd(Graph& graph, Node& add, ProviderType providertype) { add_input1_shape->dim(2).dim_value() == add_input2_shape->dim(0).dim_value(); } +// Add a Cast to convert input from float16/bfloat16 to float when input type is different fromm output type +static NodeArg* CastToFloat(Graph& graph, NodeArg* input, int32_t output_data_type, ProviderType provider_type) { + if (nullptr == input->Type() || + input->TypeAsProto()->tensor_type().elem_type() == output_data_type || + output_data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return input; + } + + auto input_shape = input->Shape(); + TypeProto input_float; + input_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + for (auto i = 0; i < input_shape->dim_size(); ++i) { + auto dim = input_float.mutable_tensor_type()->mutable_shape()->add_dim(); + *dim = input_shape->dim(i); + } + auto& cast_float = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(input->Name() + "_Float"), &input_float); + + auto& node = graph.AddNode(graph.GenerateNodeName(input->Name() + "_Cast"), + "Cast", + "Cast Input to float", + std::array{input}, + std::array{&cast_float}, + nullptr, + kOnnxDomain); + + node.AddAttribute("to", int64_t{ONNX_NAMESPACE::TensorProto_DataType_FLOAT}); + node.SetExecutionProviderType(provider_type); + return &cast_float; +} + /** Skip Layer Normalization will fuse Add + LayerNormalization into one node, and another Add if applicable @@ -243,6 +273,14 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le nodes_to_remove.push_back(*p_add1); nodes_to_remove.push_back(ln_node); + // If input types are different than output type and output type is float, insert cast node after inputs. + for (auto& input_def: skip_layer_norm_input_defs) { + input_def = CastToFloat(graph, + input_def, + ln_node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(), + ln_node.GetExecutionProviderType()); + } + Node& skip_layer_norm_node = graph.AddNode(graph.GenerateNodeName("SkipLayerNormalization"), "SkipLayerNormalization", "fused SkipLayerNorm subgraphs ", diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index da9cd3caac..0c4a685371 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4906,7 +4906,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTestCudaEp) { } static void TestSkipLayerNormFusion(const std::basic_string& file_path, int add_count, int ln_count, - int skip_ln_count, logging::Logger* logger) { + int skip_ln_count, int cast_count, logging::Logger* logger) { std::shared_ptr p_model; ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); Graph& graph = p_model->MainGraph(); @@ -4925,43 +4925,57 @@ static void TestSkipLayerNormFusion(const std::basic_string& file_pat ASSERT_TRUE(op_to_count["Sqrt"] == 0); ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count); ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == skip_ln_count); + ASSERT_TRUE(op_to_count["Cast"] == cast_count); } TEST_F(GraphTransformationTests, SkipLayerNormFusionTest) { - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1, 0, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0, 0, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output.onnx", 1, 0, 1, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output.onnx", 1, 0, 1, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output.onnx", 1, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output.onnx", 1, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, 0, logger_.get()); } -TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/skip_layer_norm_input_output_check.onnx"; +TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) { + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_with_cast.onnx", 0, 0, 1, 2, logger_.get()); + + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx", 1, 1, 0, 0, logger_.get()); + + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output_with_cast.onnx", 1, 1, 0, 0, logger_.get()); +} + +static void TestSkipLayerNormFusionInputOutputCheck(const std::basic_string& model_uri, bool with_cast, logging::Logger* logger) { std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); for (Node& node : graph.Nodes()) { if (node.OpType() == "SkipLayerNormalization") { // check inputs std::vector& input_defs = node.MutableInputDefs(); EXPECT_EQ(input_defs.size(), 5u) << "SkipLayerNormalization number of inputs does not equal to 5. Got:" << node.InputDefs().size(); - EXPECT_EQ(input_defs[0]->Name(), "input.1"); - EXPECT_EQ(input_defs[1]->Name(), "6"); + EXPECT_EQ(input_defs[0]->Name(), ((with_cast) ? "input.1_Float" : "input.1")); + EXPECT_EQ(input_defs[1]->Name(), ((with_cast) ? "6_Float" : "6")); EXPECT_EQ(input_defs[2]->Name(), "1"); EXPECT_EQ(input_defs[3]->Name(), "2"); - EXPECT_EQ(input_defs[4]->Name(), "4"); + EXPECT_EQ(input_defs[4]->Name(), ((with_cast) ? "4_Float" : "4")); // check outputs std::vector& output_defs = node.MutableOutputDefs(); @@ -4971,26 +4985,38 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) { EXPECT_EQ(node.OutputDefs().size(), 1u) << "SkipLayerNormalization number of outputs does not equal to 1. Got:" << node.OutputDefs().size(); #endif EXPECT_EQ(output_defs[0]->Name(), "19"); + } else if (node.OpType() == "Cast") { + EXPECT_TRUE(with_cast) << "Unexpected node: " << node.OpType() << "," << node.Name(); } else { EXPECT_EQ(node.OpType(), "MatMul") << "Unexpected node: " << node.OpType() << "," << node.Name(); } } } -TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx"; +TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) { + TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_check.onnx", false, logger_.get()); + TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_with_cast_check.onnx", true, logger_.get()); +} + +static void TestSkipLayerNormFusionNoBeta(const std::basic_string& model_uri, bool with_cast, logging::Logger* logger) { std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 0); ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 1); + ASSERT_TRUE(op_to_count["Cast"] == ((with_cast) ? 2 : 0)); +} + +TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) { + TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx", false, logger_.get()); + TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta_with_cast.onnx", true, logger_.get()); } TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) { diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_graph_output_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_graph_output_with_cast.onnx new file mode 100644 index 0000000000..043c0545eb Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_graph_output_with_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial_with_cast.onnx new file mode 100644 index 0000000000..58a5417aca Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial_with_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_with_cast.onnx new file mode 100644 index 0000000000..69fb3b08d7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_with_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_graph_output_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_graph_output_with_cast.onnx new file mode 100644 index 0000000000..96683f2cd4 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_graph_output_with_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial_with_cast.onnx new file mode 100644 index 0000000000..ff19625a27 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial_with_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_with_cast.onnx new file mode 100644 index 0000000000..566db31e9b Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_with_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_graph_output_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_graph_output_with_cast.onnx new file mode 100644 index 0000000000..840fc7ecda Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_graph_output_with_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx new file mode 100644 index 0000000000..173be529cd Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_with_cast.onnx new file mode 100644 index 0000000000..5646faf09e Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_with_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py index 0ebc5b11e0..a95b825c9f 100644 --- a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py @@ -1,7 +1,7 @@ from enum import Enum import onnx -from onnx import TensorProto, helper +from onnx import OperatorSetIdProto, TensorProto, helper class Format(Enum): @@ -10,19 +10,36 @@ class Format(Enum): Format3 = 3 -def GenerateModel(format, model_name, multi_output_add=False, add_output_in_graph_output=False): # noqa: N802 - nodes = [ # LayerNorm subgraph - helper.make_node("ReduceMean", ["ln_in"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1), - helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb1_out"], "sub1"), - helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb2_out"], "sub2"), - helper.make_node("Pow", ["sb2_out", "pow_in_2"], ["pow_out"], "pow"), - helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1), - helper.make_node("Add", ["rd2_out", "const_e12"], ["add1_out"], "add1"), - helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"), - helper.make_node("Div", ["sb1_out", "sqrt_out"], ["div_out"], "div1"), - helper.make_node("Mul", ["gamma", "div_out"], ["mul_out"], "mul"), - helper.make_node("Add", ["mul_out", "beta"], ["C"], "add0"), - ] +def generate_model(model_format, model_name, multi_output_add=False, add_output_in_graph_output=False, with_cast=False): + nodes = [] # LayerNorm subgraph + if with_cast: + nodes.extend( + [ + helper.make_node("Cast", ["ln_in"], ["c_out"], "cast", to=1), + helper.make_node("ReduceMean", ["c_out"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1), + helper.make_node("Sub", ["c_out", "rd1_out"], ["sb1_out"], "sub1"), + helper.make_node("Sub", ["c_out", "rd1_out"], ["sb2_out"], "sub2"), + ] + ) + else: + nodes.extend( + [ + helper.make_node("ReduceMean", ["ln_in"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1), + helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb1_out"], "sub1"), + helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb2_out"], "sub2"), + ] + ) + nodes.extend( + [ # LayerNorm subgraph + helper.make_node("Pow", ["sb2_out", "pow_in_2"], ["pow_out"], "pow"), + helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1), + helper.make_node("Add", ["rd2_out", "const_e12"], ["add1_out"], "add1"), + helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"), + helper.make_node("Div", ["sb1_out", "sqrt_out"], ["div_out"], "div1"), + helper.make_node("Mul", ["gamma", "div_out"], ["mul_out"], "mul"), + helper.make_node("Add", ["mul_out", "beta"], ["C"], "add0"), + ] + ) initializers = [ # initializers helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]), @@ -31,7 +48,7 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap helper.make_tensor("beta", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), ] - if format is Format.Format1: + if model_format is Format.Format1: nodes.extend( [ helper.make_node("Add", ["A", "bias"], ["add3_out"], "add3"), @@ -40,10 +57,12 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap ) initializers.extend( [ - helper.make_tensor("bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor( + "bias", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4] + ), ] ) - elif format is Format.Format2: + elif model_format is Format.Format2: nodes.extend( [ helper.make_node("Add", ["B", "bias"], ["add3_out"], "add3"), @@ -52,10 +71,12 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap ) initializers.extend( [ - helper.make_tensor("bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor( + "bias", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4] + ), ] ) - elif format is Format.Format3: + elif model_format is Format.Format3: nodes.extend( [ helper.make_node("Add", ["A", "B"], ["ln_in"], "add2"), @@ -63,15 +84,15 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap ) if multi_output_add: - neg_input = "ln_in" if format is Format.Format3 else "add3_out" + neg_input = "ln_in" if model_format is Format.Format3 else "add3_out" nodes.extend([helper.make_node("Neg", [neg_input], ["neg_out"], "neg")]) graph = helper.make_graph( nodes, "SkipLayerNorm_format3", # name [ # inputs - helper.make_tensor_value_info("A", TensorProto.FLOAT, [16, 32, 4]), - helper.make_tensor_value_info("B", TensorProto.FLOAT, [16, 32, 4]), + helper.make_tensor_value_info("A", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [16, 32, 4]), + helper.make_tensor_value_info("B", TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [16, 32, 4]), ], [ # outputs helper.make_tensor_value_info("C", TensorProto.FLOAT, [16, 32, 4]), @@ -80,32 +101,62 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap ) if add_output_in_graph_output: - extra_output = "ln_in" if format is Format.Format3 else "add3_out" - graph.output.extend([helper.make_tensor_value_info(extra_output, TensorProto.FLOAT, [16, 32, 4])]) + extra_output = "ln_in" if model_format is Format.Format3 else "add3_out" + graph.output.extend( + [ + helper.make_tensor_value_info( + extra_output, TensorProto.FLOAT16 if with_cast else TensorProto.FLOAT, [16, 32, 4] + ) + ] + ) - model = helper.make_model(graph) + onnxdomain = OperatorSetIdProto() + onnxdomain.version = 12 + # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. + onnxdomain.domain = "" + msdomain = OperatorSetIdProto() + msdomain.version = 1 + msdomain.domain = "com.microsoft" + opsets = [onnxdomain, msdomain] + + model = helper.make_model(graph, opset_imports=opsets) onnx.save(model, model_name) -GenerateModel(Format.Format1, "skip_layer_norm_format1.onnx") -GenerateModel(Format.Format2, "skip_layer_norm_format2.onnx") -GenerateModel(Format.Format3, "skip_layer_norm_format3.onnx") -GenerateModel(Format.Format1, "skip_layer_norm_format1_partial.onnx", multi_output_add=True) -GenerateModel(Format.Format2, "skip_layer_norm_format2_partial.onnx", multi_output_add=True) -GenerateModel(Format.Format3, "skip_layer_norm_format3_no_fusion.onnx", multi_output_add=True) +def generate_skip_layer_norm(with_cast=False): + suffix = "_with_cast" if with_cast else "" -GenerateModel( - Format.Format1, - "skip_layer_norm_format1_graph_output.onnx", - add_output_in_graph_output=True, -) -GenerateModel( - Format.Format2, - "skip_layer_norm_format2_graph_output.onnx", - add_output_in_graph_output=True, -) -GenerateModel( - Format.Format3, - "skip_layer_norm_format3_graph_output.onnx", - add_output_in_graph_output=True, -) + generate_model(Format.Format1, f"skip_layer_norm_format1{suffix}.onnx", with_cast=with_cast) + generate_model(Format.Format2, f"skip_layer_norm_format2{suffix}.onnx", with_cast=with_cast) + generate_model(Format.Format3, f"skip_layer_norm_format3{suffix}.onnx", with_cast=with_cast) + generate_model( + Format.Format1, f"skip_layer_norm_format1_partial{suffix}.onnx", multi_output_add=True, with_cast=with_cast + ) + generate_model( + Format.Format2, f"skip_layer_norm_format2_partial{suffix}.onnx", multi_output_add=True, with_cast=with_cast + ) + generate_model( + Format.Format3, f"skip_layer_norm_format3_no_fusion{suffix}.onnx", multi_output_add=True, with_cast=with_cast + ) + generate_model( + Format.Format1, + f"skip_layer_norm_format1_graph_output{suffix}.onnx", + add_output_in_graph_output=True, + with_cast=with_cast, + ) + generate_model( + Format.Format2, + f"skip_layer_norm_format2_graph_output{suffix}.onnx", + add_output_in_graph_output=True, + with_cast=with_cast, + ) + generate_model( + Format.Format3, + f"skip_layer_norm_format3_graph_output{suffix}.onnx", + add_output_in_graph_output=True, + with_cast=with_cast, + ) + + +generate_skip_layer_norm(with_cast=False) +generate_skip_layer_norm(with_cast=True) diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_input_output_with_cast_check.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_input_output_with_cast_check.onnx new file mode 100644 index 0000000000..38768b88ef Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_input_output_with_cast_check.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta_with_cast.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta_with_cast.onnx new file mode 100644 index 0000000000..f97debaa7c Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta_with_cast.onnx differ