Add Bert/GPT2 fusion change for new attribute mask_filter_value in ORT optimizer (#14333)

### Description
<!-- Describe your changes. -->

The changes correspond to specify the mask_filter_value in attention
attribute. However, the ORT optimizer cannot fuse
SkipLayerNorm/Attention/EmbedLayerNorm with the most recent
transformers. So this PR may only address this issue with some older
version of onnx models(e.g the one used in the unittest)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2023-01-19 12:52:09 -08:00 committed by GitHub
parent ae0e090c7b
commit d2c3d8eb38
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 13 deletions

View file

@ -268,6 +268,7 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
int64_t hidden_size,
int64_t num_heads,
int64_t head_size,
const float mask_filter_value,
const logging::Logger& logger) {
InlinedVector<std::reference_wrapper<const Node>> pivot_nodes;
if (edges.size() == 2) {
@ -386,6 +387,7 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
nullptr,
kMSDomain);
attention_node.AddAttribute("num_heads", num_heads);
attention_node.AddAttribute("mask_filter_value", mask_filter_value);
// Assign provider to this new node.
attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType());
@ -437,7 +439,7 @@ static bool FuseSubGraphQK(Node& layer_norm,
std::vector<NodeIndex> nodes_to_remove;
if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size,
num_heads, head_size, logger)) {
num_heads, head_size, mask_nodes.mask_filter_value, logger)) {
return false;
}
@ -528,7 +530,7 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm,
std::vector<NodeIndex> nodes_to_remove;
if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size,
num_heads, head_size, logger)) {
num_heads, head_size, mask_nodes.mask_filter_value, logger)) {
return false;
}
@ -580,7 +582,7 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm,
| qk_MatMul | |
| | [B=2] | ([A=1.0] mask_Cast(to=1))
| | / | \ /
| qk_Div | mask_Sub [B=-10000.0]
| qk_Div | mask_Sub [B=-10000.0 or value of mask_filter_value]
| \ | \ /
| mask_Add <-------- /---------------------mask_Mul
| | /

View file

@ -265,7 +265,8 @@ bool ValidateGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidde
struct MatchUnidirMaskResult {
const Node* div_node = nullptr; // the root node (Div) of the subgraph
bool is_unidirectional = false; // whether the mask is unidirectional.
bool is_unidirectional = false; // whether the mask is unidirectional.
float mask_filter_value = -10000.0f; // the value to filter out the mask.
std::vector<NodeIndex> node_indices; // id of all nodes in the subgraph for removing later.
};
@ -375,7 +376,7 @@ bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidir
| (*, -2, -1, 0) (axes=0) | Cast(9)
+----> Shape --> Slice ---------> Squeeze-------+ |
| :shape2 :slice2 :squeeze2 v condition
+----------------------------------------------------------------------------------------->Where( ,*,-10000)--->[Add]
+----------------------------------------------------------------------------------------->Where( ,*,-10000 or value of mask_filter_value)--->[Add]
When use_shared_node is true, shape1 and shape2 is one node, and also unsqueeze2 and unsqueeze3 is same.
*/
@ -394,8 +395,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
const Node& where_node = edges[0]->GetNode();
const Node& div_node = edges[1]->GetNode();
constexpr float expected_value = -10000.0f;
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(where_node.InputDefs()[2]), expected_value, true)) {
if (!optimizer_utils::GetScalarInitializerValue<float>(graph, *(where_node.InputDefs()[2]), result.mask_filter_value, true)) {
return false;
}
@ -550,6 +550,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
struct AttentionMaskNodes {
const Node* softmax;
bool has_input_mask; // When it is false, the following nodes will be NULL.
float mask_filter_value = -10000.0f;
const Node* add;
const Node* mul;
@ -566,6 +567,7 @@ struct AttentionMaskNodesDistilBert {
const Node* reshape;
const Node* equal;
const Node* shape;
float mask_filter_value = -10000.0f;
};
void SetMaskNodesToRemove(const Graph& graph, AttentionMaskNodes& mask_nodes, std::vector<NodeIndex>& nodes_to_remove) {
@ -598,10 +600,10 @@ void SetMaskNodesToRemove(const Graph&, AttentionMaskNodesDistilBert& mask_nodes
}
/** Match Input Mask subgraph:
{UnidirMask Subgraph}
|
(optional) v
[Attention_mask] --> Unsqueeze (axes=1) --> Unsqueeze (axes=2) --> Cast ---->Sub(1,*) --> Mul(*, -10000.0) --> Add( ,*)--->SoftMax -->[MatMul]
{ UnidirMask Subgraph}
|
(optional) v
[Attention_mask] --> Unsqueeze (axes=1) --> Unsqueeze (axes=2) --> Cast ---->Sub(1,*) --> Mul(*, -10000.0 or value of mask_filter_value) --> Add( ,*)--->SoftMax -->[MatMul]
When is_input_mask_optional is true, this function also matches the following subgraph:
{UnidirMask Subgraph [Where]} --> Softmax --> [MatMul]
@ -710,7 +712,7 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& qkv_matmul, Attentio
return false;
}
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mask_mul.InputDefs()[1]), float(-10000), false)) {
if (!optimizer_utils::GetScalarInitializerValue(graph, *(mask_mul.InputDefs()[1]), result.mask_filter_value, false)) {
DEBUG_LOG("mask_mul const input not matched");
return false;
}
@ -1461,7 +1463,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
opt_k_transpose = k_concat;
InlinedVector<int64_t> perm;
if (!(graph_utils::GetRepeatedNodeAttributeValues(*opt_k_transpose, "perm", perm)
if (!(graph_utils::GetRepeatedNodeAttributeValues(*opt_k_transpose, "perm", perm)
&& perm.size() == 4 && perm[0] == 0 && perm[1] == 1 && perm[2] == 3 && perm[3] == 2)) {
DEBUG_LOG("opt_k_transpose perm attribute not matched");
return false;
@ -1541,6 +1543,11 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
kMSDomain);
attention_node.AddAttribute("num_heads", num_heads);
attention_node.AddAttribute("unidirectional", static_cast<int64_t>(unidir_mask_result.is_unidirectional));
if (mask_nodes.mask_filter_value != -10000.0f or unidir_mask_result.mask_filter_value != -10000.0f) {
float mask_filter_value = mask_nodes.mask_filter_value != -10000.0f ?
mask_nodes.mask_filter_value : unidir_mask_result.mask_filter_value;
attention_node.AddAttribute("mask_filter_value", mask_filter_value);
}
// Assign provider to this new node.
attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType());

View file

@ -363,6 +363,34 @@ bool IsScalar(const NodeArg& input_arg) {
return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1);
}
template <typename T>
bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, T& value,
bool is_constant) {
if (!IsScalar(input_arg)) {
return false;
}
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
if (is_constant) {
tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
} else if (!graph.GetInitializedTensor(input_arg.Name(), tensor_proto)) {
return false;
}
if (tensor_proto == nullptr) {
return false;
}
Initializer init_const{*tensor_proto, graph.ModelPath()};
const T* val = init_const.data<T>();
value = *val;
return true;
}
template bool GetScalarInitializerValue<float>(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, float& value,
bool is_constant);
#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
} // namespace optimizer_utils

View file

@ -103,6 +103,10 @@ bool IsSupportedDataType(const Node& node, const T& supported_data_types) {
bool IsOperationDeterministic(const std::string& domain, const std::string& op);
template <typename T>
bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, T& value,
bool is_constant);
#endif // !#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)