From 7033346605e0e6bf619ac6049356d478f97e1237 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Thu, 23 Mar 2023 11:00:09 -0700 Subject: [PATCH] Support mask_filter_value attribute in DecoderMaskedMultiheadAttention (#15158) --- docs/ContribOperators.md | 2 ++ onnxruntime/core/graph/contrib_ops/bert_defs.cc | 6 +++++- onnxruntime/python/tools/transformers/convert_generation.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index a62d757aa9..656f0e86d2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1126,6 +1126,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
mask_filter_value : float
+
The value to be filled in the attention mask. Default value is -10000.0f
num_heads : int (required)
Number of attention heads
past_present_share_buffer : int
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index ae9e0c1324..b205b64954 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -473,6 +473,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("mask_filter_value", + "The value to be filled in the attention mask. Default value is -10000.0f", + AttributeProto::FLOAT, + OPTIONAL_VALUE) .Input(0, "input", "Input tensor with shape (batch_size, 1, input_hidden_size)", @@ -571,7 +575,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(1, "key", "Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size), " - "or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size)", + "or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size)", "T", OpSchema::Optional) .Input(2, diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index a7ea3b4e05..22690dc18e 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1106,6 +1106,7 @@ def update_decoder_subgraph_use_decoder_masked_multihead_attention( "past_present_share_buffer", "num_heads", "scale", + "mask_filter_value", "domain", ]