Attention with past and no unidirectional mask (#5557)

* Update fusions to support shared node, and mask of all ones
This commit is contained in:
Tianlei Wu 2020-10-21 20:12:02 -07:00 committed by GitHub
parent 0a9b83a313
commit 1f304fbee7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 208 additions and 38 deletions

View file

@ -98,10 +98,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
int past_sequence_length = 0;
if (past != nullptr) { // past is optional
if (!is_unidirectional_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is only allowed for unidirectional");
}
const auto& past_dims = past->Shape().GetDims();
if (past_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is expected to have 5 dimension, got ",

View file

@ -208,7 +208,7 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& node = *p_node;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if ((node.GetOutputEdgesCount() >= 4 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert
if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert
graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1}, kOnnxDomain) &&
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
// Get hidden size from layer norm bias tensor shape.
@ -236,13 +236,14 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
reshape_count++;
}
}
if (add_count == 1 && matmul_count == 3 && shape_count == node.GetOutputEdgesCount() - 4) { // BERT or DistilBert
if (AttentionFusion::FuseSubGraph(node, *add_node, graph, hidden_size, mask_int32_map, logger)) {
fused_count++;
modified = true;
}
} else if (reshape_count == 1 && shape_count == 3) { // GPT
if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, logger)) {
} else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && (reshape_count + shape_count) == node.GetOutputEdgesCount()) { // GPT
if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, logger)) {
fused_count++;
modified = true;
}

View file

@ -1,5 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnx/defs/shape_inference.h"
#include "onnx/defs/tensor_proto_util.h"
#pragma once
@ -57,11 +59,14 @@ bool CheckSliceParameters(const Graph& graph, const Node& slice, const std::vect
+----> Shape --> Gather (indices=0) --> Unsqueeze (axes=0) -----------+ |
| |
+----> Shape --> Gather (indices=1) --> Unsqueeze (axes=0) --------------+
The 3 Shape nodes are merged into one node if use_shared_node is true.
*/
bool MatchGemmSubgraph(Graph& graph,
Node& node_after_gemm_reshape,
int dst_arg_index,
MatchGemmResult& result,
bool use_shared_node,
const logging::Logger& logger) {
DEBUG_LOG("Start MatchGemmSubgraph");
// GPT Attention fusion supports opset version 9 or later.
@ -96,7 +101,7 @@ bool MatchGemmSubgraph(Graph& graph,
return false;
}
if (!optimizer_utils::CheckOutputEdges(graph, shape_before_slice, 1) ||
if (!optimizer_utils::CheckOutputEdges(graph, shape_before_slice, use_shared_node ? 3 : 1) ||
!optimizer_utils::CheckOutputEdges(graph, slice, 1) ||
!optimizer_utils::CheckOutputEdges(graph, squeeze, 1) ||
!optimizer_utils::CheckOutputEdges(graph, unsqueeze, 1) ||
@ -173,14 +178,21 @@ bool MatchGemmSubgraph(Graph& graph,
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze_after_gather, 1) ||
!optimizer_utils::CheckOutputEdges(graph, gather, 1) ||
!optimizer_utils::CheckOutputEdges(graph, shape, 1)) { //TODO: deal with shared Shape node which has output edges > 1
!optimizer_utils::CheckOutputEdges(graph, shape, 1) && !use_shared_node) {
DEBUG_LOG("Output edge count not expected for nodes in gemm gather path");
return false;
}
result.node_indices.push_back(unsqueeze_after_gather.Index());
result.node_indices.push_back(gather.Index());
result.node_indices.push_back(shape.Index());
if (use_shared_node) {
if (shape.Index() != shape_before_slice.Index()) {
return false;
}
} else {
result.node_indices.push_back(shape.Index());
}
if (shape.InputDefs()[0]->Name() != subgraph_input->Name()) {
return false;
@ -252,9 +264,83 @@ bool ValidateGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidde
struct MatchUnidirMaskResult {
const Node* div_node; // the root node (Div) of the subgraph
bool is_unidirectional; // whether the mask is unidirectional.
std::vector<NodeIndex> node_indices; // id of all nodes in the subgraph for removing later.
};
// Return true when mask is unidirectionl (lower trigular) or all elements are 1.
template <class T>
bool ValidateUnidirMask(std::vector<T> mask_data, int64_t w, bool& is_undirectional) {
// The mask data has shape 1x1xWxW
if (mask_data.size() == static_cast<size_t>(w * w)) {
bool is_one = true;
is_undirectional = true;
const T* p = mask_data.data();
for (int i = 0; i < w; i++) {
for (int j = 0; j < w; j++) {
if (*p != static_cast<T>(1)) {
is_one = false;
}
if (*p != ((i >= j) ? static_cast<T>(1) : static_cast<T>(0))) {
is_undirectional = false;
}
p++;
}
}
if (is_undirectional || is_one)
return true;
}
return false;
}
bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidirectional, const logging::Logger& logger) {
if (!graph_utils::IsInitializer(graph, mask.Name(), true)) {
DEBUG_LOG("unidir mask is not constant");
return false;
}
// Check that the mask shape is 1x1xWxW
auto shape = mask.Shape();
if (shape == nullptr || static_cast<size_t>(shape->dim_size()) != 4 || !utils::HasDimValue(shape->dim(0)) || static_cast<int64_t>(1) != shape->dim(0).dim_value() || !utils::HasDimValue(shape->dim(1)) || static_cast<int64_t>(1) != shape->dim(1).dim_value() || !utils::HasDimValue(shape->dim(2)) || !utils::HasDimValue(shape->dim(3)) || shape->dim(2).dim_value() != shape->dim(3).dim_value()) {
DEBUG_LOG("unidir mask shape not expected");
return false;
}
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
if (!graph.GetInitializedTensor(mask.Name(), tensor_proto) || tensor_proto == nullptr) {
return false;
}
if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
DEBUG_LOG("This optimizer does not support external data for unidirectional mask right now");
return false;
}
if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
std::vector<int32_t> int32_data = ONNX_NAMESPACE::ParseData<int32_t>(tensor_proto);
if (!ValidateUnidirMask(int32_data, shape->dim(2).dim_value(), is_unidirectional)) {
DEBUG_LOG("Mask is neither unidirectional nor all ones");
return false;
}
} else if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
std::vector<float_t> float_data = ONNX_NAMESPACE::ParseData<float_t>(tensor_proto);
if (!ValidateUnidirMask(float_data, shape->dim(2).dim_value(), is_unidirectional)) {
DEBUG_LOG("Mask is neither unidirectional nor all ones");
return false;
}
} else {
DEBUG_LOG("Expect mask data type is uint8 or float");
return false;
}
return true;
}
/** Match Unidirectional Mask subgraph.
In the below graph, ':' is followed by variable name in code. * means the input on the left side.
@ -274,8 +360,10 @@ struct MatchUnidirMaskResult {
+----> Shape --> Slice ---------> Squeeze-------+ |
| :shape2 :slice2 :squeeze2 v condition
+----------------------------------------------------------------------------------------->Where( ,*,-10000)--->[Add]
When use_shared_node is true, shape1 and shape2 is one node, and also unsqueeze2 and unsqueeze3 is same.
*/
bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnidirMaskResult& result, const logging::Logger& logger) {
bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnidirMaskResult& result, bool use_shared_node, const logging::Logger& logger) {
DEBUG_LOG("Start MatchUnidirMaskSubgraph");
std::vector<graph_utils::EdgeEndToMatch> root_path{
{0, 0, "Where", {9}, kOnnxDomain},
@ -325,10 +413,9 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
!optimizer_utils::CheckOutputEdges(graph, mask_slice, 1) ||
!optimizer_utils::CheckOutputEdges(graph, unsqueeze1, 1) ||
!optimizer_utils::CheckOutputEdges(graph, sub, 1) ||
!optimizer_utils::CheckOutputEdges(graph, squeeze1, 3) ||
!optimizer_utils::CheckOutputEdges(graph, squeeze1, use_shared_node ? 2 : 3) ||
!optimizer_utils::CheckOutputEdges(graph, slice1, 1) ||
!optimizer_utils::CheckOutputEdges(graph, shape1, 1) ||
!optimizer_utils::CheckOutputEdges(graph, mask_slice, 1)) {
!optimizer_utils::CheckOutputEdges(graph, shape1, use_shared_node ? 2 : 1)) {
DEBUG_LOG("Output edge count not expected for nodes in path 1 of unidirectional mask");
return false;
}
@ -348,6 +435,11 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
return false;
}
if (!ValidateUnidirMask(graph, *(mask_slice.InputDefs()[0]), result.is_unidirectional, logger)) {
DEBUG_LOG("ValidateUnidirMask returns false for mask_slice");
return false;
}
if (!CheckSliceParameters(graph, slice1, {1, 2, 3}, {-1, INT_MAX, 0}, logger)) {
DEBUG_LOG("CheckSliceParameters returns false for slice1");
return false;
@ -364,7 +456,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
}
const Node& unsqueeze2 = edges[0]->GetNode();
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze2, 1)) {
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze2, use_shared_node ? 2 : 1)) {
DEBUG_LOG("Output edge count not expected for unsqueeze2 of unidirectional mask");
return false;
}
@ -376,7 +468,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
}
const Node& unsqueeze3 = edges[0]->GetNode();
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze3, 1)) {
if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze3, use_shared_node ? 2 : 1)) {
DEBUG_LOG("Output edge count not expected for unsqueeze3 of unidirectional mask");
return false;
}
@ -401,7 +493,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
const Node& shape2 = edges[2]->GetNode();
if (!optimizer_utils::CheckOutputEdges(graph, squeeze2, 1) ||
!optimizer_utils::CheckOutputEdges(graph, slice2, 1) ||
!optimizer_utils::CheckOutputEdges(graph, shape2, 1)) {
!optimizer_utils::CheckOutputEdges(graph, shape2, use_shared_node ? 2 : 1)) {
DEBUG_LOG("Output edge count not expected for squeeze_2/slices2/shape2 of unidirectional mask");
return false;
}
@ -411,6 +503,10 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
return false;
}
if (use_shared_node && (shape1.Index() != shape2.Index() || unsqueeze2.Index() != unsqueeze3.Index())) {
return false;
}
result.div_node = &div_node;
result.node_indices = {
where_node.Index(),
@ -423,10 +519,13 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
slice1.Index(),
shape1.Index(),
unsqueeze2.Index(),
unsqueeze3.Index(),
squeeze2.Index(),
slice2.Index(),
shape2.Index()};
slice2.Index()};
if (!use_shared_node) {
result.node_indices.push_back(unsqueeze3.Index());
result.node_indices.push_back(shape2.Index());
}
DEBUG_LOG("Pass MatchUnidirMaskSubgraph");
return true;
@ -739,14 +838,14 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& layer_norm, const No
}
std::vector<int64_t> shape_value;
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[1]), shape_value, true) ||
shape_value.size() != 1 ||
shape_value[0] != 1) {
shape_value.size() != 1 ||
shape_value[0] != 1) {
return false;
}
shape_value.clear();
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[2]), shape_value, true) ||
shape_value.size() != 1 ||
shape_value[0] != 1) {
shape_value.size() != 1 ||
shape_value[0] != 1) {
return false;
}
@ -894,8 +993,8 @@ bool CheckDistilBertReshapeShape(const Graph& graph, const Node& reshape, int64_
// lazy check: record unqueeze first and then check in the mask path
std::vector<graph_utils::EdgeEndToMatch> shape_path{
{0, 1, "Concat", {4, 11, 13}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}};
{0, 1, "Concat", {4, 11, 13}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}};
std::vector<const Node::EdgeEnd*> edges;
if (!graph_utils::FindPath(reshape, true, shape_path, edges, logger)) {
DEBUG_LOG("Failed to find shape path");
@ -1138,12 +1237,12 @@ NodeArg* GetOrCreateMaskInt32(
| (0,2,1,3) (0,2,3,1) (perm=0,2,1,3)
| \ / | [Past]?
\ / | |
| \ p_Concat? <------|---------------------{Past_Subgraphj}?
| \ k_Concat? <------|---------------------{Past_Subgraphj}?
| \ / | |
| qk_MatMul | |
| | [B=h] | |
| | / | /
| qk_Div p_Concat? <------------------
| qk_Div v_Concat? <------------------
| | |
| {Unidir_Mask_Subgraph} | [Mask]?
| | / |
@ -1178,7 +1277,7 @@ After Fusion:
--------> Add
TODO: replace Gemm_Subgraph by MatMul + Add
*/
bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map<std::string, NodeArg*>& mask_int32_map, const logging::Logger& logger) {
bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map<std::string, NodeArg*>& mask_int32_map, bool use_shared_node, const logging::Logger& logger) {
DEBUG_LOG("Start FuseGptAttention");
const Node* parent_node = graph_utils::GetInputNode(layer_norm, 0);
if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13}, kOnnxDomain)) {
@ -1191,7 +1290,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
}
MatchGemmResult gemm1_result;
if (!MatchGemmSubgraph(graph, *graph.GetNode(add_after_gemm->Index()), 1, gemm1_result, logger) ||
if (!MatchGemmSubgraph(graph, *graph.GetNode(add_after_gemm->Index()), 1, gemm1_result, use_shared_node, logger) ||
!ValidateGemmInitializer(graph, *gemm1_result.gemm, hidden_size, false, logger)) {
return false;
}
@ -1233,7 +1332,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
const Node& v_split = edges[2]->GetNode();
MatchGemmResult gemm0_result;
if (!MatchGemmSubgraph(graph, *graph.GetNode(v_split.Index()), 0, gemm0_result, logger) ||
if (!MatchGemmSubgraph(graph, *graph.GetNode(v_split.Index()), 0, gemm0_result, use_shared_node, logger) ||
!ValidateGemmInitializer(graph, *gemm0_result.gemm, hidden_size, true, logger)) {
return false;
}
@ -1263,7 +1362,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
}
MatchUnidirMaskResult unidir_mask_result;
if (!MatchUnidirMaskSubgraph(graph, *(mask_nodes.has_input_mask ? mask_nodes.add : mask_nodes.softmax), unidir_mask_result, logger)) {
if (!MatchUnidirMaskSubgraph(graph, *(mask_nodes.has_input_mask ? mask_nodes.add : mask_nodes.softmax), unidir_mask_result, use_shared_node, logger)) {
DEBUG_LOG("MatchUnidirMaskSubgraph returns NULL");
return false;
}
@ -1365,7 +1464,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
nullptr,
kMSDomain);
attention_node.AddAttribute("num_heads", num_heads);
attention_node.AddAttribute("unidirectional", (int64_t)1);
attention_node.AddAttribute("unidirectional", static_cast<int64_t>(unidir_mask_result.is_unidirectional));
// Assign provider to this new node.
attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType());

