mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Add Bert/GPT2 fusion change for new attribute mask_filter_value in ORT optimizer (#14333)
### Description <!-- Describe your changes. --> The changes correspond to specify the mask_filter_value in attention attribute. However, the ORT optimizer cannot fuse SkipLayerNorm/Attention/EmbedLayerNorm with the most recent transformers. So this PR may only address this issue with some older version of onnx models(e.g the one used in the unittest) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
parent
ae0e090c7b
commit
d2c3d8eb38
4 changed files with 54 additions and 13 deletions
|
|
@ -268,6 +268,7 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
|
|||
int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
int64_t head_size,
|
||||
const float mask_filter_value,
|
||||
const logging::Logger& logger) {
|
||||
InlinedVector<std::reference_wrapper<const Node>> pivot_nodes;
|
||||
if (edges.size() == 2) {
|
||||
|
|
@ -386,6 +387,7 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
|
|||
nullptr,
|
||||
kMSDomain);
|
||||
attention_node.AddAttribute("num_heads", num_heads);
|
||||
attention_node.AddAttribute("mask_filter_value", mask_filter_value);
|
||||
|
||||
// Assign provider to this new node.
|
||||
attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType());
|
||||
|
|
@ -437,7 +439,7 @@ static bool FuseSubGraphQK(Node& layer_norm,
|
|||
|
||||
std::vector<NodeIndex> nodes_to_remove;
|
||||
if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size,
|
||||
num_heads, head_size, logger)) {
|
||||
num_heads, head_size, mask_nodes.mask_filter_value, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -528,7 +530,7 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm,
|
|||
|
||||
std::vector<NodeIndex> nodes_to_remove;
|
||||
if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size,
|
||||
num_heads, head_size, logger)) {
|
||||
num_heads, head_size, mask_nodes.mask_filter_value, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -580,7 +582,7 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm,
|
|||
| qk_MatMul | |
|
||||
| | [B=2] | ([A=1.0] mask_Cast(to=1))
|
||||
| | / | \ /
|
||||
| qk_Div | mask_Sub [B=-10000.0]
|
||||
| qk_Div | mask_Sub [B=-10000.0 or value of mask_filter_value]
|
||||
| \ | \ /
|
||||
| mask_Add <-------- /---------------------mask_Mul
|
||||
| | /
|
||||
|
|
|
|||
|
|
@ -265,7 +265,8 @@ bool ValidateGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidde
|
|||
|
||||
struct MatchUnidirMaskResult {
|
||||
const Node* div_node = nullptr; // the root node (Div) of the subgraph
|
||||
bool is_unidirectional = false; // whether the mask is unidirectional.
|
||||
bool is_unidirectional = false; // whether the mask is unidirectional.
|
||||
float mask_filter_value = -10000.0f; // the value to filter out the mask.
|
||||
std::vector<NodeIndex> node_indices; // id of all nodes in the subgraph for removing later.
|
||||
};
|
||||
|
||||
|
|
@ -375,7 +376,7 @@ bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidir
|
|||
| (*, -2, -1, 0) (axes=0) | Cast(9)
|
||||
+----> Shape --> Slice ---------> Squeeze-------+ |
|
||||
| :shape2 :slice2 :squeeze2 v condition
|
||||
+----------------------------------------------------------------------------------------->Where( ,*,-10000)--->[Add]
|
||||
+----------------------------------------------------------------------------------------->Where( ,*,-10000 or value of mask_filter_value)--->[Add]
|
||||
|
||||
When use_shared_node is true, shape1 and shape2 is one node, and also unsqueeze2 and unsqueeze3 is same.
|
||||
*/
|
||||
|
|
@ -394,8 +395,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
|
|||
const Node& where_node = edges[0]->GetNode();
|
||||
const Node& div_node = edges[1]->GetNode();
|
||||
|
||||
constexpr float expected_value = -10000.0f;
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(where_node.InputDefs()[2]), expected_value, true)) {
|
||||
if (!optimizer_utils::GetScalarInitializerValue<float>(graph, *(where_node.InputDefs()[2]), result.mask_filter_value, true)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -550,6 +550,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid
|
|||
struct AttentionMaskNodes {
|
||||
const Node* softmax;
|
||||
bool has_input_mask; // When it is false, the following nodes will be NULL.
|
||||
float mask_filter_value = -10000.0f;
|
||||
|
||||
const Node* add;
|
||||
const Node* mul;
|
||||
|
|
@ -566,6 +567,7 @@ struct AttentionMaskNodesDistilBert {
|
|||
const Node* reshape;
|
||||
const Node* equal;
|
||||
const Node* shape;
|
||||
float mask_filter_value = -10000.0f;
|
||||
};
|
||||
|
||||
void SetMaskNodesToRemove(const Graph& graph, AttentionMaskNodes& mask_nodes, std::vector<NodeIndex>& nodes_to_remove) {
|
||||
|
|
@ -598,10 +600,10 @@ void SetMaskNodesToRemove(const Graph&, AttentionMaskNodesDistilBert& mask_nodes
|
|||
}
|
||||
|
||||
/** Match Input Mask subgraph:
|
||||
{UnidirMask Subgraph}
|
||||
|
|
||||
(optional) v
|
||||
[Attention_mask] --> Unsqueeze (axes=1) --> Unsqueeze (axes=2) --> Cast ---->Sub(1,*) --> Mul(*, -10000.0) --> Add( ,*)--->SoftMax -->[MatMul]
|
||||
{ UnidirMask Subgraph}
|
||||
|
|
||||
(optional) v
|
||||
[Attention_mask] --> Unsqueeze (axes=1) --> Unsqueeze (axes=2) --> Cast ---->Sub(1,*) --> Mul(*, -10000.0 or value of mask_filter_value) --> Add( ,*)--->SoftMax -->[MatMul]
|
||||
|
||||
When is_input_mask_optional is true, this function also matches the following subgraph:
|
||||
{UnidirMask Subgraph [Where]} --> Softmax --> [MatMul]
|
||||
|
|
@ -710,7 +712,7 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& qkv_matmul, Attentio
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mask_mul.InputDefs()[1]), float(-10000), false)) {
|
||||
if (!optimizer_utils::GetScalarInitializerValue(graph, *(mask_mul.InputDefs()[1]), result.mask_filter_value, false)) {
|
||||
DEBUG_LOG("mask_mul const input not matched");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1461,7 +1463,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
|
|||
opt_k_transpose = k_concat;
|
||||
InlinedVector<int64_t> perm;
|
||||
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(*opt_k_transpose, "perm", perm)
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(*opt_k_transpose, "perm", perm)
|
||||
&& perm.size() == 4 && perm[0] == 0 && perm[1] == 1 && perm[2] == 3 && perm[3] == 2)) {
|
||||
DEBUG_LOG("opt_k_transpose perm attribute not matched");
|
||||
return false;
|
||||
|
|
@ -1541,6 +1543,11 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
|
|||
kMSDomain);
|
||||
attention_node.AddAttribute("num_heads", num_heads);
|
||||
attention_node.AddAttribute("unidirectional", static_cast<int64_t>(unidir_mask_result.is_unidirectional));
|
||||
if (mask_nodes.mask_filter_value != -10000.0f or unidir_mask_result.mask_filter_value != -10000.0f) {
|
||||
float mask_filter_value = mask_nodes.mask_filter_value != -10000.0f ?
|
||||
mask_nodes.mask_filter_value : unidir_mask_result.mask_filter_value;
|
||||
attention_node.AddAttribute("mask_filter_value", mask_filter_value);
|
||||
}
|
||||
|
||||
// Assign provider to this new node.
|
||||
attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType());
|
||||
|
|
|
|||
|
|
@ -363,6 +363,34 @@ bool IsScalar(const NodeArg& input_arg) {
|
|||
return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, T& value,
|
||||
bool is_constant) {
|
||||
if (!IsScalar(input_arg)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
|
||||
if (is_constant) {
|
||||
tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
} else if (!graph.GetInitializedTensor(input_arg.Name(), tensor_proto)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor_proto == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
const T* val = init_const.data<T>();
|
||||
value = *val;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template bool GetScalarInitializerValue<float>(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, float& value,
|
||||
bool is_constant);
|
||||
|
||||
#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
} // namespace optimizer_utils
|
||||
|
|
|
|||
|
|
@ -103,6 +103,10 @@ bool IsSupportedDataType(const Node& node, const T& supported_data_types) {
|
|||
|
||||
bool IsOperationDeterministic(const std::string& domain, const std::string& op);
|
||||
|
||||
template <typename T>
|
||||
bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, T& value,
|
||||
bool is_constant);
|
||||
|
||||
#endif // !#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
|||
Loading…
Reference in a new issue