mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Add Fusion for GPT Attention with both past state and attention mask (#4437)
Add Fusion for GPT Attention with past state and attention mask
This commit is contained in:
parent
7baf374939
commit
eabf6dc9ee
4 changed files with 1204 additions and 206 deletions
|
|
@ -5,12 +5,9 @@
|
|||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/attention_fusion.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/attention_fusion_helper.h"
|
||||
#include <cmath>
|
||||
|
||||
#define DEBUG_LOG(x) LOGS(logger, VERBOSE) << x
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
static bool ValidateMatMulInitializer(const Graph& graph, const Node& matmul, int64_t hidden_size) {
|
||||
|
|
@ -31,7 +28,7 @@ static bool ValidateAddBiasInitializer(const Graph& graph, const Node& add, int6
|
|||
return optimizer_utils::ValidateShape(input_b, {hidden_size});
|
||||
}
|
||||
|
||||
// Merge 1-D weights (q, k and v) by concanating them one by one.
|
||||
// Merge 1-D weights (q, k and v) by concatenating them one by one.
|
||||
template <typename T>
|
||||
void MergeWeights(const T* q, const T* k, const T* v, std::vector<T>& result, int64_t element_count) {
|
||||
for (int64_t i = 0; i < element_count; i++) {
|
||||
|
|
@ -50,7 +47,7 @@ void MergeWeights(const T* q, const T* k, const T* v, std::vector<T>& result, in
|
|||
}
|
||||
}
|
||||
|
||||
// Merge 2-D weights (q, k and v) by concanating them row by row.
|
||||
// Merge 2-D weights (q, k and v) by concatenating them row by row.
|
||||
template <typename T>
|
||||
void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight, std::vector<T>& result, int64_t hidden_size) {
|
||||
const T* q = q_weight;
|
||||
|
|
@ -146,36 +143,6 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size,
|
|||
return graph_utils::AddInitializer(graph, initializer);
|
||||
}
|
||||
|
||||
// Add a Cast to convert Mask from int64 to int32.
|
||||
static NodeArg& CastMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type) {
|
||||
const TensorShapeProto* mask_shape = mask_input->Shape();
|
||||
TypeProto mask_int32;
|
||||
mask_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
|
||||
auto dim0 = mask_int32.mutable_tensor_type()->mutable_shape()->add_dim();
|
||||
*dim0 = mask_shape->dim(0);
|
||||
auto dim1 = mask_int32.mutable_tensor_type()->mutable_shape()->add_dim();
|
||||
*dim1 = mask_shape->dim(1);
|
||||
auto& cast32 = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("Mask_Int32"), &mask_int32);
|
||||
|
||||
Node& node = graph.AddNode(graph.GenerateNodeName("MaskCast"),
|
||||
"Cast",
|
||||
"Cast mask from int64 to int32",
|
||||
{mask_input},
|
||||
{&cast32},
|
||||
nullptr,
|
||||
kOnnxDomain);
|
||||
|
||||
// Add attribute: "to" = 6
|
||||
ONNX_NAMESPACE::AttributeProto to;
|
||||
to.set_name("to");
|
||||
to.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
|
||||
to.set_i(static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
|
||||
node.AddAttribute("to", to);
|
||||
|
||||
node.SetExecutionProviderType(provider_type);
|
||||
return cast32;
|
||||
}
|
||||
|
||||
static NodeArg& AddMaskReduceSum(Graph& graph, NodeArg* reduce_sum_input, TypeProto& output_type, ProviderType provider_type) {
|
||||
NodeArg& reduce_sum_output = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("MaskIndex_Int32"), &output_type);
|
||||
|
||||
|
|
@ -229,7 +196,7 @@ static NodeArg* ProcessMask(Graph& graph, NodeArg* mask_input, ProviderType prov
|
|||
NodeArg* reduce_sum_input = mask_input;
|
||||
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 ||
|
||||
data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
NodeArg& cast_int32 = CastMaskToInt32(graph, mask_input, provider_type);
|
||||
NodeArg& cast_int32 = AttentionFusionHelper::CastMaskToInt32(graph, mask_input, provider_type);
|
||||
reduce_sum_input = &cast_int32;
|
||||
}
|
||||
|
||||
|
|
@ -272,6 +239,9 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// A map from mask input arg name to mask index output.
|
||||
std::map<std::string, NodeArg*> mask_index_map;
|
||||
|
||||
// A map from mask input arg name to the one casted to int32
|
||||
std::map<std::string, NodeArg*> mask_int32_map;
|
||||
|
||||
int fused_count = 0;
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto* p_node = graph.GetNode(node_index);
|
||||
|
|
@ -296,23 +266,31 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
const Node* add_node = nullptr;
|
||||
int add_count = 0;
|
||||
int matmul_count = 0;
|
||||
int shape_count = 0;
|
||||
int reshape_count = 0;
|
||||
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
|
||||
if ((*it).OpType().compare("Add") == 0) {
|
||||
add_count++;
|
||||
add_node = &(*it);
|
||||
} else if ((*it).OpType().compare("MatMul") == 0) {
|
||||
matmul_count++;
|
||||
} else if ((*it).OpType().compare("Shape") == 0) {
|
||||
shape_count++;
|
||||
} else if ((*it).OpType().compare("Reshape") == 0) {
|
||||
reshape_count++;
|
||||
}
|
||||
}
|
||||
|
||||
if (add_count != 1 || matmul_count != 3) {
|
||||
DEBUG_LOG("Attention subgraph expects 1 Add and 3 MatMul as children of LayerNormalization.");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (AttentionFusion::FuseSubGraph(node, *add_node, graph, hidden_size, mask_index_map, logger)) {
|
||||
fused_count++;
|
||||
modified = true;
|
||||
if (add_count == 1 && matmul_count == 3) { // BERT
|
||||
if (AttentionFusion::FuseSubGraph(node, *add_node, graph, hidden_size, mask_index_map, logger)) {
|
||||
fused_count++;
|
||||
modified = true;
|
||||
}
|
||||
} else if (reshape_count == 1 && shape_count == 3) { // GPT
|
||||
if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, logger)) {
|
||||
fused_count++;
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -345,9 +323,9 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
| (0,2,1,3) (0,2,3,1) (perm=0,2,1,3) |
|
||||
| \ / | mask_Unsqueeze(axes=2)
|
||||
| qk_MatMul | |
|
||||
| | [B=2] | ([A=1] mask_Cast(to=1))
|
||||
| | [B=2] | ([A=1.0] mask_Cast(to=1))
|
||||
| | / | \ /
|
||||
| qk_Div | mask_Sub [A=1000]
|
||||
| qk_Div | mask_Sub [B=-10000.0]
|
||||
| \ | \ /
|
||||
| mask_Add <-------- /---------------------mask_Mul
|
||||
| | /
|
||||
|
|
@ -413,46 +391,16 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
return false;
|
||||
}
|
||||
|
||||
// Internal nodes of attention subgraph only allow edges within the subgraph, and no graph output is allowed.
|
||||
// No constraints for four nodes: reshape node is last node of Attention; and add, matmul and v_root are not in attention subgraph.
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, transpose, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, qkv_matmul, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, v_transpose, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, v_reshape, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, v_add, 1) ||
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, v_add, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, v_matmul, 1)) {
|
||||
DEBUG_LOG("Output edge count not expected for nodes in path v");
|
||||
DEBUG_LOG("Output edge count not expected for Add or MatMul in path v");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> perm;
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(transpose, "perm", perm) && perm.size() == 4 && perm[0] == 0 && perm[1] == 2 && perm[2] == 1 && perm[3] == 3)) {
|
||||
DEBUG_LOG("Failed in match Transpose attribute perm. Expected: 0, 2, 1, 3");
|
||||
return false;
|
||||
}
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(v_transpose, "perm", perm) && perm.size() == 4 && perm[0] == 0 && perm[1] == 2 && perm[2] == 1 && perm[3] == 3)) {
|
||||
DEBUG_LOG("Failed in match v_transpose attribute perm. Expected: 0, 2, 1, 3");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> v_reshape_shape;
|
||||
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(v_reshape.InputDefs()[1]), v_reshape_shape) ||
|
||||
v_reshape_shape.size() != 4 ||
|
||||
v_reshape_shape[2] <= 0 ||
|
||||
v_reshape_shape[3] <= 0 ||
|
||||
hidden_size != v_reshape_shape[2] * v_reshape_shape[3]) {
|
||||
DEBUG_LOG("v_reshape initializer value is not expected");
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t num_attention_head = v_reshape_shape[2];
|
||||
const int64_t attention_head_size = v_reshape_shape[3];
|
||||
|
||||
std::vector<int64_t> reshape_shape;
|
||||
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(reshape.InputDefs()[1]), reshape_shape) ||
|
||||
reshape_shape.size() != 3 ||
|
||||
reshape_shape[2] != hidden_size) {
|
||||
DEBUG_LOG("reshape initializer value is not expected");
|
||||
int64_t num_heads = 0; // will be updated in CheckNodesInPathV
|
||||
int64_t head_size = 0; // will be updated in CheckNodesInPathV
|
||||
if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose, qkv_matmul, v_transpose, v_reshape, num_heads, head_size, hidden_size, logger)) {
|
||||
DEBUG_LOG("CheckNodesInPathV return false");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -465,86 +413,11 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
return false;
|
||||
}
|
||||
|
||||
// path 2 to find mask. Unsqueeze -> Unsqueeze -> (Cast) -> Sub -> Mul -> Add -> Softmax
|
||||
// The "Cast" node in parentheses is optional.
|
||||
std::vector<graph_utils::EdgeEndToMatch> mask_path{
|
||||
{0, 0, "Softmax", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Add", {7}, kOnnxDomain},
|
||||
{0, 1, "Mul", {7}, kOnnxDomain},
|
||||
{0, 0, "Sub", {7}, kOnnxDomain}};
|
||||
|
||||
if (!graph_utils::FindPath(qkv_matmul, true, mask_path, edges, logger)) {
|
||||
DEBUG_LOG("Failed to find path for mask");
|
||||
return false;
|
||||
}
|
||||
|
||||
const Node& softmax = edges[0]->GetNode();
|
||||
const Node& mask_add = edges[1]->GetNode();
|
||||
const Node& mask_mul = edges[2]->GetNode();
|
||||
const Node& mask_sub = edges[3]->GetNode();
|
||||
|
||||
// Match optional mask cast node
|
||||
Node* p_mask_cast = nullptr;
|
||||
Node* p_mask_unsqueeze_2 = nullptr;
|
||||
Node* p_mask_unsqueeze_1 = nullptr;
|
||||
std::vector<graph_utils::EdgeEndToMatch> mask_path_format_1{
|
||||
{0, 1, "Cast", {9}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}};
|
||||
|
||||
std::vector<graph_utils::EdgeEndToMatch> mask_path_format_2{
|
||||
{0, 1, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}};
|
||||
|
||||
if (graph_utils::FindPath(mask_sub, true, mask_path_format_1, edges, logger)) {
|
||||
p_mask_cast = const_cast<Node*>(&edges[0]->GetNode());
|
||||
p_mask_unsqueeze_2 = const_cast<Node*>(&edges[1]->GetNode());
|
||||
p_mask_unsqueeze_1 = const_cast<Node*>(&edges[2]->GetNode());
|
||||
} else if (graph_utils::FindPath(mask_sub, true, mask_path_format_2, edges, logger)) {
|
||||
p_mask_unsqueeze_2 = const_cast<Node*>(&edges[0]->GetNode());
|
||||
p_mask_unsqueeze_1 = const_cast<Node*>(&edges[1]->GetNode());
|
||||
} else {
|
||||
DEBUG_LOG("Failed to find path for mask");
|
||||
return false;
|
||||
}
|
||||
|
||||
const Node& mask_unsqueeze_2 = *p_mask_unsqueeze_2;
|
||||
const Node& mask_unsqueeze_1 = *p_mask_unsqueeze_1;
|
||||
|
||||
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, softmax, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, mask_add, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, mask_sub, 1) ||
|
||||
(p_mask_cast != nullptr && !optimizer_utils::CheckOutputEdges(graph, *p_mask_cast, 1)) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, mask_unsqueeze_2, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, mask_unsqueeze_1, 1)) {
|
||||
DEBUG_LOG("Output edge count not expected for mask nodes");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!optimizer_utils::IsAttributeWithExpectedValue(softmax, "axis", 3)) {
|
||||
DEBUG_LOG("Softmax attribute axis is expected to be 3");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> axes;
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(mask_unsqueeze_1, "axes", axes) && axes.size() == 1 && axes[0] == 1)) {
|
||||
DEBUG_LOG("mask_unsqueeze_1 axes not matched. Expect: 1");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(mask_unsqueeze_2, "axes", axes) && axes.size() == 1 && axes[0] == 2)) {
|
||||
DEBUG_LOG("mask_unsqueeze_2 axes not matched. Expect: 2");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mask_sub.InputDefs()[0]), float(1), false)) {
|
||||
DEBUG_LOG("mask_sub const input not matched");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mask_mul.InputDefs()[1]), float(-10000), false)) {
|
||||
DEBUG_LOG("mask_mul const input not matched");
|
||||
// Find mask nodes: Unsqueeze -> Unsqueeze -> (Cast) -> Sub -> Mul -> Add -> Softmax --> [MatMul]
|
||||
// The "Cast" node in parentheses is optional.
|
||||
AttentionFusionHelper::AttentionMaskNodes mask_nodes;
|
||||
if (!AttentionFusionHelper::MatchInputMaskSubgraph(graph, qkv_matmul, mask_nodes, logger)) {
|
||||
DEBUG_LOG("Failed in match input mask subgraph");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -558,7 +431,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
{0, 0, "MatMul", {1, 9}, kOnnxDomain},
|
||||
{0, 0, "LayerNormalization", {1}, kOnnxDomain}};
|
||||
|
||||
if (!graph_utils::FindPath(mask_add, true, q_path, edges, logger)) {
|
||||
if (!graph_utils::FindPath(*(mask_nodes.add), true, q_path, edges, logger)) {
|
||||
DEBUG_LOG("Failed to find path for q");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -575,23 +448,8 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> q_reshape_shape;
|
||||
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(q_reshape.InputDefs()[1]), q_reshape_shape) ||
|
||||
q_reshape_shape.size() != 4 ||
|
||||
q_reshape_shape[2] != num_attention_head ||
|
||||
q_reshape_shape[3] != attention_head_size) {
|
||||
DEBUG_LOG("q_reshape const not matched");
|
||||
return false;
|
||||
}
|
||||
|
||||
float expected_value = std::sqrt(static_cast<float>(attention_head_size));
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(qk_div.InputDefs()[1]), expected_value, false)) {
|
||||
DEBUG_LOG("qk_div const not matched.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(q_transpose, "perm", perm) && perm.size() == 4 && perm[0] == 0 && perm[1] == 2 && perm[2] == 1 && perm[3] == 3)) {
|
||||
DEBUG_LOG("q_transpose perm attribute not matched");
|
||||
if (!AttentionFusionHelper::CheckNodesInPathQ(graph, qk_div, q_reshape, q_transpose, num_heads, head_size, logger)) {
|
||||
DEBUG_LOG("CheckNodesInPathQ returns false");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -624,8 +482,8 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(k_transpose, "perm", perm) && perm.size() == 4 && perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1)) {
|
||||
DEBUG_LOG("k_transpose perm attribute not matched");
|
||||
if (!AttentionFusionHelper::CheckNodesInPathK(graph, k_reshape, k_transpose, num_heads, head_size, logger)) {
|
||||
DEBUG_LOG("CheckNodesInPathK returns false");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -635,15 +493,6 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> k_reshape_shape;
|
||||
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(k_reshape.InputDefs()[1]), k_reshape_shape) ||
|
||||
k_reshape_shape.size() != 4 ||
|
||||
k_reshape_shape[2] != num_attention_head ||
|
||||
k_reshape_shape[3] != attention_head_size) {
|
||||
DEBUG_LOG("k_reshape const not matched");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Load q, k and v weights
|
||||
const ONNX_NAMESPACE::TensorProto* q_weight_tensor = nullptr;
|
||||
const ONNX_NAMESPACE::TensorProto* k_weight_tensor = nullptr;
|
||||
|
|
@ -662,7 +511,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
}
|
||||
|
||||
// Now everything is ready, we will start fusing subgraph.
|
||||
NodeArg* mask_input = graph.GetNode(mask_unsqueeze_1.Index())->MutableInputDefs()[0];
|
||||
NodeArg* mask_input = graph.GetNode(mask_nodes.unsqueeze_1->Index())->MutableInputDefs()[0];
|
||||
NodeArg* mask_index = GetOrCreateMaskIndex(graph, mask_input, mask_index_map, layer_norm.GetExecutionProviderType(), logger);
|
||||
if (nullptr == mask_index) {
|
||||
DEBUG_LOG("Failed to create mask index");
|
||||
|
|
@ -684,7 +533,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
output_defs,
|
||||
nullptr,
|
||||
kMSDomain);
|
||||
attention_node.AddAttribute("num_heads", num_attention_head);
|
||||
attention_node.AddAttribute("num_heads", num_heads);
|
||||
|
||||
// Assign provider to this new node.
|
||||
attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType());
|
||||
|
|
@ -698,8 +547,6 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
v_reshape.Index(),
|
||||
v_add.Index(),
|
||||
v_matmul.Index(),
|
||||
softmax.Index(),
|
||||
mask_add.Index(),
|
||||
qk_div.Index(),
|
||||
qk_matmul.Index(),
|
||||
q_transpose.Index(),
|
||||
|
|
@ -711,16 +558,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
|
|||
k_add.Index(),
|
||||
k_matmul.Index()};
|
||||
|
||||
// When the last Attention node is fused. Original mask processing nodes can be removed safely.
|
||||
if (optimizer_utils::CheckOutputEdges(graph, mask_mul, 1)) {
|
||||
nodes_to_remove.push_back(mask_mul.Index());
|
||||
nodes_to_remove.push_back(mask_sub.Index());
|
||||
if (p_mask_cast != nullptr) {
|
||||
nodes_to_remove.push_back((*p_mask_cast).Index());
|
||||
}
|
||||
nodes_to_remove.push_back(mask_unsqueeze_2.Index());
|
||||
nodes_to_remove.push_back(mask_unsqueeze_1.Index());
|
||||
}
|
||||
AttentionFusionHelper::SetMaskNodesToRemove(graph, mask_nodes, nodes_to_remove);
|
||||
|
||||
for (const auto& node_index : nodes_to_remove) {
|
||||
Node* node = graph.GetNode(node_index);
|
||||
|
|
|
|||
1142
onnxruntime/core/optimizer/attention_fusion_helper.h
Normal file
1142
onnxruntime/core/optimizer/attention_fusion_helper.h
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1718,6 +1718,24 @@ TEST_F(GraphTransformationTests, AttentionFusionFloat32Test) {
|
|||
ValidateAttention(graph);
|
||||
}
|
||||
|
||||
// Test GPT-2 Attention Fusion with float32 mask
|
||||
TEST_F(GraphTransformationTests, AttentionFusionGPTWithPastAndMaskTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/gpt2_past_mask_one_layer.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<AttentionFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 0);
|
||||
EXPECT_EQ(op_to_count["Softmax"], 0);
|
||||
EXPECT_EQ(op_to_count["Attention"], 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, GeluFusionTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/gelu.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/gpt2_past_mask_one_layer.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/gpt2_past_mask_one_layer.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue