mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
New Pattern Support for LayerNormFusion (#14118)
Latest torch exporter changed the LayerNorm exporting code to add two more Cast nodes (to make it logically correct in compute), but our current LayerNormFusion doesn't support the new pattern. The PR is to add support of this.
This commit is contained in:
parent
f864b54393
commit
15c1157ef2
2 changed files with 91 additions and 17 deletions
|
|
@ -52,9 +52,9 @@ due to restriction in older opsets. Therefore, Layer Normalization will also han
|
|||
| |
|
||||
| v
|
||||
X --> ReduceMean --> Sub --> Cast --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
|
||||
| ^
|
||||
| |
|
||||
+--------------------------------------------------------+
|
||||
| ^
|
||||
| |
|
||||
+------------------------------------------------+
|
||||
+---------------------+ Cast
|
||||
| | |
|
||||
| v v
|
||||
|
|
@ -134,7 +134,6 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
Node& sub_node = *graph.GetNode(p_sub_node->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13, 14}) ||
|
||||
sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) ||
|
||||
!IsSupportedDataType(sub_node)) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -184,9 +183,24 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
// Apex O2 pattern specific match ends...
|
||||
|
||||
// Find the "Div" node after "Sub".
|
||||
// Find the "Div" node after "Sub". It's possible that there is "Cast" node after "Sub" node.
|
||||
const Node* p_cast1 = nullptr;
|
||||
if (!p_sub_node_dup && sub_node.GetOutputEdgesCount() == 1) {
|
||||
Node& cast_node = *graph.GetNode(sub_node.OutputNodesBegin()->Index());
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) &&
|
||||
cast_node.GetExecutionProviderType() == reduce_mean_node.GetExecutionProviderType() &&
|
||||
optimizer_utils::CheckOutputEdges(graph, cast_node, 2u) && IsSupportedDataType(cast_node)) {
|
||||
p_cast1 = &cast_node;
|
||||
nodes_to_remove.push_back(cast_node);
|
||||
}
|
||||
}
|
||||
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 && !p_cast1 ? 2u : 1u)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Node* p_div = nullptr;
|
||||
p_div = graph_utils::FirstChildByType(sub_node, "Div");
|
||||
p_div = graph_utils::FirstChildByType(p_cast1 ? *p_cast1 : sub_node, "Div");
|
||||
|
||||
// Find the sub_dup node if exist
|
||||
if (p_sub_node_dup != nullptr) {
|
||||
|
|
@ -269,23 +283,19 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
nodes_to_remove.push_back(pow_node);
|
||||
|
||||
// check if Cast node exists: either between sub and pow, or as second input to pow
|
||||
const Node* p_cast_node = graph_utils::FirstParentByType(pow_node, "Cast");
|
||||
if (p_cast_node != nullptr) {
|
||||
Node& cast_node = *graph.GetNode(p_cast_node->Index());
|
||||
const Node* p_cast2 = graph_utils::FirstParentByType(pow_node, "Cast");
|
||||
if (p_cast2 != nullptr && p_cast2 != p_cast1) {
|
||||
Node& cast_node = *graph.GetNode(p_cast2->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) ||
|
||||
cast_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, cast_node, 1)) {
|
||||
continue;
|
||||
}
|
||||
nodes_to_remove.push_back(cast_node);
|
||||
|
||||
// Traceback from the last node in vector to find sub --> pow or sub --> cast
|
||||
const Node* p_sub2_node = graph_utils::FirstParentByType(nodes_to_remove.back(), "Sub");
|
||||
if (p_sub2_node != nullptr) {
|
||||
// Cast is between Sub and Pow
|
||||
if ((p_sub2_node != p_sub_node && p_sub2_node != p_sub_node_dup) || !IsSupportedDataType(cast_node)) {
|
||||
continue;
|
||||
}
|
||||
} else if (!p_cast2) {
|
||||
const Node* p_sub2_node = graph_utils::FirstParentByType(pow_node, "Sub");
|
||||
if (!p_sub2_node || (p_sub2_node != p_sub_node && p_sub2_node != p_sub_node_dup)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4453,6 +4453,70 @@ TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* data_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
|
||||
auto* pow_initializer = builder.MakeInitializer<float>({}, {2.0f});
|
||||
auto* add_initializer = builder.MakeInitializer<float>({}, {1e-5f});
|
||||
auto* weight_initializer = builder.MakeInitializer<MLFloat16>({3}, std::vector<MLFloat16>(3, MLFloat16(1.0f)));
|
||||
auto* bias_initializer = builder.MakeInitializer<MLFloat16>({3}, std::vector<MLFloat16>(3, MLFloat16(0.0f)));
|
||||
auto* reduce_mean_out_1 = builder.MakeIntermediate();
|
||||
auto* sub_out = builder.MakeIntermediate();
|
||||
auto* cast_out_1 = builder.MakeIntermediate();
|
||||
auto* pow_out = builder.MakeIntermediate();
|
||||
auto* reduce_mean_out_2 = builder.MakeIntermediate();
|
||||
auto* add_out_1 = builder.MakeIntermediate();
|
||||
auto* sqrt_out = builder.MakeIntermediate();
|
||||
auto* div_out = builder.MakeIntermediate();
|
||||
auto* cast_out_2 = builder.MakeIntermediate();
|
||||
auto* mul_out = builder.MakeIntermediate();
|
||||
auto* add_out_2 = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector<int64_t>{-1});
|
||||
builder.AddNode("Sub", {data_arg, reduce_mean_out_1}, {sub_out});
|
||||
builder.AddNode("Cast", {sub_out}, {cast_out_1})
|
||||
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
|
||||
builder.AddNode("Pow", {cast_out_1, pow_initializer}, {pow_out});
|
||||
builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector<int64_t>{-1});
|
||||
builder.AddNode("Add", {reduce_mean_out_2, add_initializer}, {add_out_1});
|
||||
builder.AddNode("Sqrt", {add_out_1}, {sqrt_out});
|
||||
builder.AddNode("Div", {cast_out_1, sqrt_out}, {div_out});
|
||||
builder.AddNode("Cast", {div_out}, {cast_out_2})
|
||||
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
builder.AddNode("Mul", {cast_out_2, weight_initializer}, {mul_out});
|
||||
builder.AddNode("Add", {mul_out, bias_initializer}, {add_out_2});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
|
||||
1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_t5.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
Loading…
Reference in a new issue