mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Support EmbedLayerNorm fusion for DistilBert (#4928)
* checkin embedlayernorm fusion for distilbert * move function from optimizer_utils * review comments
This commit is contained in:
parent
00fe718264
commit
792ed44537
4 changed files with 475 additions and 228 deletions
|
|
@ -77,6 +77,16 @@ static void AddNodes(std::vector<NodeIndex>& node_indices,
|
|||
}
|
||||
}
|
||||
|
||||
static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Node::NodeConstIterator end, const std::vector<std::string>& expected_types) {
|
||||
for (const std::string& expected_type : expected_types) {
|
||||
if (start == end || (*start).OpType().compare(expected_type) != 0) {
|
||||
return false;
|
||||
}
|
||||
++start;
|
||||
}
|
||||
return start == end;
|
||||
}
|
||||
|
||||
/** Match subgraph like the following:
|
||||
(input_ids)
|
||||
/ \
|
||||
|
|
@ -184,7 +194,7 @@ static bool MatchInputToConcatSubgraph(
|
|||
}
|
||||
|
||||
/** Match subgraph like the following:
|
||||
*
|
||||
*
|
||||
* Shape -> ^Gather (indice=0)^ -> ^Unsqueeze^
|
||||
* / | +-----------------------+
|
||||
* / v | |
|
||||
|
|
@ -196,7 +206,7 @@ static bool MatchInputToConcatSubgraph(
|
|||
* # Unsqueeze -> ConstantOfShape -> NonZero -> Transpose -> Squeeze -> (Cast) -> Unsqueeze #
|
||||
* # or #
|
||||
* # (Cast (to=7)) -> Range (start=0, delta=1) -> Unsqueeze #
|
||||
*
|
||||
*
|
||||
* Note that position gather node is the node in the bottom of above sub-graph.
|
||||
* Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found.
|
||||
* Path in ** is an alternative path to check.
|
||||
|
|
@ -213,43 +223,43 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
|
|||
// Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze
|
||||
// --> Cast --> Unsqueeze --> Expand --> Gather
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path_1{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Cast", {9}, kOnnxDomain},
|
||||
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Transpose", {1}, kOnnxDomain},
|
||||
{0, 0, "NonZero", {9}, kOnnxDomain},
|
||||
{0, 1, "Expand", {8, 13}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Cast", {9, 13}, kOnnxDomain},
|
||||
{0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Transpose", {1, 13}, kOnnxDomain},
|
||||
{0, 0, "NonZero", {9, 13}, kOnnxDomain},
|
||||
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
|
||||
// Look for Path 2 (Path 1 with no cast):
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path_2{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Transpose", {1}, kOnnxDomain},
|
||||
{0, 0, "NonZero", {9}, kOnnxDomain},
|
||||
{0, 1, "Expand", {8, 13}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Transpose", {1, 13}, kOnnxDomain},
|
||||
{0, 0, "NonZero", {9, 13}, kOnnxDomain},
|
||||
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
|
||||
// Path 3 Pattern:
|
||||
// Shape -> Gather -> Cast (to=7) -> Range (start=0, delta=1) -> Unsqueeze -> Expand
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path_3{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 1, "Expand", {8, 13}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Range", {1, 11}, kOnnxDomain},
|
||||
{0, 1, "Cast", {9}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
{0, 1, "Cast", {9, 13}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
|
||||
// Path 4 pattern (Path 3 with no "Cast"):
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path_4{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 1, "Expand", {8, 13}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Range", {1, 11}, kOnnxDomain},
|
||||
{0, 1, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
{0, 1, "Gather", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
|
||||
// Match one of the three path patterns.
|
||||
if (!graph_utils::FindPath(position_gather_node, true, parent_path_1, pg_edges, logger) &&
|
||||
!graph_utils::FindPath(position_gather_node, true, parent_path_2, pg_edges, logger) &&
|
||||
|
|
@ -317,8 +327,8 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
|
|||
std::vector<const Node::EdgeEnd*> pg_edges_2;
|
||||
std::vector<graph_utils::EdgeEndToMatch> path_to_match_1{
|
||||
{0, 1, "Where", {9}, kOnnxDomain},
|
||||
{0, 0, "Equal", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Reshape", {5}, kOnnxDomain}};
|
||||
{0, 0, "Equal", {1, 7, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Reshape", {5, 13}, kOnnxDomain}};
|
||||
if (graph_utils::FindPath(expand_node, true, path_to_match_1, pg_edges_2, logger)) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[0]->GetNode(), 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[1]->GetNode(), 1) ||
|
||||
|
|
@ -361,7 +371,7 @@ static bool MatchPositionEmbeddingSubgraph(
|
|||
// Constant folding removes Shape and Expand nodes when input has static shape.
|
||||
// In that case just look for Gather --> Add.
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (!graph_utils::FindPath(add_node, true, {{0, 1, "Gather", {1, 11}, kOnnxDomain}}, edges, logger)) {
|
||||
if (!graph_utils::FindPath(add_node, true, {{0, 1, "Gather", {1, 11, 13}, kOnnxDomain}}, edges, logger)) {
|
||||
return false;
|
||||
}
|
||||
Node& position_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
|
|
@ -461,6 +471,339 @@ static NodeArg* ExtractEmbedding(Graph& graph,
|
|||
return &node_arg;
|
||||
}
|
||||
|
||||
static void CreateEmbedLayernormNode(Graph& graph,
|
||||
NodeArg* input_ids,
|
||||
NodeArg* segment_ids,
|
||||
NodeArg* word_embedding,
|
||||
NodeArg* position_embedding,
|
||||
NodeArg* segment_embedding,
|
||||
NodeArg* mask,
|
||||
Node& layer_norm_node,
|
||||
Node& reduce_sum_node) {
|
||||
// Cast input_ids, segment_ids, and mask to int32 if needed.
|
||||
input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType());
|
||||
if (segment_ids != nullptr && segment_embedding != nullptr) {
|
||||
segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType());
|
||||
}
|
||||
mask = CastToInt32(graph, mask, layer_norm_node.GetExecutionProviderType());
|
||||
|
||||
NodeArg place_holder("", nullptr);
|
||||
if (segment_ids == nullptr && segment_embedding == nullptr) {
|
||||
segment_ids = &place_holder;
|
||||
segment_embedding = &place_holder;
|
||||
}
|
||||
|
||||
const std::vector<NodeArg*> embed_layer_norm_input_defs{
|
||||
input_ids,
|
||||
segment_ids,
|
||||
word_embedding,
|
||||
position_embedding,
|
||||
segment_embedding,
|
||||
layer_norm_node.MutableInputDefs()[1],
|
||||
layer_norm_node.MutableInputDefs()[2],
|
||||
mask};
|
||||
|
||||
Node& embed_layer_norm_node = graph.AddNode(graph.GenerateNodeName("EmbedLayerNormalization"),
|
||||
"EmbedLayerNormalization",
|
||||
"fused EmbedLayerNorm subgraphs ",
|
||||
embed_layer_norm_input_defs,
|
||||
{layer_norm_node.MutableOutputDefs()[0], reduce_sum_node.MutableOutputDefs()[0]},
|
||||
{}, kMSDomain);
|
||||
|
||||
// Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value
|
||||
// will be used.
|
||||
NodeAttributes ln_attrs = layer_norm_node.GetAttributes();
|
||||
NodeAttributes::const_iterator epsilon = ln_attrs.find("epsilon");
|
||||
if (epsilon != ln_attrs.end()) {
|
||||
embed_layer_norm_node.AddAttribute("epsilon", epsilon->second);
|
||||
} else {
|
||||
embed_layer_norm_node.AddAttribute("epsilon", contrib::kDefaultEmbedLayerNormEpsilon);
|
||||
}
|
||||
|
||||
// Assign provider to this new node. Provider should be same as the provider for old node.
|
||||
embed_layer_norm_node.SetExecutionProviderType(layer_norm_node.GetExecutionProviderType());
|
||||
}
|
||||
|
||||
static bool FuseSubGraph(Graph& graph,
|
||||
Node& layer_norm_add_node,
|
||||
Node& layer_norm_node,
|
||||
Node& reduce_sum_node,
|
||||
bool& modified,
|
||||
const logging::Logger& logger) {
|
||||
// Trace back to find the Gather for segment embedding.
|
||||
std::vector<graph_utils::EdgeEndToMatch> segment_embedding_path{
|
||||
{0, 1, "Gather", {1, 11, 13}, kOnnxDomain}};
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) {
|
||||
return false;
|
||||
}
|
||||
Node& segment_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, segment_gather_node, 1)) {
|
||||
return false;
|
||||
}
|
||||
// The first input of segment_gather_node must be 2d.
|
||||
NodeArg* segment_embedding = segment_gather_node.MutableInputDefs()[0];
|
||||
auto sg_shape = segment_embedding->Shape();
|
||||
if (sg_shape == nullptr || sg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(sg_shape->dim()[1]) ||
|
||||
sg_shape->dim()[1].dim_value() <= 0) {
|
||||
return false;
|
||||
}
|
||||
auto hidden_size = sg_shape->dim()[1].dim_value();
|
||||
|
||||
// Trace back to find Gather --> Add --> LayerNormalization
|
||||
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
|
||||
{0, 0, "Add", {7, 13}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
|
||||
if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) {
|
||||
return false;
|
||||
}
|
||||
Node& add_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
Node& word_gather_node = *graph.GetNode(edges[1]->GetNode().Index());
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, word_gather_node, 1)) {
|
||||
return false;
|
||||
}
|
||||
// The first input of word_gather_node must be 2d.
|
||||
NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0];
|
||||
auto wg_shape = word_embedding->Shape();
|
||||
if (wg_shape == nullptr || wg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(wg_shape->dim()[1]) ||
|
||||
wg_shape->dim()[1].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Word embedding shape not expected.");
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeArg* input_ids = word_gather_node.MutableInputDefs()[1];
|
||||
NodeArg* position_embedding = nullptr;
|
||||
std::vector<NodeIndex> nodes_to_remove;
|
||||
|
||||
// ORT constant folding might be applied to position embedding subgraph when input has static shape.
|
||||
// Here we handle such special case that the input of add node is constant initializer.
|
||||
auto add_input_name = add_node.MutableInputDefs()[1]->Name();
|
||||
if (graph_utils::IsConstantInitializer(graph, add_input_name)) {
|
||||
// Check that input has static shape.
|
||||
auto input_shape = input_ids->Shape();
|
||||
if (input_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(input_shape->dim()[0]) ||
|
||||
!utils::HasDimValue(input_shape->dim()[1])) {
|
||||
DEBUG_LOG("Input is expected to have dim value in all dimensions.");
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t batch_size = input_shape->dim()[0].dim_value();
|
||||
int64_t sequence_length = input_shape->dim()[1].dim_value();
|
||||
if (batch_size <= 0 || sequence_length <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* position_embed_tensor;
|
||||
if (!graph.GetInitializedTensor(add_input_name, position_embed_tensor)) {
|
||||
DEBUG_LOG("Failed to get initializer tensor.");
|
||||
return false;
|
||||
}
|
||||
// Tensor shape shall be [batch_size, sequence_length, hidden_size].
|
||||
if (position_embed_tensor->dims_size() != 3 ||
|
||||
position_embed_tensor->dims(0) != batch_size ||
|
||||
position_embed_tensor->dims(1) != sequence_length ||
|
||||
position_embed_tensor->dims(2) != hidden_size) {
|
||||
DEBUG_LOG("Position embedding shape not matched.");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Tensor data type should be float or float16.
|
||||
const auto data_type = position_embed_tensor->data_type();
|
||||
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
DEBUG_LOG("Position embedding data type shall be float or float16.");
|
||||
return false;
|
||||
}
|
||||
|
||||
// The tensor has same data for all batches, and we extract only one batch data as position embedding.
|
||||
position_embedding = ExtractEmbedding(graph, batch_size, sequence_length, hidden_size, position_embed_tensor, modified);
|
||||
} else {
|
||||
if (!MatchPositionEmbeddingSubgraph(graph, add_node, input_ids, logger, nodes_to_remove, position_embedding)) {
|
||||
DEBUG_LOG("Failed to match position embedding subgraph.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (position_embedding == nullptr) {
|
||||
DEBUG_LOG("Failed to get position embedding weights.");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto pg_shape = position_embedding->Shape();
|
||||
if (pg_shape == nullptr || pg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(pg_shape->dim()[1]) ||
|
||||
pg_shape->dim()[1].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Position embedding shape is not expected.");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get input "input_ids" from node.
|
||||
if (!CheckInput(input_ids, logger)) {
|
||||
DEBUG_LOG("Input id is not valid. ");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get input "segment_ids" from node.
|
||||
NodeArg* segment_ids = segment_gather_node.MutableInputDefs()[1];
|
||||
if (!CheckInput(segment_ids, logger)) {
|
||||
DEBUG_LOG("Segment id is not valid. ");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get input "mask" from "ReduceSum" node.
|
||||
NodeArg* mask = reduce_sum_node.MutableInputDefs()[0];
|
||||
if (!CheckInput(mask, logger)) {
|
||||
DEBUG_LOG("Mask is not valid. ");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
|
||||
utils::GetTensorShapeFromTensorShapeProto(*(segment_ids->Shape()))) {
|
||||
DEBUG_LOG("Input_ids and segment id should have the same shape. ");
|
||||
return false;
|
||||
}
|
||||
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
|
||||
utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) {
|
||||
DEBUG_LOG("Input_ids and mask should have the same shape. ");
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeArg* gamma = layer_norm_node.MutableInputDefs()[1];
|
||||
NodeArg* beta = layer_norm_node.MutableInputDefs()[2];
|
||||
if (gamma->Shape() == nullptr || gamma->Shape()->dim()[0].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Gamma should be of shape (hidden_size). ");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (beta->Shape() == nullptr || beta->Shape()->dim()[0].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Beta should be of shape (hidden_size). ");
|
||||
return false;
|
||||
}
|
||||
|
||||
CreateEmbedLayernormNode(graph, input_ids, segment_ids, word_embedding, position_embedding, segment_embedding,
|
||||
mask, layer_norm_node, reduce_sum_node);
|
||||
|
||||
nodes_to_remove.push_back(word_gather_node.Index());
|
||||
nodes_to_remove.push_back(segment_gather_node.Index());
|
||||
nodes_to_remove.push_back(add_node.Index());
|
||||
nodes_to_remove.push_back(reduce_sum_node.Index());
|
||||
nodes_to_remove.push_back(layer_norm_add_node.Index());
|
||||
nodes_to_remove.push_back(layer_norm_node.Index());
|
||||
|
||||
for (const auto& index : nodes_to_remove) {
|
||||
Node* node = graph.GetNode(index);
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *node);
|
||||
graph.RemoveNode(node->Index());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// DistilBert's pattern does not have segment embedding
|
||||
static bool FuseSubGraphDistilBert(Graph& graph,
|
||||
Node& layer_norm_add_node,
|
||||
Node& layer_norm_node,
|
||||
Node& reduce_sum_node,
|
||||
const logging::Logger& logger) {
|
||||
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) {
|
||||
return false;
|
||||
}
|
||||
Node& word_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, word_gather_node, 1)) {
|
||||
return false;
|
||||
}
|
||||
// The first input of word_gather_node must be 2d.
|
||||
NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0];
|
||||
auto wg_shape = word_embedding->Shape();
|
||||
if (wg_shape == nullptr || wg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(wg_shape->dim()[1])) {
|
||||
DEBUG_LOG("Word embedding shape not expected.");
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t hidden_size = wg_shape->dim()[1].dim_value();
|
||||
|
||||
Node& add_node = layer_norm_add_node;
|
||||
|
||||
NodeArg* input_ids = word_gather_node.MutableInputDefs()[1];
|
||||
NodeArg* position_embedding = nullptr;
|
||||
std::vector<NodeIndex> nodes_to_remove;
|
||||
|
||||
// ORT constant folding might be applied to position embedding subgraph when input has static shape.
|
||||
// Here we handle such special case that the input of add node is constant initializer.
|
||||
auto add_input_name = add_node.MutableInputDefs()[1]->Name();
|
||||
if (!MatchPositionEmbeddingSubgraph(graph, add_node, input_ids, logger, nodes_to_remove, position_embedding)) {
|
||||
DEBUG_LOG("Failed to match position embedding subgraph.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (position_embedding == nullptr) {
|
||||
DEBUG_LOG("Failed to get position embedding weights.");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto pg_shape = position_embedding->Shape();
|
||||
if (pg_shape == nullptr || pg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(pg_shape->dim()[1]) ||
|
||||
pg_shape->dim()[1].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Position embedding shape is not expected.");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get input "input_ids" from node.
|
||||
if (!CheckInput(input_ids, logger)) {
|
||||
DEBUG_LOG("Input id is not valid. ");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get input "mask" from "ReduceSum" node.
|
||||
NodeArg* mask = reduce_sum_node.MutableInputDefs()[0];
|
||||
if (!CheckInput(mask, logger)) {
|
||||
DEBUG_LOG("Mask is not valid. ");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
|
||||
utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) {
|
||||
DEBUG_LOG("Input_ids and mask should have the same shape. ");
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeArg* gamma = layer_norm_node.MutableInputDefs()[1];
|
||||
NodeArg* beta = layer_norm_node.MutableInputDefs()[2];
|
||||
if (gamma->Shape() == nullptr || gamma->Shape()->dim()[0].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Gamma should be of shape (hidden_size). ");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (beta->Shape() == nullptr || beta->Shape()->dim()[0].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Beta should be of shape (hidden_size). ");
|
||||
return false;
|
||||
}
|
||||
|
||||
CreateEmbedLayernormNode(graph, input_ids, nullptr, word_embedding, position_embedding, nullptr,
|
||||
mask, layer_norm_node, reduce_sum_node);
|
||||
|
||||
nodes_to_remove.push_back(word_gather_node.Index());
|
||||
nodes_to_remove.push_back(add_node.Index());
|
||||
nodes_to_remove.push_back(reduce_sum_node.Index());
|
||||
nodes_to_remove.push_back(layer_norm_node.Index());
|
||||
|
||||
for (const auto& index : nodes_to_remove) {
|
||||
Node* node = graph.GetNode(index);
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *node);
|
||||
graph.RemoveNode(node->Index());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
/**
|
||||
Embed Layer Normalization will fuse embeddings and mask processing into one node :
|
||||
The embeddings before conversion:
|
||||
|
|
@ -512,208 +855,17 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
}
|
||||
Node& layer_norm_add_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
|
||||
// Trace back to find the Gather for segment embedding.
|
||||
std::vector<graph_utils::EdgeEndToMatch> segment_embedding_path{
|
||||
{0, 1, "Gather", {1, 11, 13}, kOnnxDomain}};
|
||||
if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) {
|
||||
continue;
|
||||
}
|
||||
Node& segment_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, segment_gather_node, 1)) {
|
||||
continue;
|
||||
}
|
||||
// The first input of segment_gather_node must be 2d.
|
||||
NodeArg* segment_embedding = segment_gather_node.MutableInputDefs()[0];
|
||||
auto sg_shape = segment_embedding->Shape();
|
||||
if (sg_shape == nullptr || sg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(sg_shape->dim()[1]) ||
|
||||
sg_shape->dim()[1].dim_value() <= 0) {
|
||||
continue;
|
||||
}
|
||||
auto hidden_size = sg_shape->dim()[1].dim_value();
|
||||
|
||||
// Trace back to find Gather --> Add --> LayerNormalization
|
||||
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
|
||||
{0, 0, "Add", {7, 13}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
|
||||
if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) {
|
||||
continue;
|
||||
}
|
||||
Node& add_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
Node& word_gather_node = *graph.GetNode(edges[1]->GetNode().Index());
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, word_gather_node, 1)) {
|
||||
continue;
|
||||
}
|
||||
// The first input of word_gather_node must be 2d.
|
||||
NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0];
|
||||
auto wg_shape = word_embedding->Shape();
|
||||
if (wg_shape == nullptr || wg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(wg_shape->dim()[1]) ||
|
||||
wg_shape->dim()[1].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Word embedding shape not expected.");
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeArg* input_ids = word_gather_node.MutableInputDefs()[1];
|
||||
NodeArg* position_embedding = nullptr;
|
||||
std::vector<NodeIndex> nodes_to_remove;
|
||||
|
||||
// ORT constant folding might be applied to position embedding subgraph when input has static shape.
|
||||
// Here we handle such special case that the input of add node is constant initializer.
|
||||
auto add_input_name = add_node.MutableInputDefs()[1]->Name();
|
||||
if (graph_utils::IsConstantInitializer(graph, add_input_name)) {
|
||||
// Check that input has static shape.
|
||||
auto input_shape = input_ids->Shape();
|
||||
if (input_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(input_shape->dim()[0]) ||
|
||||
!utils::HasDimValue(input_shape->dim()[1])) {
|
||||
DEBUG_LOG("Input is expected to have dim value in all dimensions.");
|
||||
continue;
|
||||
if (IsNeighborNodeExpectedTypes(layer_norm_add_node.InputEdgesBegin(), layer_norm_add_node.InputNodesEnd(), {"Gather", "Gather"})) {
|
||||
//DistilBert
|
||||
if (FuseSubGraphDistilBert(graph, layer_norm_add_node, layer_norm_node, reduce_sum_node, logger)) {
|
||||
modified = true;
|
||||
}
|
||||
|
||||
int64_t batch_size = input_shape->dim()[0].dim_value();
|
||||
int64_t sequence_length = input_shape->dim()[1].dim_value();
|
||||
if (batch_size <= 0 || sequence_length <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* position_embed_tensor;
|
||||
if (!graph.GetInitializedTensor(add_input_name, position_embed_tensor)) {
|
||||
DEBUG_LOG("Failed to get initializer tensor.");
|
||||
continue;
|
||||
}
|
||||
// Tensor shape shall be [batch_size, sequence_length, hidden_size].
|
||||
if (position_embed_tensor->dims_size() != 3 ||
|
||||
position_embed_tensor->dims(0) != batch_size ||
|
||||
position_embed_tensor->dims(1) != sequence_length ||
|
||||
position_embed_tensor->dims(2) != hidden_size) {
|
||||
DEBUG_LOG("Position embedding shape not matched.");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Tensor data type should be float or float16.
|
||||
const auto data_type = position_embed_tensor->data_type();
|
||||
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
DEBUG_LOG("Position embedding data type shall be float or float16.");
|
||||
continue;
|
||||
}
|
||||
|
||||
// The tensor has same data for all batches, and we extract only one batch data as position embedding.
|
||||
position_embedding = ExtractEmbedding(graph, batch_size, sequence_length, hidden_size, position_embed_tensor, modified);
|
||||
} else {
|
||||
if (!MatchPositionEmbeddingSubgraph(graph, add_node, input_ids, logger, nodes_to_remove, position_embedding)) {
|
||||
DEBUG_LOG("Failed to match position embedding subgraph.");
|
||||
continue;
|
||||
if (FuseSubGraph(graph, layer_norm_add_node, layer_norm_node, reduce_sum_node, modified, logger)) {
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (position_embedding == nullptr) {
|
||||
DEBUG_LOG("Failed to get position embedding weights.");
|
||||
continue;
|
||||
}
|
||||
|
||||
auto pg_shape = position_embedding->Shape();
|
||||
if (pg_shape == nullptr || pg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(pg_shape->dim()[1]) ||
|
||||
pg_shape->dim()[1].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Position embedding shape is not expected.");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get input "input_ids" from node.
|
||||
if (!CheckInput(input_ids, logger)) {
|
||||
DEBUG_LOG("Input id is not valid. ");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get input "segment_ids" from node.
|
||||
NodeArg* segment_ids = segment_gather_node.MutableInputDefs()[1];
|
||||
if (!CheckInput(segment_ids, logger)) {
|
||||
DEBUG_LOG("Segment id is not valid. ");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get input "mask" from "ReduceSum" node.
|
||||
NodeArg* mask = reduce_sum_node.MutableInputDefs()[0];
|
||||
if (!CheckInput(mask, logger)) {
|
||||
DEBUG_LOG("Mask is not valid. ");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
|
||||
utils::GetTensorShapeFromTensorShapeProto(*(segment_ids->Shape()))) {
|
||||
DEBUG_LOG("Input_ids and segment id should have the same shape. ");
|
||||
continue;
|
||||
}
|
||||
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
|
||||
utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) {
|
||||
DEBUG_LOG("Input_ids and mask should have the same shape. ");
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeArg* gamma = layer_norm_node.MutableInputDefs()[1];
|
||||
NodeArg* beta = layer_norm_node.MutableInputDefs()[2];
|
||||
if (gamma->Shape() == nullptr || gamma->Shape()->dim()[0].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Gamma should be of shape (hidden_size). ");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (beta->Shape() == nullptr || beta->Shape()->dim()[0].dim_value() != hidden_size) {
|
||||
DEBUG_LOG("Beta should be of shape (hidden_size). ");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Cast input_ids, segment_ids, and mask to int32 if needed.
|
||||
input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType());
|
||||
segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType());
|
||||
mask = CastToInt32(graph, mask, layer_norm_node.GetExecutionProviderType());
|
||||
|
||||
const std::vector<NodeArg*> embed_layer_norm_input_defs{
|
||||
input_ids,
|
||||
segment_ids,
|
||||
word_embedding,
|
||||
position_embedding,
|
||||
segment_embedding,
|
||||
layer_norm_node.MutableInputDefs()[1],
|
||||
layer_norm_node.MutableInputDefs()[2],
|
||||
mask};
|
||||
Node& embed_layer_norm_node = graph.AddNode(graph.GenerateNodeName("EmbedLayerNormalization"),
|
||||
"EmbedLayerNormalization",
|
||||
"fused EmbedLayerNorm subgraphs ",
|
||||
embed_layer_norm_input_defs,
|
||||
{layer_norm_node.MutableOutputDefs()[0], reduce_sum_node.MutableOutputDefs()[0]},
|
||||
{}, kMSDomain);
|
||||
|
||||
// Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value
|
||||
// will be used.
|
||||
NodeAttributes ln_attrs = layer_norm_node.GetAttributes();
|
||||
NodeAttributes::const_iterator epsilon = ln_attrs.find("epsilon");
|
||||
if (epsilon != ln_attrs.end()) {
|
||||
embed_layer_norm_node.AddAttribute("epsilon", epsilon->second);
|
||||
} else {
|
||||
embed_layer_norm_node.AddAttribute("epsilon", contrib::kDefaultEmbedLayerNormEpsilon);
|
||||
}
|
||||
|
||||
// Assign provider to this new node. Provider should be same as the provider for old node.
|
||||
embed_layer_norm_node.SetExecutionProviderType(layer_norm_node.GetExecutionProviderType());
|
||||
|
||||
nodes_to_remove.push_back(word_gather_node.Index());
|
||||
nodes_to_remove.push_back(segment_gather_node.Index());
|
||||
nodes_to_remove.push_back(add_node.Index());
|
||||
nodes_to_remove.push_back(reduce_sum_node.Index());
|
||||
nodes_to_remove.push_back(layer_norm_add_node.Index());
|
||||
nodes_to_remove.push_back(layer_norm_node.Index());
|
||||
|
||||
for (const auto& index : nodes_to_remove) {
|
||||
Node* node = graph.GetNode(index);
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *node);
|
||||
graph.RemoveNode(node->Index());
|
||||
}
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2611,6 +2611,28 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
|
|||
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
|
||||
}
|
||||
|
||||
//DistilBert
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
|
||||
auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
|
||||
ASSERT_TRUE(ret1.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
|
||||
EXPECT_EQ(op_to_count["Attention"], 1);
|
||||
EXPECT_EQ(op_to_count["Cast"], 2);
|
||||
EXPECT_EQ(op_to_count["Shape"], 0);
|
||||
EXPECT_EQ(op_to_count["Gather"], 0);
|
||||
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
|
||||
EXPECT_EQ(op_to_count["ReduceSum"], 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -278,9 +278,82 @@ def GenerateModel6(model_name):
|
|||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
def GenerateModel7(model_name):
|
||||
batch_size = 2
|
||||
hidden_size = 4
|
||||
attention_heads = 2
|
||||
sequence_length = 3
|
||||
|
||||
nodes = [
|
||||
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0),
|
||||
|
||||
helper.make_node("Shape", ["input_ids"], ["shape0_out"], "shape0"),
|
||||
helper.make_node("Gather", ["shape0_out", "indices_1"], ["gather0_out"], "gather0"),
|
||||
helper.make_node("Range", ["start", "gather0_out", "delta"], ["range0_out"], "range0"),
|
||||
helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
|
||||
helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
|
||||
helper.make_node("Expand", ["unsqueeze0_out", "shape1_out"], ["expand_out"], "expand"),
|
||||
helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather", axis=0),
|
||||
|
||||
helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["add1_out"], "add1"),
|
||||
helper.make_node("LayerNormalization", ["add1_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"],
|
||||
"layernorm",
|
||||
axis=-1,
|
||||
epsion=0.000009999999747378752),
|
||||
|
||||
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
|
||||
|
||||
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
|
||||
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
|
||||
"att",
|
||||
domain="com.microsoft",
|
||||
num_heads=attention_heads),
|
||||
helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"),
|
||||
helper.make_node("Add", ["matmul_out", "add_bias"], ["add2_out"], "add2"),
|
||||
helper.make_node("Add", ["add2_out", "layernorm_out"], ["add3_out"], "add3")
|
||||
]
|
||||
|
||||
qkv_weights = [
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
|
||||
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
|
||||
]
|
||||
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
|
||||
helper.make_tensor('start', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('delta', TensorProto.INT64, [], [1]),
|
||||
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights),
|
||||
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size],
|
||||
[0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"EmbedLayerNorm_format7", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]),
|
||||
helper.make_tensor_value_info('input_mask', TensorProto.INT64, [batch_size, sequence_length]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]),
|
||||
],
|
||||
initializers)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
GenerateModel3('embed_layer_norm_format3.onnx', True)
|
||||
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
|
||||
GenerateModel5('embed_layer_norm_format5.onnx')
|
||||
GenerateModel6('embed_layer_norm_format6.onnx')
|
||||
GenerateModel7('embed_layer_norm_format7.onnx') #distilbert
|
||||
GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')
|
||||
|
|
|
|||
Loading…
Reference in a new issue