View file

@ -23,16 +23,17 @@ class FusionGptAttention(Fusion):
self.utils = FusionUtils(model)
self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32
def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask=''):
def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask, is_unidirectional):
attention_node_name = self.model.create_node_name('GptAttention')
attention_node = helper.make_node('Attention',
inputs=[input, gemm.input[1], gemm.input[2], mask, past],
outputs=[attention_node_name + "_output", present],
name=attention_node_name)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend(
[helper.make_attribute("num_heads", self.num_heads),
helper.make_attribute("unidirectional", 1)])
attention_node.attribute.extend([
helper.make_attribute("num_heads", self.num_heads),
helper.make_attribute("unidirectional", 1 if is_unidirectional else 0)
])
matmul_node = helper.make_node('MatMul',
inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
@ -115,6 +116,8 @@ class FusionGptAttention(Fusion):
logger.debug("Add and LayerNormalization shall have one same input")
return
is_unidirectional = True
slice_mask = None
input_mask_nodes = None
qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0])
if qk_nodes is not None:
@ -127,6 +130,7 @@ class FusionGptAttention(Fusion):
logger.debug("fuse_attention: failed to match unidirectional mask path")
return
div_mask = mask_nodes[-1]
slice_mask = mask_nodes[3]
if div_qk != div_mask:
logger.debug("fuse_attention: skip since div_qk != div_mask")
@ -162,11 +166,24 @@ class FusionGptAttention(Fusion):
logger.debug("fuse_attention: failed to match mask path")
return
div_mask = mask_nodes[-1]
slice_mask = mask_nodes[2]
if div_qk != div_mask:
logger.debug("fuse_attention: skip since div_qk != div_mask")
return
# Validate that the mask data is either lower triangular (unidirectional) or all ones
mask_data = numpy_helper.to_array(self.model.get_initializer(slice_mask.input[0]))
if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1)
and mask_data.shape[2] == mask_data.shape[3]):
logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW")
return
if np.allclose(mask_data, np.ones_like(mask_data)):
is_unidirectional = False
elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones")
return
q_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0])
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
@ -219,7 +236,7 @@ class FusionGptAttention(Fusion):
self.casted_attention_mask[input_name] = attention_mask_input_name
self.create_attention_node(gemm, gemm_qkv, past, present, layernorm_before_attention.output[0],
reshape_qkv.output[0], attention_mask_input_name)
reshape_qkv.output[0], attention_mask_input_name, is_unidirectional)
# we rely on prune_graph() to clean old subgraph nodes:
# qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]

