New Pattern Support for LayerNormFusion (#14118)

Latest torch exporter changed the LayerNorm exporting code to add two
more Cast nodes (to make it logically correct in compute), but our
current LayerNormFusion doesn't support the new pattern. The PR is to
add support of this.
This commit is contained in:
Vincent Wang 2023-01-04 17:51:14 +08:00 committed by GitHub
parent f864b54393
commit 15c1157ef2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 91 additions and 17 deletions

View file

@ -52,9 +52,9 @@ due to restriction in older opsets. Therefore, Layer Normalization will also han
| |
| v
X --> ReduceMean --> Sub --> Cast --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
| ^
| |
+--------------------------------------------------------+
| ^
| |
+------------------------------------------------+
+---------------------+ Cast
| | |
| v v
@ -134,7 +134,6 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& sub_node = *graph.GetNode(p_sub_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13, 14}) ||
sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) ||
!IsSupportedDataType(sub_node)) {
continue;
}
@ -184,9 +183,24 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
// Apex O2 pattern specific match ends...
// Find the "Div" node after "Sub".
// Find the "Div" node after "Sub". It's possible that there is "Cast" node after "Sub" node.
const Node* p_cast1 = nullptr;
if (!p_sub_node_dup && sub_node.GetOutputEdgesCount() == 1) {
Node& cast_node = *graph.GetNode(sub_node.OutputNodesBegin()->Index());
if (graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) &&
cast_node.GetExecutionProviderType() == reduce_mean_node.GetExecutionProviderType() &&
optimizer_utils::CheckOutputEdges(graph, cast_node, 2u) && IsSupportedDataType(cast_node)) {
p_cast1 = &cast_node;
nodes_to_remove.push_back(cast_node);
}
}
if (!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 && !p_cast1 ? 2u : 1u)) {
continue;
}
const Node* p_div = nullptr;
p_div = graph_utils::FirstChildByType(sub_node, "Div");
p_div = graph_utils::FirstChildByType(p_cast1 ? *p_cast1 : sub_node, "Div");
// Find the sub_dup node if exist
if (p_sub_node_dup != nullptr) {
@ -269,23 +283,19 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
nodes_to_remove.push_back(pow_node);
// check if Cast node exists: either between sub and pow, or as second input to pow
const Node* p_cast_node = graph_utils::FirstParentByType(pow_node, "Cast");
if (p_cast_node != nullptr) {
Node& cast_node = *graph.GetNode(p_cast_node->Index());
const Node* p_cast2 = graph_utils::FirstParentByType(pow_node, "Cast");
if (p_cast2 != nullptr && p_cast2 != p_cast1) {
Node& cast_node = *graph.GetNode(p_cast2->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) ||
cast_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, cast_node, 1)) {
continue;
}
nodes_to_remove.push_back(cast_node);
// Traceback from the last node in vector to find sub --> pow or sub --> cast
const Node* p_sub2_node = graph_utils::FirstParentByType(nodes_to_remove.back(), "Sub");
if (p_sub2_node != nullptr) {
// Cast is between Sub and Pow
if ((p_sub2_node != p_sub_node && p_sub2_node != p_sub_node_dup) || !IsSupportedDataType(cast_node)) {
continue;
}
} else if (!p_cast2) {
const Node* p_sub2_node = graph_utils::FirstParentByType(pow_node, "Sub");
if (!p_sub2_node || (p_sub2_node != p_sub_node && p_sub2_node != p_sub_node_dup)) {
continue;
}
}

View file

@ -4453,6 +4453,70 @@ TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) {
}
}
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
auto* pow_initializer = builder.MakeInitializer<float>({}, {2.0f});
auto* add_initializer = builder.MakeInitializer<float>({}, {1e-5f});
auto* weight_initializer = builder.MakeInitializer<MLFloat16>({3}, std::vector<MLFloat16>(3, MLFloat16(1.0f)));
auto* bias_initializer = builder.MakeInitializer<MLFloat16>({3}, std::vector<MLFloat16>(3, MLFloat16(0.0f)));
auto* reduce_mean_out_1 = builder.MakeIntermediate();
auto* sub_out = builder.MakeIntermediate();
auto* cast_out_1 = builder.MakeIntermediate();
auto* pow_out = builder.MakeIntermediate();
auto* reduce_mean_out_2 = builder.MakeIntermediate();
auto* add_out_1 = builder.MakeIntermediate();
auto* sqrt_out = builder.MakeIntermediate();
auto* div_out = builder.MakeIntermediate();
auto* cast_out_2 = builder.MakeIntermediate();
auto* mul_out = builder.MakeIntermediate();
auto* add_out_2 = builder.MakeOutput();
builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector<int64_t>{-1});
builder.AddNode("Sub", {data_arg, reduce_mean_out_1}, {sub_out});
builder.AddNode("Cast", {sub_out}, {cast_out_1})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
builder.AddNode("Pow", {cast_out_1, pow_initializer}, {pow_out});
builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector<int64_t>{-1});
builder.AddNode("Add", {reduce_mean_out_2, add_initializer}, {add_out_1});
builder.AddNode("Sqrt", {add_out_1}, {sqrt_out});
builder.AddNode("Div", {cast_out_1, sqrt_out}, {div_out});
builder.AddNode("Cast", {div_out}, {cast_out_2})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
builder.AddNode("Mul", {cast_out_2, weight_initializer}, {mul_out});
builder.AddNode("Add", {mul_out, bias_initializer}, {add_out_2});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1);
return Status::OK();
};
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == 1);
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}
TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_t5.onnx";
std::shared_ptr<Model> p_model;