Support mask_filter_value attribute in DecoderMaskedMultiheadAttention (#15158)

This commit is contained in:
Hariharan Seshadri 2023-03-23 11:00:09 -07:00 committed by GitHub
parent 910fc09de2
commit 7033346605
3 changed files with 8 additions and 1 deletions

View file

@ -1126,6 +1126,8 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
<dl>
<dt><tt>mask_filter_value</tt> : float</dt>
<dd>The value to be filled in the attention mask. Default value is -10000.0f</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
<dt><tt>past_present_share_buffer</tt> : int</dt>

View file

@ -473,6 +473,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr("mask_filter_value",
"The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Input(0,
"input",
"Input tensor with shape (batch_size, 1, input_hidden_size)",
@ -571,7 +575,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(1,
"key",
"Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size), "
"or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size)",
"or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Input(2,

View file

@ -1106,6 +1106,7 @@ def update_decoder_subgraph_use_decoder_masked_multihead_attention(
"past_present_share_buffer",
"num_heads",
"scale",
"mask_filter_value",
"domain",
]