Add Fusion for GPT Attention with both past state and attention mask (#4437)

Add Fusion for GPT Attention with past state and attention mask
This commit is contained in:
Tianlei Wu 2020-07-06 19:37:37 -07:00 committed by GitHub
parent 7baf374939
commit eabf6dc9ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 1204 additions and 206 deletions

View file

@ -5,12 +5,9 @@
#include "core/optimizer/initializer.h"
#include "core/optimizer/attention_fusion.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/attention_fusion_helper.h"
#include <cmath>
#define DEBUG_LOG(x) LOGS(logger, VERBOSE) << x
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
static bool ValidateMatMulInitializer(const Graph& graph, const Node& matmul, int64_t hidden_size) {
@ -31,7 +28,7 @@ static bool ValidateAddBiasInitializer(const Graph& graph, const Node& add, int6
return optimizer_utils::ValidateShape(input_b, {hidden_size});
}
// Merge 1-D weights (q, k and v) by concanating them one by one.
// Merge 1-D weights (q, k and v) by concatenating them one by one.
template <typename T>
void MergeWeights(const T* q, const T* k, const T* v, std::vector<T>& result, int64_t element_count) {
for (int64_t i = 0; i < element_count; i++) {
@ -50,7 +47,7 @@ void MergeWeights(const T* q, const T* k, const T* v, std::vector<T>& result, in
}
}
// Merge 2-D weights (q, k and v) by concanating them row by row.
// Merge 2-D weights (q, k and v) by concatenating them row by row.
template <typename T>
void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight, std::vector<T>& result, int64_t hidden_size) {
const T* q = q_weight;
@ -146,36 +143,6 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size,
return graph_utils::AddInitializer(graph, initializer);
}
// Add a Cast to convert Mask from int64 to int32.
static NodeArg& CastMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type) {
const TensorShapeProto* mask_shape = mask_input->Shape();
TypeProto mask_int32;
mask_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
auto dim0 = mask_int32.mutable_tensor_type()->mutable_shape()->add_dim();
*dim0 = mask_shape->dim(0);
auto dim1 = mask_int32.mutable_tensor_type()->mutable_shape()->add_dim();
*dim1 = mask_shape->dim(1);
auto& cast32 = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("Mask_Int32"), &mask_int32);
Node& node = graph.AddNode(graph.GenerateNodeName("MaskCast"),
"Cast",
"Cast mask from int64 to int32",
{mask_input},
{&cast32},
nullptr,
kOnnxDomain);
// Add attribute: "to" = 6
ONNX_NAMESPACE::AttributeProto to;
to.set_name("to");
to.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
to.set_i(static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
node.AddAttribute("to", to);
node.SetExecutionProviderType(provider_type);
return cast32;
}
static NodeArg& AddMaskReduceSum(Graph& graph, NodeArg* reduce_sum_input, TypeProto& output_type, ProviderType provider_type) {
NodeArg& reduce_sum_output = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("MaskIndex_Int32"), &output_type);
@ -229,7 +196,7 @@ static NodeArg* ProcessMask(Graph& graph, NodeArg* mask_input, ProviderType prov
NodeArg* reduce_sum_input = mask_input;
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 ||
data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
NodeArg& cast_int32 = CastMaskToInt32(graph, mask_input, provider_type);
NodeArg& cast_int32 = AttentionFusionHelper::CastMaskToInt32(graph, mask_input, provider_type);
reduce_sum_input = &cast_int32;
}
@ -272,6 +239,9 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// A map from mask input arg name to mask index output.
std::map<std::string, NodeArg*> mask_index_map;
// A map from mask input arg name to the one casted to int32
std::map<std::string, NodeArg*> mask_int32_map;
int fused_count = 0;
for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
@ -296,23 +266,31 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const Node* add_node = nullptr;
int add_count = 0;
int matmul_count = 0;
int shape_count = 0;
int reshape_count = 0;
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
if ((*it).OpType().compare("Add") == 0) {
add_count++;
add_node = &(*it);
} else if ((*it).OpType().compare("MatMul") == 0) {
matmul_count++;
} else if ((*it).OpType().compare("Shape") == 0) {
shape_count++;
} else if ((*it).OpType().compare("Reshape") == 0) {
reshape_count++;
}
}
if (add_count != 1 || matmul_count != 3) {
DEBUG_LOG("Attention subgraph expects 1 Add and 3 MatMul as children of LayerNormalization.");
continue;
}
if (AttentionFusion::FuseSubGraph(node, *add_node, graph, hidden_size, mask_index_map, logger)) {
fused_count++;
modified = true;
if (add_count == 1 && matmul_count == 3) { // BERT
if (AttentionFusion::FuseSubGraph(node, *add_node, graph, hidden_size, mask_index_map, logger)) {
fused_count++;
modified = true;
}
} else if (reshape_count == 1 && shape_count == 3) { // GPT
if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, logger)) {
fused_count++;
modified = true;
}
}
}
}
@ -345,9 +323,9 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
| (0,2,1,3) (0,2,3,1) (perm=0,2,1,3) |
| \ / | mask_Unsqueeze(axes=2)
| qk_MatMul | |
| | [B=2] | ([A=1] mask_Cast(to=1))
| | [B=2] | ([A=1.0] mask_Cast(to=1))
| | / | \ /
| qk_Div | mask_Sub [A=1000]
| qk_Div | mask_Sub [B=-10000.0]
| \ | \ /
| mask_Add <-------- /---------------------mask_Mul
| | /
@ -413,46 +391,16 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
return false;
}
// Internal nodes of attention subgraph only allow edges within the subgraph, and no graph output is allowed.
// No constraints for four nodes: reshape node is last node of Attention; and add, matmul and v_root are not in attention subgraph.
if (!optimizer_utils::CheckOutputEdges(graph, transpose, 1) ||
!optimizer_utils::CheckOutputEdges(graph, qkv_matmul, 1) ||
!optimizer_utils::CheckOutputEdges(graph, v_transpose, 1) ||
!optimizer_utils::CheckOutputEdges(graph, v_reshape, 1) ||
!optimizer_utils::CheckOutputEdges(graph, v_add, 1) ||
if (!optimizer_utils::CheckOutputEdges(graph, v_add, 1) ||
!optimizer_utils::CheckOutputEdges(graph, v_matmul, 1)) {
DEBUG_LOG("Output edge count not expected for nodes in path v");
DEBUG_LOG("Output edge count not expected for Add or MatMul in path v");
return false;
}
std::vector<int64_t> perm;
if (!(graph_utils::GetRepeatedNodeAttributeValues(transpose, "perm", perm) && perm.size() == 4 && perm[0] == 0 && perm[1] == 2 && perm[2] == 1 && perm[3] == 3)) {
DEBUG_LOG("Failed in match Transpose attribute perm. Expected: 0, 2, 1, 3");
return false;
}
if (!(graph_utils::GetRepeatedNodeAttributeValues(v_transpose, "perm", perm) && perm.size() == 4 && perm[0] == 0 && perm[1] == 2 && perm[2] == 1 && perm[3] == 3)) {
DEBUG_LOG("Failed in match v_transpose attribute perm. Expected: 0, 2, 1, 3");
return false;
}
std::vector<int64_t> v_reshape_shape;
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(v_reshape.InputDefs()[1]), v_reshape_shape) ||
v_reshape_shape.size() != 4 ||
v_reshape_shape[2] <= 0 ||
v_reshape_shape[3] <= 0 ||
hidden_size != v_reshape_shape[2] * v_reshape_shape[3]) {
DEBUG_LOG("v_reshape initializer value is not expected");
return false;
}
const int64_t num_attention_head = v_reshape_shape[2];
const int64_t attention_head_size = v_reshape_shape[3];
std::vector<int64_t> reshape_shape;
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(reshape.InputDefs()[1]), reshape_shape) ||
reshape_shape.size() != 3 ||
reshape_shape[2] != hidden_size) {
DEBUG_LOG("reshape initializer value is not expected");
int64_t num_heads = 0; // will be updated in CheckNodesInPathV
int64_t head_size = 0; // will be updated in CheckNodesInPathV
if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose, qkv_matmul, v_transpose, v_reshape, num_heads, head_size, hidden_size, logger)) {
DEBUG_LOG("CheckNodesInPathV return false");
return false;
}
@ -465,86 +413,11 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
return false;
}
// path 2 to find mask. Unsqueeze -> Unsqueeze -> (Cast) -> Sub -> Mul -> Add -> Softmax
// The "Cast" node in parentheses is optional.
std::vector<graph_utils::EdgeEndToMatch> mask_path{
{0, 0, "Softmax", {1, 11}, kOnnxDomain},
{0, 0, "Add", {7}, kOnnxDomain},
{0, 1, "Mul", {7}, kOnnxDomain},
{0, 0, "Sub", {7}, kOnnxDomain}};
if (!graph_utils::FindPath(qkv_matmul, true, mask_path, edges, logger)) {
DEBUG_LOG("Failed to find path for mask");
return false;
}
const Node& softmax = edges[0]->GetNode();
const Node& mask_add = edges[1]->GetNode();
const Node& mask_mul = edges[2]->GetNode();
const Node& mask_sub = edges[3]->GetNode();
// Match optional mask cast node
Node* p_mask_cast = nullptr;
Node* p_mask_unsqueeze_2 = nullptr;
Node* p_mask_unsqueeze_1 = nullptr;
std::vector<graph_utils::EdgeEndToMatch> mask_path_format_1{
{0, 1, "Cast", {9}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}};
std::vector<graph_utils::EdgeEndToMatch> mask_path_format_2{
{0, 1, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}};
if (graph_utils::FindPath(mask_sub, true, mask_path_format_1, edges, logger)) {
p_mask_cast = const_cast<Node*>(&edges[0]->GetNode());
p_mask_unsqueeze_2 = const_cast<Node*>(&edges[1]->GetNode());
p_mask_unsqueeze_1 = const_cast<Node*>(&edges[2]->GetNode());
} else if (graph_utils::FindPath(mask_sub, true, mask_path_format_2, edges, logger)) {
p_mask_unsqueeze_2 = const_cast<Node*>(&edges[0]->GetNode());
p_mask_unsqueeze_1 = const_cast<Node*>(&edges[1]->GetNode());
} else {
DEBUG_LOG("Failed to find path for mask");
return false;
}
const Node& mask_unsqueeze_2 = *p_mask_unsqueeze_2;
const Node& mask_unsqueeze_1 = *p_mask_unsqueeze_1;
if (!optimizer_utils::CheckOutputEdges(graph, softmax, 1) ||
!optimizer_utils::CheckOutputEdges(graph, mask_add, 1) ||
!optimizer_utils::CheckOutputEdges(graph, mask_sub, 1) ||
(p_mask_cast != nullptr && !optimizer_utils::CheckOutputEdges(graph, *p_mask_cast, 1)) ||
!optimizer_utils::CheckOutputEdges(graph, mask_unsqueeze_2, 1) ||
!optimizer_utils::CheckOutputEdges(graph, mask_unsqueeze_1, 1)) {
DEBUG_LOG("Output edge count not expected for mask nodes");
return false;
}
if (!optimizer_utils::IsAttributeWithExpectedValue(softmax, "axis", 3)) {
DEBUG_LOG("Softmax attribute axis is expected to be 3");
return false;
}
std::vector<int64_t> axes;
if (!(graph_utils::GetRepeatedNodeAttributeValues(mask_unsqueeze_1, "axes", axes) && axes.size() == 1 && axes[0] == 1)) {
DEBUG_LOG("mask_unsqueeze_1 axes not matched. Expect: 1");
return false;
}
if (!(graph_utils::GetRepeatedNodeAttributeValues(mask_unsqueeze_2, "axes", axes) && axes.size() == 1 && axes[0] == 2)) {
DEBUG_LOG("mask_unsqueeze_2 axes not matched. Expect: 2");
return false;
}
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mask_sub.InputDefs()[0]), float(1), false)) {
DEBUG_LOG("mask_sub const input not matched");
return false;
}
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mask_mul.InputDefs()[1]), float(-10000), false)) {
DEBUG_LOG("mask_mul const input not matched");
// Find mask nodes: Unsqueeze -> Unsqueeze -> (Cast) -> Sub -> Mul -> Add -> Softmax --> [MatMul]
// The "Cast" node in parentheses is optional.
AttentionFusionHelper::AttentionMaskNodes mask_nodes;
if (!AttentionFusionHelper::MatchInputMaskSubgraph(graph, qkv_matmul, mask_nodes, logger)) {
DEBUG_LOG("Failed in match input mask subgraph");
return false;
}
@ -558,7 +431,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
{0, 0, "MatMul", {1, 9}, kOnnxDomain},
{0, 0, "LayerNormalization", {1}, kOnnxDomain}};
if (!graph_utils::FindPath(mask_add, true, q_path, edges, logger)) {
if (!graph_utils::FindPath(*(mask_nodes.add), true, q_path, edges, logger)) {
DEBUG_LOG("Failed to find path for q");
return false;
}
@ -575,23 +448,8 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
return false;
}
std::vector<int64_t> q_reshape_shape;
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(q_reshape.InputDefs()[1]), q_reshape_shape) ||
q_reshape_shape.size() != 4 ||
q_reshape_shape[2] != num_attention_head ||
q_reshape_shape[3] != attention_head_size) {
DEBUG_LOG("q_reshape const not matched");
return false;
}
float expected_value = std::sqrt(static_cast<float>(attention_head_size));
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(qk_div.InputDefs()[1]), expected_value, false)) {
DEBUG_LOG("qk_div const not matched.");
return false;
}
if (!(graph_utils::GetRepeatedNodeAttributeValues(q_transpose, "perm", perm) && perm.size() == 4 && perm[0] == 0 && perm[1] == 2 && perm[2] == 1 && perm[3] == 3)) {
DEBUG_LOG("q_transpose perm attribute not matched");
if (!AttentionFusionHelper::CheckNodesInPathQ(graph, qk_div, q_reshape, q_transpose, num_heads, head_size, logger)) {
DEBUG_LOG("CheckNodesInPathQ returns false");
return false;
}
@ -624,8 +482,8 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
return false;
}
if (!(graph_utils::GetRepeatedNodeAttributeValues(k_transpose, "perm", perm) && perm.size() == 4 && perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1)) {
DEBUG_LOG("k_transpose perm attribute not matched");
if (!AttentionFusionHelper::CheckNodesInPathK(graph, k_reshape, k_transpose, num_heads, head_size, logger)) {
DEBUG_LOG("CheckNodesInPathK returns false");
return false;
}
@ -635,15 +493,6 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
return false;
}
std::vector<int64_t> k_reshape_shape;
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(k_reshape.InputDefs()[1]), k_reshape_shape) ||
k_reshape_shape.size() != 4 ||
k_reshape_shape[2] != num_attention_head ||
k_reshape_shape[3] != attention_head_size) {
DEBUG_LOG("k_reshape const not matched");
return false;
}
// Load q, k and v weights
const ONNX_NAMESPACE::TensorProto* q_weight_tensor = nullptr;
const ONNX_NAMESPACE::TensorProto* k_weight_tensor = nullptr;
@ -662,7 +511,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
}
// Now everything is ready, we will start fusing subgraph.
NodeArg* mask_input = graph.GetNode(mask_unsqueeze_1.Index())->MutableInputDefs()[0];
NodeArg* mask_input = graph.GetNode(mask_nodes.unsqueeze_1->Index())->MutableInputDefs()[0];
NodeArg* mask_index = GetOrCreateMaskIndex(graph, mask_input, mask_index_map, layer_norm.GetExecutionProviderType(), logger);
if (nullptr == mask_index) {
DEBUG_LOG("Failed to create mask index");
@ -684,7 +533,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
output_defs,
nullptr,
kMSDomain);
attention_node.AddAttribute("num_heads", num_attention_head);
attention_node.AddAttribute("num_heads", num_heads);
// Assign provider to this new node.
attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType());
@ -698,8 +547,6 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
v_reshape.Index(),
v_add.Index(),
v_matmul.Index(),
softmax.Index(),
mask_add.Index(),
qk_div.Index(),
qk_matmul.Index(),
q_transpose.Index(),
@ -711,16 +558,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
k_add.Index(),
k_matmul.Index()};
// When the last Attention node is fused. Original mask processing nodes can be removed safely.
if (optimizer_utils::CheckOutputEdges(graph, mask_mul, 1)) {
nodes_to_remove.push_back(mask_mul.Index());
nodes_to_remove.push_back(mask_sub.Index());
if (p_mask_cast != nullptr) {
nodes_to_remove.push_back((*p_mask_cast).Index());
}
nodes_to_remove.push_back(mask_unsqueeze_2.Index());
nodes_to_remove.push_back(mask_unsqueeze_1.Index());
}
AttentionFusionHelper::SetMaskNodesToRemove(graph, mask_nodes, nodes_to_remove);
for (const auto& node_index : nodes_to_remove) {
Node* node = graph.GetNode(node_index);

File diff suppressed because it is too large Load diff

View file

@ -1718,6 +1718,24 @@ TEST_F(GraphTransformationTests, AttentionFusionFloat32Test) {
ValidateAttention(graph);
}
// Test GPT-2 Attention Fusion with float32 mask
TEST_F(GraphTransformationTests, AttentionFusionGPTWithPastAndMaskTest) {
auto model_uri = MODEL_FOLDER "fusion/gpt2_past_mask_one_layer.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<AttentionFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Transpose"], 0);
EXPECT_EQ(op_to_count["Softmax"], 0);
EXPECT_EQ(op_to_count["Attention"], 1);
}
TEST_F(GraphTransformationTests, GeluFusionTest) {
auto model_uri = MODEL_FOLDER "fusion/gelu.onnx";
std::shared_ptr<Model> p_model;