Only fuse when output count of add is 1 (#2884)

* Only fuse when output count of add is 1

* add unit test for add with multi output
This commit is contained in:
Yufeng Li 2020-01-24 13:47:34 -08:00 committed by GitHub
parent a92e924ab2
commit cd876720d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 23 additions and 11 deletions

View file

@ -25,7 +25,8 @@ static bool IsSupportedDataType(const Node& node) {
static bool CheckFirstAdd(Node& add, ProviderType providertype) {
if (providertype != add.GetExecutionProviderType() ||
!IsSupportedDataType(add)) {
!IsSupportedDataType(add) ||
add.GetOutputEdgesCount() != 1) {
return false;
}
@ -58,7 +59,8 @@ static bool CheckFirstAdd(Node& add, ProviderType providertype) {
// The 2nd input should be a 1D constant value
static bool CheckSecondAdd(Node& add, ProviderType providertype) {
if (providertype != add.GetExecutionProviderType() ||
!IsSupportedDataType(add)) {
!IsSupportedDataType(add) ||
add.GetOutputEdgesCount() != 1) {
return false;
}

View file

@ -1272,7 +1272,7 @@ TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) {
}
}
static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_path) {
static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_path, int add_count, int ln_count, int skip_ln_count) {
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = p_model->MainGraph();
@ -1285,19 +1285,22 @@ static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_pat
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Add"] == add_count );
ASSERT_TRUE(op_to_count["Sub"] == 0);
ASSERT_TRUE(op_to_count["ReduceMean"] == 0);
ASSERT_TRUE(op_to_count["Pow"] == 0);
ASSERT_TRUE(op_to_count["Sqrt"] == 0);
ASSERT_TRUE(op_to_count["LayerNormalization"] == 0);
ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == 1);
ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count );
ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == skip_ln_count );
}
TEST(GraphTransformationTests, SkipLayerNormFusionTest) {
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx");
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx");
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx");
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1);
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1 );
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1 );
TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1 );
TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1 );
TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0 );
}
TEST(GraphTransformationTests, EmbedLayerNormFusionFormat1) {

View file

@ -8,7 +8,7 @@ class Format(Enum):
Format2=2,
Format3=3
def GenerateModel(format, model_name):
def GenerateModel(format, model_name, multi_output_add = False):
nodes = [ # LayerNorm subgraph
helper.make_node("ReduceMean", ["ln_in"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1),
helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb1_out"], "sub1"),
@ -42,6 +42,10 @@ def GenerateModel(format, model_name):
elif format is Format.Format3:
nodes.extend([helper.make_node("Add", ["A", "B"], ["ln_in"], "add2"),])
if multi_output_add:
neg_input = "ln_in" if format is Format.Format3 else "add3_out"
nodes.extend([helper.make_node("Neg", [neg_input], ["neg_out"], "neg")])
graph = helper.make_graph(
nodes,
"SkipLayerNorm_format3", #name
@ -60,4 +64,7 @@ def GenerateModel(format, model_name):
GenerateModel(Format.Format1, 'skip_layer_norm_format1.onnx')
GenerateModel(Format.Format2, 'skip_layer_norm_format2.onnx')
GenerateModel(Format.Format3, 'skip_layer_norm_format3.onnx')
GenerateModel(Format.Format3, 'skip_layer_norm_format3.onnx')
GenerateModel(Format.Format1, 'skip_layer_norm_format1_partial.onnx', True)
GenerateModel(Format.Format2, 'skip_layer_norm_format2_partial.onnx', True)
GenerateModel(Format.Format3, 'skip_layer_norm_format3_no_fusion.onnx', True)