mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Attention with past and no unidirectional mask (#5557)
* Update fusions to support shared node, and mask of all ones
This commit is contained in:
parent
0a9b83a313
commit
1f304fbee7
7 changed files with 208 additions and 38 deletions
|
|
@ -98,10 +98,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
|
||||
int past_sequence_length = 0;
|
||||
if (past != nullptr) { // past is optional
|
||||
if (!is_unidirectional_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is only allowed for unidirectional");
|
||||
}
|
||||
|
||||
const auto& past_dims = past->Shape().GetDims();
|
||||
if (past_dims.size() != 5) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is expected to have 5 dimension, got ",
|
||||
|
|
|
|||
|
|
@ -208,7 +208,7 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
Node& node = *p_node;
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if ((node.GetOutputEdgesCount() >= 4 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert
|
||||
if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1}, kOnnxDomain) &&
|
||||
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
|
||||
// Get hidden size from layer norm bias tensor shape.
|
||||
|
|
@ -236,13 +236,14 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
reshape_count++;
|
||||
}
|
||||
}
|
||||
|
||||
if (add_count == 1 && matmul_count == 3 && shape_count == node.GetOutputEdgesCount() - 4) { // BERT or DistilBert
|
||||
if (AttentionFusion::FuseSubGraph(node, *add_node, graph, hidden_size, mask_int32_map, logger)) {
|
||||
fused_count++;
|
||||
modified = true;
|
||||
}
|
||||
} else if (reshape_count == 1 && shape_count == 3) { // GPT
|
||||
if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, logger)) {
|
||||
} else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && (reshape_count + shape_count) == node.GetOutputEdgesCount()) { // GPT
|
||||
if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, logger)) {
|
||||
fused_count++;
|
||||
modified = true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "onnx/defs/shape_inference.h"
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
|
||||
#pragma once
|
||||
|
||||
|
|
@ -57,11 +59,14 @@ bool CheckSliceParameters(const Graph& graph, const Node& slice, const std::vect
|
|||
+----> Shape --> Gather (indices=0) --> Unsqueeze (axes=0) -----------+ |
|
||||
| |
|
||||
+----> Shape --> Gather (indices=1) --> Unsqueeze (axes=0) --------------+
|
||||
|
||||
The 3 Shape nodes are merged into one node if use_shared_node is true.
|
||||
*/
|
||||
bool MatchGemmSubgraph(Graph& graph,
|
||||
Node& node_after_gemm_reshape,
|
||||
int dst_arg_index,
|
||||
MatchGemmResult& result,
|
||||
bool use_shared_node,
|
||||
const logging::Logger& logger) {
|
||||
DEBUG_LOG("Start MatchGemmSubgraph");
|
||||
// GPT Attention fusion supports opset version 9 or later.
|
||||
|
|
@ -96,7 +101,7 @@ bool MatchGemmSubgraph(Graph& graph,
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, shape_before_slice, 1) ||
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, shape_before_slice, use_shared_node ? 3 : 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, slice, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, squeeze, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, unsqueeze, 1) ||
|
||||
|
|
@ -173,14 +178,21 @@ bool MatchGemmSubgraph(Graph& graph,
|
|||
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze_after_gather, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, gather, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, shape, 1)) { //TODO: deal with shared Shape node which has output edges > 1
|
||||
!optimizer_utils::CheckOutputEdges(graph, shape, 1) && !use_shared_node) {
|
||||
DEBUG_LOG("Output edge count not expected for nodes in gemm gather path");
|
||||
return false;
|
||||
}
|
||||
|
||||
result.node_indices.push_back(unsqueeze_after_gather.Index());
|
||||
result.node_indices.push_back(gather.Index());
|
||||
result.node_indices.push_back(shape.Index());
|
||||
|
||||
if (use_shared_node) {
|
||||
if (shape.Index() != shape_before_slice.Index()) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
result.node_indices.push_back(shape.Index());
|
||||
}
|
||||
|
||||
if (shape.InputDefs()[0]->Name() != subgraph_input->Name()) {
|
||||
return false;
|
||||
|
|
@ -252,9 +264,83 @@ bool ValidateGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidde
|
|||
|
||||
struct MatchUnidirMaskResult {
|
||||
const Node* div_node; // the root node (Div) of the subgraph
|
||||
bool is_unidirectional; // whether the mask is unidirectional.
|
||||
std::vector<NodeIndex> node_indices; // id of all nodes in the subgraph for removing later.
|
||||
};
|
||||
|
||||
// Return true when mask is unidirectionl (lower trigular) or all elements are 1.
|
||||
template <class T>
|
||||
bool ValidateUnidirMask(std::vector<T> mask_data, int64_t w, bool& is_undirectional) {
|
||||
// The mask data has shape 1x1xWxW
|
||||
if (mask_data.size() == static_cast<size_t>(w * w)) {
|
||||
bool is_one = true;
|
||||
is_undirectional = true;
|
||||
|
||||
const T* p = mask_data.data();
|
||||
for (int i = 0; i < w; i++) {
|
||||
for (int j = 0; j < w; j++) {
|
||||
if (*p != static_cast<T>(1)) {
|
||||
is_one = false;
|
||||
}
|
||||
|
||||
if (*p != ((i >= j) ? static_cast<T>(1) : static_cast<T>(0))) {
|
||||
is_undirectional = false;
|
||||
}
|
||||
|
||||
p++;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_undirectional || is_one)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidirectional, const logging::Logger& logger) {
|
||||
if (!graph_utils::IsInitializer(graph, mask.Name(), true)) {
|
||||
DEBUG_LOG("unidir mask is not constant");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that the mask shape is 1x1xWxW
|
||||
auto shape = mask.Shape();
|
||||
if (shape == nullptr || static_cast<size_t>(shape->dim_size()) != 4 || !utils::HasDimValue(shape->dim(0)) || static_cast<int64_t>(1) != shape->dim(0).dim_value() || !utils::HasDimValue(shape->dim(1)) || static_cast<int64_t>(1) != shape->dim(1).dim_value() || !utils::HasDimValue(shape->dim(2)) || !utils::HasDimValue(shape->dim(3)) || shape->dim(2).dim_value() != shape->dim(3).dim_value()) {
|
||||
DEBUG_LOG("unidir mask shape not expected");
|
||||
return false;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
|
||||
if (!graph.GetInitializedTensor(mask.Name(), tensor_proto) || tensor_proto == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
|
||||
DEBUG_LOG("This optimizer does not support external data for unidirectional mask right now");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
|
||||
std::vector<int32_t> int32_data = ONNX_NAMESPACE::ParseData<int32_t>(tensor_proto);
|
||||
if (!ValidateUnidirMask(int32_data, shape->dim(2).dim_value(), is_unidirectional)) {
|
||||
DEBUG_LOG("Mask is neither unidirectional nor all ones");
|
||||
return false;
|
||||
}
|
||||
} else if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
std::vector<float_t> float_data = ONNX_NAMESPACE::ParseData<float_t>(tensor_proto);
|
||||
if (!ValidateUnidirMask(float_data, shape->dim(2).dim_value(), is_unidirectional)) {
|
||||
DEBUG_LOG("Mask is neither unidirectional nor all ones");
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
DEBUG_LOG("Expect mask data type is uint8 or float");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Match Unidirectional Mask subgraph.
|
||||
In the below graph, ':' is followed by variable name in code. * means the input on the left side.
|
||||
|
||||
|
|
@ -274,8 +360,10 @@ struct MatchUnidirMaskResult {
|
|||
+----> Shape --> Slice ---------> Squeeze-------+ |
|
||||
| :shape2 :slice2 :squeeze2 v condition
|
||||
+----------------------------------------------------------------------------------------->Where( ,*,-10000)--->[Add]
|
||||
|
||||
When use_shared_node is true, shape1 and shape2 is one node, and also unsqueeze2 and unsqueeze3 is same.
|
||||
*/
|
||||
bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnidirMaskResult& result, const logging::Logger& logger) {
|
||||
bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnidirMaskResult& result, bool use_shared_node, const logging::Logger& logger) {
|
||||
DEBUG_LOG("Start MatchUnidirMaskSubgraph");
|
||||
std::vector<graph_utils::EdgeEndToMatch> root_path{
|
||||
{0, 0, "Where", {9}, kOnnxDomain},
|
||||
|
|
@ -325,10 +413,9 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
|
|||
!optimizer_utils::CheckOutputEdges(graph, mask_slice, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, unsqueeze1, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, sub, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, squeeze1, 3) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, squeeze1, use_shared_node ? 2 : 3) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, slice1, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, shape1, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, mask_slice, 1)) {
|
||||
!optimizer_utils::CheckOutputEdges(graph, shape1, use_shared_node ? 2 : 1)) {
|
||||
DEBUG_LOG("Output edge count not expected for nodes in path 1 of unidirectional mask");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -348,6 +435,11 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!ValidateUnidirMask(graph, *(mask_slice.InputDefs()[0]), result.is_unidirectional, logger)) {
|
||||
DEBUG_LOG("ValidateUnidirMask returns false for mask_slice");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!CheckSliceParameters(graph, slice1, {1, 2, 3}, {-1, INT_MAX, 0}, logger)) {
|
||||
DEBUG_LOG("CheckSliceParameters returns false for slice1");
|
||||
return false;
|
||||
|
|
@ -364,7 +456,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
|
|||
}
|
||||
|
||||
const Node& unsqueeze2 = edges[0]->GetNode();
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze2, 1)) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze2, use_shared_node ? 2 : 1)) {
|
||||
DEBUG_LOG("Output edge count not expected for unsqueeze2 of unidirectional mask");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -376,7 +468,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
|
|||
}
|
||||
|
||||
const Node& unsqueeze3 = edges[0]->GetNode();
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze3, 1)) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze3, use_shared_node ? 2 : 1)) {
|
||||
DEBUG_LOG("Output edge count not expected for unsqueeze3 of unidirectional mask");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -401,7 +493,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
|
|||
const Node& shape2 = edges[2]->GetNode();
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, squeeze2, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, slice2, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, shape2, 1)) {
|
||||
!optimizer_utils::CheckOutputEdges(graph, shape2, use_shared_node ? 2 : 1)) {
|
||||
DEBUG_LOG("Output edge count not expected for squeeze_2/slices2/shape2 of unidirectional mask");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -411,6 +503,10 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
|
|||
return false;
|
||||
}
|
||||
|
||||
if (use_shared_node && (shape1.Index() != shape2.Index() || unsqueeze2.Index() != unsqueeze3.Index())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
result.div_node = &div_node;
|
||||
result.node_indices = {
|
||||
where_node.Index(),
|
||||
|
|
@ -423,10 +519,13 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
|
|||
slice1.Index(),
|
||||
shape1.Index(),
|
||||
unsqueeze2.Index(),
|
||||
unsqueeze3.Index(),
|
||||
squeeze2.Index(),
|
||||
slice2.Index(),
|
||||
shape2.Index()};
|
||||
slice2.Index()};
|
||||
|
||||
if (!use_shared_node) {
|
||||
result.node_indices.push_back(unsqueeze3.Index());
|
||||
result.node_indices.push_back(shape2.Index());
|
||||
}
|
||||
|
||||
DEBUG_LOG("Pass MatchUnidirMaskSubgraph");
|
||||
return true;
|
||||
|
|
@ -739,14 +838,14 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& layer_norm, const No
|
|||
}
|
||||
std::vector<int64_t> shape_value;
|
||||
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[1]), shape_value, true) ||
|
||||
shape_value.size() != 1 ||
|
||||
shape_value[0] != 1) {
|
||||
shape_value.size() != 1 ||
|
||||
shape_value[0] != 1) {
|
||||
return false;
|
||||
}
|
||||
shape_value.clear();
|
||||
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[2]), shape_value, true) ||
|
||||
shape_value.size() != 1 ||
|
||||
shape_value[0] != 1) {
|
||||
shape_value.size() != 1 ||
|
||||
shape_value[0] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -894,8 +993,8 @@ bool CheckDistilBertReshapeShape(const Graph& graph, const Node& reshape, int64_
|
|||
|
||||
// lazy check: record unqueeze first and then check in the mask path
|
||||
std::vector<graph_utils::EdgeEndToMatch> shape_path{
|
||||
{0, 1, "Concat", {4, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}};
|
||||
{0, 1, "Concat", {4, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}};
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (!graph_utils::FindPath(reshape, true, shape_path, edges, logger)) {
|
||||
DEBUG_LOG("Failed to find shape path");
|
||||
|
|
@ -1138,12 +1237,12 @@ NodeArg* GetOrCreateMaskInt32(
|
|||
| (0,2,1,3) (0,2,3,1) (perm=0,2,1,3)
|
||||
| \ / | [Past]?
|
||||
\ / | |
|
||||
| \ p_Concat? <------|---------------------{Past_Subgraphj}?
|
||||
| \ k_Concat? <------|---------------------{Past_Subgraphj}?
|
||||
| \ / | |
|
||||
| qk_MatMul | |
|
||||
| | [B=h] | |
|
||||
| | / | /
|
||||
| qk_Div p_Concat? <------------------
|
||||
| qk_Div v_Concat? <------------------
|
||||
| | |
|
||||
| {Unidir_Mask_Subgraph} | [Mask]?
|
||||
| | / |
|
||||
|
|
@ -1178,7 +1277,7 @@ After Fusion:
|
|||
--------> Add
|
||||
TODO: replace Gemm_Subgraph by MatMul + Add
|
||||
*/
|
||||
bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map<std::string, NodeArg*>& mask_int32_map, const logging::Logger& logger) {
|
||||
bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map<std::string, NodeArg*>& mask_int32_map, bool use_shared_node, const logging::Logger& logger) {
|
||||
DEBUG_LOG("Start FuseGptAttention");
|
||||
const Node* parent_node = graph_utils::GetInputNode(layer_norm, 0);
|
||||
if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13}, kOnnxDomain)) {
|
||||
|
|
@ -1191,7 +1290,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
|
|||
}
|
||||
|
||||
MatchGemmResult gemm1_result;
|
||||
if (!MatchGemmSubgraph(graph, *graph.GetNode(add_after_gemm->Index()), 1, gemm1_result, logger) ||
|
||||
if (!MatchGemmSubgraph(graph, *graph.GetNode(add_after_gemm->Index()), 1, gemm1_result, use_shared_node, logger) ||
|
||||
!ValidateGemmInitializer(graph, *gemm1_result.gemm, hidden_size, false, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1233,7 +1332,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
|
|||
const Node& v_split = edges[2]->GetNode();
|
||||
|
||||
MatchGemmResult gemm0_result;
|
||||
if (!MatchGemmSubgraph(graph, *graph.GetNode(v_split.Index()), 0, gemm0_result, logger) ||
|
||||
if (!MatchGemmSubgraph(graph, *graph.GetNode(v_split.Index()), 0, gemm0_result, use_shared_node, logger) ||
|
||||
!ValidateGemmInitializer(graph, *gemm0_result.gemm, hidden_size, true, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1263,7 +1362,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
|
|||
}
|
||||
|
||||
MatchUnidirMaskResult unidir_mask_result;
|
||||
if (!MatchUnidirMaskSubgraph(graph, *(mask_nodes.has_input_mask ? mask_nodes.add : mask_nodes.softmax), unidir_mask_result, logger)) {
|
||||
if (!MatchUnidirMaskSubgraph(graph, *(mask_nodes.has_input_mask ? mask_nodes.add : mask_nodes.softmax), unidir_mask_result, use_shared_node, logger)) {
|
||||
DEBUG_LOG("MatchUnidirMaskSubgraph returns NULL");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1365,7 +1464,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
|
|||
nullptr,
|
||||
kMSDomain);
|
||||
attention_node.AddAttribute("num_heads", num_heads);
|
||||
attention_node.AddAttribute("unidirectional", (int64_t)1);
|
||||
attention_node.AddAttribute("unidirectional", static_cast<int64_t>(unidir_mask_result.is_unidirectional));
|
||||
|
||||
// Assign provider to this new node.
|
||||
attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType());
|
||||
|
|
|
|||
|
|
@ -23,16 +23,17 @@ class FusionGptAttention(Fusion):
|
|||
self.utils = FusionUtils(model)
|
||||
self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32
|
||||
|
||||
def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask=''):
|
||||
def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask, is_unidirectional):
|
||||
attention_node_name = self.model.create_node_name('GptAttention')
|
||||
attention_node = helper.make_node('Attention',
|
||||
inputs=[input, gemm.input[1], gemm.input[2], mask, past],
|
||||
outputs=[attention_node_name + "_output", present],
|
||||
name=attention_node_name)
|
||||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend(
|
||||
[helper.make_attribute("num_heads", self.num_heads),
|
||||
helper.make_attribute("unidirectional", 1)])
|
||||
attention_node.attribute.extend([
|
||||
helper.make_attribute("num_heads", self.num_heads),
|
||||
helper.make_attribute("unidirectional", 1 if is_unidirectional else 0)
|
||||
])
|
||||
|
||||
matmul_node = helper.make_node('MatMul',
|
||||
inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
|
||||
|
|
@ -115,6 +116,8 @@ class FusionGptAttention(Fusion):
|
|||
logger.debug("Add and LayerNormalization shall have one same input")
|
||||
return
|
||||
|
||||
is_unidirectional = True
|
||||
slice_mask = None
|
||||
input_mask_nodes = None
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0])
|
||||
if qk_nodes is not None:
|
||||
|
|
@ -127,6 +130,7 @@ class FusionGptAttention(Fusion):
|
|||
logger.debug("fuse_attention: failed to match unidirectional mask path")
|
||||
return
|
||||
div_mask = mask_nodes[-1]
|
||||
slice_mask = mask_nodes[3]
|
||||
|
||||
if div_qk != div_mask:
|
||||
logger.debug("fuse_attention: skip since div_qk != div_mask")
|
||||
|
|
@ -162,11 +166,24 @@ class FusionGptAttention(Fusion):
|
|||
logger.debug("fuse_attention: failed to match mask path")
|
||||
return
|
||||
div_mask = mask_nodes[-1]
|
||||
slice_mask = mask_nodes[2]
|
||||
|
||||
if div_qk != div_mask:
|
||||
logger.debug("fuse_attention: skip since div_qk != div_mask")
|
||||
return
|
||||
|
||||
# Validate that the mask data is either lower triangular (unidirectional) or all ones
|
||||
mask_data = numpy_helper.to_array(self.model.get_initializer(slice_mask.input[0]))
|
||||
if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1)
|
||||
and mask_data.shape[2] == mask_data.shape[3]):
|
||||
logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW")
|
||||
return
|
||||
if np.allclose(mask_data, np.ones_like(mask_data)):
|
||||
is_unidirectional = False
|
||||
elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
|
||||
logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones")
|
||||
return
|
||||
|
||||
q_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match q path")
|
||||
|
|
@ -219,7 +236,7 @@ class FusionGptAttention(Fusion):
|
|||
self.casted_attention_mask[input_name] = attention_mask_input_name
|
||||
|
||||
self.create_attention_node(gemm, gemm_qkv, past, present, layernorm_before_attention.output[0],
|
||||
reshape_qkv.output[0], attention_mask_input_name)
|
||||
reshape_qkv.output[0], attention_mask_input_name, is_unidirectional)
|
||||
|
||||
# we rely on prune_graph() to clean old subgraph nodes:
|
||||
# qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
|
||||
|
|
|
|||
|
|
@ -1834,6 +1834,63 @@ TEST_F(GraphTransformationTests, AttentionFusionFloat32Test) {
|
|||
ValidateAttention(graph);
|
||||
}
|
||||
|
||||
// Test GPT-2 Attention Fusion with past and unidirectional mask
|
||||
TEST_F(GraphTransformationTests, AttentionFusionWithPastAndUnidirMaskTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/attention_past_unidir.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<AttentionFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 0);
|
||||
EXPECT_EQ(op_to_count["Softmax"], 0);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
|
||||
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
Node* p_node = graph.GetNode(node_index);
|
||||
if (p_node->OpType().compare("Attention") == 0) {
|
||||
EXPECT_EQ(p_node->GetAttributes().at("unidirectional").i(), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test Attention Fusion with past but no unidirectional mask
|
||||
TEST_F(GraphTransformationTests, AttentionFusionWithPastAndNoUnidirMaskTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/attention_past_no_unidir.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<AttentionFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 0);
|
||||
EXPECT_EQ(op_to_count["Softmax"], 0);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
Node* p_node = graph.GetNode(node_index);
|
||||
if (p_node->OpType().compare("Attention") == 0) {
|
||||
EXPECT_EQ(p_node->GetAttributes().at("unidirectional").i(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test GPT-2 Attention Fusion with float32 mask
|
||||
TEST_F(GraphTransformationTests, AttentionFusionGPTWithPastAndMaskTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/gpt2_past_mask_one_layer.onnx";
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/attention_past_no_unidir.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/attention_past_no_unidir.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/attention_past_unidir.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/attention_past_unidir.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue