mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Only fuse when output count of add is 1 (#2884)
* Only fuse when output count of add is 1 * add unit test for add with multi output
This commit is contained in:
parent
a92e924ab2
commit
cd876720d9
9 changed files with 23 additions and 11 deletions
|
|
@ -25,7 +25,8 @@ static bool IsSupportedDataType(const Node& node) {
|
|||
|
||||
static bool CheckFirstAdd(Node& add, ProviderType providertype) {
|
||||
if (providertype != add.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(add)) {
|
||||
!IsSupportedDataType(add) ||
|
||||
add.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -58,7 +59,8 @@ static bool CheckFirstAdd(Node& add, ProviderType providertype) {
|
|||
// The 2nd input should be a 1D constant value
|
||||
static bool CheckSecondAdd(Node& add, ProviderType providertype) {
|
||||
if (providertype != add.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(add)) {
|
||||
!IsSupportedDataType(add) ||
|
||||
add.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1272,7 +1272,7 @@ TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) {
|
|||
}
|
||||
}
|
||||
|
||||
static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_path) {
|
||||
static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_path, int add_count, int ln_count, int skip_ln_count) {
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
|
@ -1285,19 +1285,22 @@ static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_pat
|
|||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Div"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Add"] == add_count );
|
||||
ASSERT_TRUE(op_to_count["Sub"] == 0);
|
||||
ASSERT_TRUE(op_to_count["ReduceMean"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Pow"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Sqrt"] == 0);
|
||||
ASSERT_TRUE(op_to_count["LayerNormalization"] == 0);
|
||||
ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == 1);
|
||||
ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count );
|
||||
ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == skip_ln_count );
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, SkipLayerNormFusionTest) {
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx");
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx");
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx");
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1);
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1 );
|
||||
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1 );
|
||||
TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1 );
|
||||
TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1 );
|
||||
TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0 );
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, EmbedLayerNormFusionFormat1) {
|
||||
|
|
|
|||
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial.onnx
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial.onnx
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -8,7 +8,7 @@ class Format(Enum):
|
|||
Format2=2,
|
||||
Format3=3
|
||||
|
||||
def GenerateModel(format, model_name):
|
||||
def GenerateModel(format, model_name, multi_output_add = False):
|
||||
nodes = [ # LayerNorm subgraph
|
||||
helper.make_node("ReduceMean", ["ln_in"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1),
|
||||
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb1_out"], "sub1"),
|
||||
|
|
@ -42,6 +42,10 @@ def GenerateModel(format, model_name):
|
|||
elif format is Format.Format3:
|
||||
nodes.extend([helper.make_node("Add", ["A", "B"], ["ln_in"], "add2"),])
|
||||
|
||||
if multi_output_add:
|
||||
neg_input = "ln_in" if format is Format.Format3 else "add3_out"
|
||||
nodes.extend([helper.make_node("Neg", [neg_input], ["neg_out"], "neg")])
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"SkipLayerNorm_format3", #name
|
||||
|
|
@ -60,4 +64,7 @@ def GenerateModel(format, model_name):
|
|||
|
||||
GenerateModel(Format.Format1, 'skip_layer_norm_format1.onnx')
|
||||
GenerateModel(Format.Format2, 'skip_layer_norm_format2.onnx')
|
||||
GenerateModel(Format.Format3, 'skip_layer_norm_format3.onnx')
|
||||
GenerateModel(Format.Format3, 'skip_layer_norm_format3.onnx')
|
||||
GenerateModel(Format.Format1, 'skip_layer_norm_format1_partial.onnx', True)
|
||||
GenerateModel(Format.Format2, 'skip_layer_norm_format2_partial.onnx', True)
|
||||
GenerateModel(Format.Format3, 'skip_layer_norm_format3_no_fusion.onnx', True)
|
||||
Loading…
Reference in a new issue