mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Match More EmbedLayerNormalization Patterns for Bert Model Graph Fusion (#4354)
match more embed patterns for bert base cased
This commit is contained in:
parent
755675541a
commit
37b624b688
4 changed files with 220 additions and 138 deletions
|
|
@ -103,24 +103,24 @@ static void AddNodes(std::vector<NodeIndex>& node_indices,
|
|||
It is because they are matched as part of other subgraph.
|
||||
*/
|
||||
|
||||
static bool MatchPositionSubgraph(
|
||||
static bool MatchInputToConcatSubgraph(
|
||||
Graph& graph,
|
||||
const Node& expand_node,
|
||||
const Node& cur_node,
|
||||
const NodeArg* input_ids,
|
||||
const int index,
|
||||
const logging::Logger& logger,
|
||||
std::vector<NodeIndex>& subgraph_node_indices,
|
||||
const NodeIndex expected_gather_node_1_index) {
|
||||
subgraph_node_indices.clear();
|
||||
|
||||
std::vector<graph_utils::EdgeEndToMatch> expand_parent_path1{
|
||||
{0, 1, "Concat", {4, 11}, kOnnxDomain},
|
||||
{0, index, "Concat", {4, 11}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain},
|
||||
};
|
||||
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (!graph_utils::FindPath(expand_node, true, expand_parent_path1, edges, logger)) {
|
||||
if (!graph_utils::FindPath(cur_node, true, expand_parent_path1, edges, logger)) {
|
||||
DEBUG_LOG("Failed to find path 1 of position shape.");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -184,60 +184,80 @@ static bool MatchPositionSubgraph(
|
|||
}
|
||||
|
||||
/** Match subgraph like the following:
|
||||
(input_ids)
|
||||
/ \
|
||||
Shape Shape
|
||||
| |
|
||||
^Gather (indice=0)^ Gather (indice=1)--+
|
||||
^|^ ^|^ |
|
||||
^Unsqueeze^ ^Unsqueeze^ Unsqueeze
|
||||
^\^ ^/^ |
|
||||
^\^ ^/^ ConstantOfShape
|
||||
^\^ ^/^ |
|
||||
^Concat^ NonZero
|
||||
| |
|
||||
| Transpose
|
||||
| |
|
||||
| Squeeze
|
||||
| |
|
||||
| Cast
|
||||
| |
|
||||
| Unsqueeze
|
||||
+--|----------------------------+
|
||||
| |
|
||||
Expand
|
||||
|
|
||||
Gather
|
||||
|
||||
Note that position gather node is the node in the bottom of above sub-graph.
|
||||
Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found.
|
||||
*/
|
||||
static bool MatchPositionEmbeddingSubgraph1(
|
||||
*
|
||||
* Shape -> ^Gather (indice=0)^ -> ^Unsqueeze^
|
||||
* / | +-----------------------+
|
||||
* / v | |
|
||||
* [input_ids] ^Concat^ -> *Reshape* -> *Equal* -> *Where* -> Expand -> Gather
|
||||
* \ | | ("position")
|
||||
* Shape -> ^Gather (indice=1)^ -> ^Unsqueeze^ |
|
||||
* | |
|
||||
* +-------------- # one of the below subgraph patterns # ---------------+
|
||||
* # Unsqueeze -> ConstantOfShape -> NonZero -> Transpose -> Squeeze -> (Cast) -> Unsqueeze #
|
||||
* # or #
|
||||
* # (Cast (to=7)) -> Range (start=0, delta=1) -> Unsqueeze #
|
||||
*
|
||||
* Note that position gather node is the node in the bottom of above sub-graph.
|
||||
* Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found.
|
||||
* Path in ** is an alternative path to check.
|
||||
*/
|
||||
static bool MatchPositionEmbeddingSubgraphsFromGather(
|
||||
Graph& graph,
|
||||
const Node& position_gather_node,
|
||||
const NodeArg* input_ids,
|
||||
const logging::Logger& logger,
|
||||
std::vector<NodeIndex>& subgraph_node_indices) {
|
||||
subgraph_node_indices.clear();
|
||||
|
||||
std::vector<const Node::EdgeEnd*> pg_edges;
|
||||
// Look for Path 1:
|
||||
// Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand --> Gather
|
||||
if (!graph_utils::FindPath(position_gather_node, true,
|
||||
{{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Cast", {9}, kOnnxDomain},
|
||||
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Transpose", {1}, kOnnxDomain},
|
||||
{0, 0, "NonZero", {9}, kOnnxDomain},
|
||||
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}},
|
||||
pg_edges, logger)) {
|
||||
// Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze
|
||||
// --> Cast --> Unsqueeze --> Expand --> Gather
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path_1{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Cast", {9}, kOnnxDomain},
|
||||
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Transpose", {1}, kOnnxDomain},
|
||||
{0, 0, "NonZero", {9}, kOnnxDomain},
|
||||
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
// Look for Path 2 (Path 1 with no cast):
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path_2{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Transpose", {1}, kOnnxDomain},
|
||||
{0, 0, "NonZero", {9}, kOnnxDomain},
|
||||
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
// Path 3 Pattern:
|
||||
// Shape -> Gather -> Cast (to=7) -> Range (start=0, delta=1) -> Unsqueeze -> Expand
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path_3{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Range", {1, 11}, kOnnxDomain},
|
||||
{0, 1, "Cast", {9}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
// Path 4 pattern (Path 3 with no "Cast"):
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path_4{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Range", {1, 11}, kOnnxDomain},
|
||||
{0, 1, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
// Match one of the three path patterns.
|
||||
if (!graph_utils::FindPath(position_gather_node, true, parent_path_1, pg_edges, logger) &&
|
||||
!graph_utils::FindPath(position_gather_node, true, parent_path_2, pg_edges, logger) &&
|
||||
!graph_utils::FindPath(position_gather_node, true, parent_path_3, pg_edges, logger) &&
|
||||
!graph_utils::FindPath(position_gather_node, true, parent_path_4, pg_edges, logger)) {
|
||||
return false;
|
||||
}
|
||||
const size_t gather_index = 8;
|
||||
const size_t gather_index = pg_edges.size() - 2;
|
||||
// All nodes in Path 1 must have only 1 output edge, except the gather node allowed 1 or 2 output edges
|
||||
for (size_t i = 0; i < pg_edges.size(); i++) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 1)) {
|
||||
|
|
@ -251,6 +271,18 @@ static bool MatchPositionEmbeddingSubgraph1(
|
|||
|
||||
Node& expand_node = *graph.GetNode(pg_edges[0]->GetNode().Index());
|
||||
Node& gather_node = *graph.GetNode(pg_edges[gather_index]->GetNode().Index());
|
||||
if (pg_edges[2]->GetNode().OpType() == "Range") {
|
||||
// Check if the values in "start" and "delta" attributes in Range are expected.
|
||||
Node& range_node = *graph.GetNode(pg_edges[2]->GetNode().Index());
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[0]), int64_t(0), true)) {
|
||||
DEBUG_LOG("The first input of Range should be a constant with value 0.");
|
||||
return false;
|
||||
}
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[2]), int64_t(1), true)) {
|
||||
DEBUG_LOG("The third input of Range should be a constant with value 1.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (gather_node.GetOutputEdgesCount() == 1) {
|
||||
// Check if the second input of the Gather node in the path has a constant input of 1
|
||||
|
|
@ -279,7 +311,35 @@ static bool MatchPositionEmbeddingSubgraph1(
|
|||
|
||||
subgraph_node_indices.push_back(shape_node_index);
|
||||
} else { // gather_output_edges_count == 2
|
||||
if (!MatchPositionSubgraph(graph, expand_node, input_ids, logger, subgraph_node_indices, gather_node.Index())) {
|
||||
// Match optional Reshape -> Equal -> Where -> Expand
|
||||
// | |
|
||||
// --------------------
|
||||
std::vector<const Node::EdgeEnd*> pg_edges_2;
|
||||
std::vector<graph_utils::EdgeEndToMatch> path_to_match_1{
|
||||
{0, 1, "Where", {9}, kOnnxDomain},
|
||||
{0, 0, "Equal", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Reshape", {5}, kOnnxDomain}};
|
||||
if (graph_utils::FindPath(expand_node, true, path_to_match_1, pg_edges_2, logger)) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[0]->GetNode(), 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[1]->GetNode(), 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[2]->GetNode(), 2)) {
|
||||
DEBUG_LOG("Optional position subgraph nodes number of outputs unexpected.");
|
||||
return false;
|
||||
}
|
||||
Node& where_node = *graph.GetNode(pg_edges_2[0]->GetNode().Index());
|
||||
Node& reshape_node = *graph.GetNode(pg_edges_2[2]->GetNode().Index());
|
||||
if (where_node.MutableInputDefs()[2] != reshape_node.MutableOutputDefs()[0]) {
|
||||
DEBUG_LOG("Optional position subgraph nodes Where node is expected to be the parent of Reshape.");
|
||||
return false;
|
||||
}
|
||||
// Match [input_ids] -> Gather -> Shape -> Unsqueeze from Reshape node.
|
||||
if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, subgraph_node_indices, gather_node.Index())) {
|
||||
DEBUG_LOG("Failed to match position subgraph.");
|
||||
return false;
|
||||
}
|
||||
AddNodes(subgraph_node_indices, pg_edges_2);
|
||||
} else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, subgraph_node_indices, gather_node.Index())) {
|
||||
// Match [input_ids] -> Gather -> Shape -> Unsqueeze from Expand node.
|
||||
DEBUG_LOG("Failed to match position subgraph.");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -290,90 +350,6 @@ static bool MatchPositionEmbeddingSubgraph1(
|
|||
return true;
|
||||
}
|
||||
|
||||
/** Match subgraph like the following:
|
||||
(input_ids)
|
||||
/ \
|
||||
Shape Shape
|
||||
| |
|
||||
Gather (indice=0) Gather (indice=1)--+
|
||||
| | |
|
||||
Unsqueeze Unsqueeze Cast(to=7) (Cast is optional)
|
||||
\ / |
|
||||
\ / Range(start=0, delta=1)
|
||||
\ / |
|
||||
Concat Unsqueeze
|
||||
| |
|
||||
+--|----------------------------+
|
||||
| |
|
||||
Expand
|
||||
|
|
||||
Gather
|
||||
|
||||
Note that position gather node is the node in the bottom of above sub-graph.
|
||||
*/
|
||||
|
||||
static bool MatchPositionEmbeddingSubgraph2(
|
||||
Graph& graph,
|
||||
const Node& position_gather_node,
|
||||
const NodeArg* input_ids,
|
||||
const logging::Logger& logger,
|
||||
std::vector<NodeIndex>& subgraph_node_indices) {
|
||||
subgraph_node_indices.clear();
|
||||
|
||||
// Match Gather <-- Expand <-- Unsqueeze <-- Range <-- Cast <-- Gather
|
||||
// Since Range is from opset 11, we only match opset 11 here.
|
||||
std::vector<NodeIndex> position_parent_nodes;
|
||||
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path_symbolic{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {11}, kOnnxDomain},
|
||||
{0, 0, "Range", {11}, kOnnxDomain},
|
||||
{0, 1, "Cast", {9}, kOnnxDomain},
|
||||
{0, 0, "Gather", {11}, kOnnxDomain}};
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
|
||||
if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) {
|
||||
// Cast node might be removed by other optimizer. Here we check a pattern without Cast node.
|
||||
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path_no_cast{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {11}, kOnnxDomain},
|
||||
{0, 0, "Range", {11}, kOnnxDomain},
|
||||
{0, 1, "Gather", {11}, kOnnxDomain}};
|
||||
if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_no_cast, edges, logger)) {
|
||||
DEBUG_LOG("Failed to find path 1.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
size_t last_edge = edges.size() - 1;
|
||||
for (size_t i = 0; i < edges.size(); i++) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, edges[i]->GetNode(), (i == last_edge ? 2u : 1u))) {
|
||||
DEBUG_LOG("Output edge count not expected for nodes in path 1.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Node& expand_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
Node& range_node = *graph.GetNode(edges[2]->GetNode().Index());
|
||||
Node& gather_node_1 = *graph.GetNode(edges[last_edge]->GetNode().Index());
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[0]), int64_t(0), true)) {
|
||||
DEBUG_LOG("The first input of Range should be a constant with value 0.");
|
||||
return false;
|
||||
}
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[2]), int64_t(1), true)) {
|
||||
DEBUG_LOG("The third input of Range should be a constant with value 1.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!MatchPositionSubgraph(graph, expand_node, input_ids, logger, subgraph_node_indices, gather_node_1.Index())) {
|
||||
DEBUG_LOG("Failed to match position subgraph.");
|
||||
return false;
|
||||
}
|
||||
|
||||
AddNodes(subgraph_node_indices, edges);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool MatchPositionEmbeddingSubgraph(
|
||||
Graph& graph,
|
||||
const Node& add_node,
|
||||
|
|
@ -422,10 +398,8 @@ static bool MatchPositionEmbeddingSubgraph(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if (!MatchPositionEmbeddingSubgraph1(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) {
|
||||
if (!MatchPositionEmbeddingSubgraph2(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) {
|
||||
return false;
|
||||
}
|
||||
if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2424,6 +2424,35 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
EXPECT_EQ(op_to_count["Shape"], 0);
|
||||
EXPECT_EQ(op_to_count["Expand"], 0);
|
||||
EXPECT_EQ(op_to_count["Gather"], 0);
|
||||
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
|
||||
EXPECT_EQ(op_to_count["Reshape"], 0);
|
||||
EXPECT_EQ(op_to_count["Equal"], 0);
|
||||
EXPECT_EQ(op_to_count["Where"], 0);
|
||||
EXPECT_EQ(op_to_count["LayerNormalization"], 0);
|
||||
EXPECT_EQ(op_to_count["SkipLayerNormalization"], 0);
|
||||
EXPECT_EQ(op_to_count["ReduceSum"], 0);
|
||||
EXPECT_EQ(op_to_count["MatMul"], 1);
|
||||
EXPECT_EQ(op_to_count["Add"], 2);
|
||||
EXPECT_EQ(op_to_count["Cast"], 3);
|
||||
EXPECT_EQ(op_to_count["Attention"], 1);
|
||||
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, DynamicQuantizeMatMulTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/dynamic_quantize_matmul.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -143,7 +143,86 @@ def GenerateModel5(model_name):
|
|||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
def GenerateModel6(model_name):
|
||||
nodes = [ # LayerNorm subgraph
|
||||
helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
|
||||
helper.make_node("Gather", ["shape1_out", "indices_0"], ["gather0_out"], "gather0"),
|
||||
helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
|
||||
helper.make_node("Shape", ["input_ids"], ["shape2_out"], "shape2"),
|
||||
helper.make_node("Gather", ["shape2_out", "indices_1"], ["gather1_out"], "gather1"),
|
||||
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
|
||||
helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out"], ["concat_out"], "concat", axis=0),
|
||||
|
||||
helper.make_node("Reshape", ["concat_out", "reshape_init"], ["reshape_out"], "reshape"),
|
||||
helper.make_node("Equal", ["reshape_out", "equal_init"], ["equal_out"], "equal"),
|
||||
helper.make_node("Where", ["equal_out", "where_init", "reshape_out"], ["where_out"], "where"),
|
||||
|
||||
helper.make_node("Range", ["start_0", "gather1_out", "delta_1"], ["range_out"],
|
||||
"range"),
|
||||
helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
|
||||
|
||||
helper.make_node("Expand", ["unsqueeze2_out", "where_out"], ["expand_out"], "expand"),
|
||||
helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather"),
|
||||
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather"),
|
||||
helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["word_add_pos_out"], "word_add_pos"),
|
||||
helper.make_node("Gather", ["seg_embed", "segment_ids"], ["seg_gather_out"], "seg_gather"),
|
||||
helper.make_node("Add", ["word_add_pos_out", "seg_gather_out"], ["add3_out"], "add3"),
|
||||
helper.make_node("LayerNormalization", ["add3_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"],
|
||||
"layernorm",
|
||||
axis=-1,
|
||||
epsion=0.000009999999747378752),
|
||||
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
|
||||
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
|
||||
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
|
||||
"att",
|
||||
domain="com.microsoft",
|
||||
num_heads=2),
|
||||
helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"),
|
||||
helper.make_node("Add", ["matmul_out", "add_bias"], ["add_out"], "add"),
|
||||
helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2")
|
||||
]
|
||||
|
||||
|
||||
# hidden_size=4, num_heads=2, max_seq_length=3
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
|
||||
helper.make_tensor('start_0', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('delta_1', TensorProto.INT64, [], [1]),
|
||||
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('pos_embed', TensorProto.FLOAT, [4, 4],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('reshape_init', TensorProto.INT64, [1], [-1]),
|
||||
helper.make_tensor('equal_init', TensorProto.INT64, [2], [-1, -1]),
|
||||
helper.make_tensor('where_init', TensorProto.INT64, [2], [1, 1]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"EmbedLayerNorm_format6", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('input_ids', TensorProto.INT64, ['batch', 3]),
|
||||
helper.make_tensor_value_info('segment_ids', TensorProto.INT64, ['batch', 3]),
|
||||
helper.make_tensor_value_info('input_mask', TensorProto.INT64, ['batch', 3]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('add2_out', TensorProto.FLOAT, ['batch', 3, 4]),
|
||||
],
|
||||
initializers)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
GenerateModel3('embed_layer_norm_format3.onnx', True)
|
||||
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
|
||||
GenerateModel5('embed_layer_norm_format5.onnx')
|
||||
GenerateModel6('embed_layer_norm_format6.onnx')
|
||||
Loading…
Reference in a new issue