View file

@ -1834,6 +1834,63 @@ TEST_F(GraphTransformationTests, AttentionFusionFloat32Test) {
ValidateAttention(graph);
}
// Test GPT-2 Attention Fusion with past and unidirectional mask
TEST_F(GraphTransformationTests, AttentionFusionWithPastAndUnidirMaskTest) {
auto model_uri = MODEL_FOLDER "fusion/attention_past_unidir.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<AttentionFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Transpose"], 0);
EXPECT_EQ(op_to_count["Softmax"], 0);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
Node* p_node = graph.GetNode(node_index);
if (p_node->OpType().compare("Attention") == 0) {
EXPECT_EQ(p_node->GetAttributes().at("unidirectional").i(), 1);
}
}
}
// Test Attention Fusion with past but no unidirectional mask
TEST_F(GraphTransformationTests, AttentionFusionWithPastAndNoUnidirMaskTest) {
auto model_uri = MODEL_FOLDER "fusion/attention_past_no_unidir.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<AttentionFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Transpose"], 0);
EXPECT_EQ(op_to_count["Softmax"], 0);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
Node* p_node = graph.GetNode(node_index);
if (p_node->OpType().compare("Attention") == 0) {
EXPECT_EQ(p_node->GetAttributes().at("unidirectional").i(), 0);
}
}
}
// Test GPT-2 Attention Fusion with float32 mask
TEST_F(GraphTransformationTests, AttentionFusionGPTWithPastAndMaskTest) {
auto model_uri = MODEL_FOLDER "fusion/gpt2_past_mask_one_layer.onnx";