diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 07e892cabd..beb7029895 100755
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1134,7 +1134,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs (1 - 10)
+#### Inputs (1 - 11)
- query : T
@@ -1157,6 +1157,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
- cache_indirection (optional) : M
- A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
+- bias (optional) : T
+- Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
#### Outputs (1 - 3)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 71a6a3afeb..c23209d4fe 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -802,7 +802,7 @@ Do not modify directly.*
|ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
-|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
+|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)|
|DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
index 804cf2cf10..52b1efdeb2 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
@@ -62,6 +62,14 @@ struct BeamSearchState : IBeamSearchState {
}
}
+ void EnsurePastStateReorderStagingBuffer(AllocatorPtr allocator, int64_t sz) {
+ auto current_buffer_size = this->staging_for_past_state_reorder.Shape().Size();
+ if (sz > current_buffer_size) {
+ TensorShape buffer_shape = {sz};
+ this->staging_for_past_state_reorder = Tensor(DataTypeImpl::GetType(), buffer_shape, allocator);
+ }
+ }
+
private:
BufferUniquePtr next_token_logits_buffer_;
BufferUniquePtr next_token_scores_buffer_;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
index 5c4b7e2753..6385511785 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
@@ -254,6 +254,11 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches
if (decoder_subgraph_.has_decoder_masked_attention_) {
size_t offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex());
+ // Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same
+ auto first_cross_attention_key = decoder_feeds[offset + 2 * static_cast(decoder_subgraph_.num_layers)].GetMutable();
+ auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
+ beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);
+
// Here we only need to reorder the past key for self-attention and cross-attention.
for (size_t i = 0; i < 2 * static_cast(decoder_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
index a5178422c0..5a7c154d31 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
@@ -245,6 +245,11 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe
if (decoder_subgraph_.has_decoder_masked_attention_) {
size_t offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex());
+ // Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same
+ auto first_cross_attention_key = decoder_feeds[offset + 2 * static_cast(decoder_subgraph_.num_layers)].GetMutable();
+ auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
+ beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);
+
// Here we only need to reorder the past key for self-attention and cross-attention.
for (size_t i = 0; i < 2 * static_cast(decoder_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc
index 37f3e5a43d..887a6a8984 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc
@@ -106,9 +106,9 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 0 (input_ids) shall have int32 type");
- auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
+ auto float_type = subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
- "decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
+ "decoder subgraph input 1 (encoder_hidden_states) shall have float or float16 type");
for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) {
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
@@ -202,6 +202,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds(
// When first_past_input_index_ == 2, the encoder_hidden_states and past states are copied from the second output
// of encoder.
// When first_past_input_index_ == 1, the past states are copied from the second output of encoder.
+ // TODO: MAKE IT MORE READABLE
for (size_t j = static_cast(3) - first_past_input_index_; j < encoder_fetches.size(); j++) {
if (j == 1) {
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
@@ -226,7 +227,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds(
decoder_feeds.push_back(expanded_hidden_states);
} else {
// past key/value for cross attention does not need to be initialized with max_seq_len since they are static.
- bool use_max_seq_len = (j - first_past_input_index_) < 2 * static_cast(num_layers);
+ bool use_max_seq_len = (j - first_past_input_index_) <= 2 * static_cast(num_layers);
OrtValue expanded_cache;
if (is_output_float16_) {
@@ -252,7 +253,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds(
if (past_present_share_buffer_) {
// Past sequence length feed
- ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, 1));
+ ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, cur_len - 1));
// Add beam search specific inputs
if (need_cache_indir) {
const int64_t batch_size = static_cast(batch_beam_size / num_beam);
diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc
index d5053e1d85..4bdc6db30b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc
@@ -22,6 +22,7 @@ static constexpr int kBeamWidthInputIndex = 8;
static constexpr int kCacheIndirectionInputIndex = 9;
static constexpr int kPastInputIndex = 5;
static constexpr int kPresentOutputIndex = 1;
+static constexpr int kBiasIndex = 10;
#define REGISTER_KERNEL_TYPED(T1, T2) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
@@ -63,6 +64,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext*
const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex);
const Tensor* beam_width = context->Input(kBeamWidthInputIndex);
const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex);
+ const Tensor* bias = context->Input(kBiasIndex);
auto& device_prop = GetDeviceProp();
DecoderMaskedMultiHeadAttentionParams parameters;
@@ -70,7 +72,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext*
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query,
key,
value,
- nullptr, // bias
+ bias,
mask_index,
relative_position_bias,
past_key,
@@ -84,6 +86,13 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext*
is_dmmha_packing, // dmmha_packing
device_prop.maxThreadsPerBlock));
+ if (bias) {
+ const T1* bias_data = bias->Data();
+ parameters.q_bias = const_cast(bias_data);
+ parameters.k_bias = const_cast(bias_data + parameters.hidden_size);
+ parameters.v_bias = const_cast(bias_data + 2LL * parameters.hidden_size);
+ }
+
int batch_size = parameters.batch_size;
int sequence_length = parameters.sequence_length;
@@ -142,6 +151,9 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext*
// parameters.k and paraneters.v are nullptr
parameters.k_cache = const_cast(key->Data());
parameters.v_cache = const_cast(value->Data());
+ parameters.k_bias = nullptr;
+ parameters.v_bias = nullptr;
+
} else {
// Sanity check
ORT_ENFORCE(past_present_share_buffer_);
diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
index 5b63621a59..f1ddf13f8b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
@@ -161,40 +161,19 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
q = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q)[qk_offset]));
}
- Qk_vec_k k;
-
- if (!params.is_cross_attention) {
- zero(k);
-
- if (!is_masked) {
- k = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]));
- }
- }
+ // The offset in the bias buffer.
+ int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE;
// Trigger the loads from the Q and K bias buffers.
- Qk_vec_k q_bias;
- Qk_vec_k k_bias;
- if (!params.is_mha) {
- // The offset in the bias buffer.
- int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE;
+ if (params.q_bias && !is_masked) {
+ Qk_vec_k q_bias;
- zero(q_bias);
+ q_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q_bias)[qk_bias_offset]));
- if (!is_masked) {
- q_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q_bias)[qk_bias_offset]));
- }
-
- zero(k_bias);
-
- if (!is_masked) {
- k_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k_bias)[qk_bias_offset]));
- }
-
- // Computes the Q/K values with bias.
q = add_vec(q, q_bias);
- k = add_vec(k, k_bias);
}
+
T* params_k_cache = reinterpret_cast(params.k_cache);
const float inv_sqrt_dh = params.scale;
@@ -205,6 +184,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
}
if (!params.is_cross_attention) {
+ Qk_vec_k k;
+
+ zero(k);
+
+ if (!is_masked) {
+ k = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]));
+
+ if (params.k_bias) {
+ Qk_vec_k k_bias;
+
+ k_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k_bias)[qk_bias_offset]));
+
+ k = add_vec(k, k_bias);
+ }
+ }
+
if (!is_masked) {
// Write the K values to the global memory cache.
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
@@ -438,7 +433,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
// One group of threads computes the product(s) for the current timestep.
V_vec_k v_bias;
- if (!params.is_mha) {
+ if (params.v_bias && !params.is_cross_attention) {
zero(v_bias);
T* params_v_bias = reinterpret_cast(params.v_bias);
@@ -478,7 +473,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
V_vec_k v;
v = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.v)[v_offset]));
- if (!params.is_mha) {
+ if (params.v_bias) {
v = add_vec(v, v_bias);
}
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 8d9e560d4e..7895eff00f 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -676,6 +676,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration",
"M",
OpSchema::Optional)
+ .Input(10,
+ "bias",
+ "Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection",
+ "T",
+ OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, v_hidden_size)",
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 22f82baf5c..74d525324f 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -1301,7 +1301,9 @@ Graph::Graph(const Model& owning_model,
for (auto& node_arg : graph_proto_->value_info()) {
if (utils::HasName(node_arg) && utils::HasType(node_arg)) {
- name_to_type_map[node_arg.name()] = node_arg.type();
+ if (node_arg.name().size() > 0) {
+ name_to_type_map[node_arg.name()] = node_arg.type();
+ }
}
}
diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py
index d388c42218..d1c14de9ee 100755
--- a/onnxruntime/python/tools/symbolic_shape_infer.py
+++ b/onnxruntime/python/tools/symbolic_shape_infer.py
@@ -195,6 +195,7 @@ class SymbolicShapeInference:
"RestorePadding": self._infer_RestorePadding,
"BiasGelu": self._infer_BiasGelu,
"MultiHeadAttention": self._infer_MultiHeadAttention,
+ "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
"FastGelu": self._infer_FastGelu,
"Gelu": self._infer_Gelu,
@@ -2233,10 +2234,33 @@ class SymbolicShapeInference:
present_shape = [batch_size, num_heads, total_sequence_length, head_size]
assert output_dtype is not None
- vi = self.known_vi_[node.output[1]]
- vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
- vi = self.known_vi_[node.output[2]]
- vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
+ if len(node.output) > 2 and node.output[1] and node.output[2]:
+ vi = self.known_vi_[node.output[1]]
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
+ vi = self.known_vi_[node.output[2]]
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
+
+ def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802
+ # Output 0 has shape (batch_size, 1, v_hidden_size)
+ # Q, K and V without packing:
+ # Input 0 (query) has shape (batch_size, 1, hidden_size)
+ # Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size)
+
+ query_shape = self._get_shape(node, 0)
+ if query_shape is not None:
+ output_shape = query_shape
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
+ assert output_dtype is not None
+ vi = self.known_vi_[node.output[0]]
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
+
+ if len(node.output) > 2 and node.output[1] and node.output[2]:
+ past_shape = self._try_get_shape(node, 5)
+ if past_shape is not None:
+ vi = self.known_vi_[node.output[1]]
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
+ vi = self.known_vi_[node.output[2]]
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
def _infer_FastGelu(self, node): # noqa: N802
self._propagate_shape_and_type(node)
@@ -2659,10 +2683,16 @@ class SymbolicShapeInference:
logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
logger.debug("node inputs:")
for i in node.input:
- logger.debug(self.known_vi_[i])
+ if i in self.known_vi_:
+ logger.debug(self.known_vi_[i])
+ else:
+ logger.debug(f"not in knwon_vi_ for {i}")
logger.debug("node outputs:")
for o in node.output:
- logger.debug(self.known_vi_[o])
+ if o in self.known_vi_:
+ logger.debug(self.known_vi_[o])
+ else:
+ logger.debug(f"not in knwon_vi_ for {o}")
if self.auto_merge_ and not out_type_undefined:
logger.debug("Merging: " + str(self.suggested_merge_))
return False
diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py
index 96812dff16..162713e54b 100644
--- a/onnxruntime/python/tools/transformers/convert_generation.py
+++ b/onnxruntime/python/tools/transformers/convert_generation.py
@@ -1207,8 +1207,72 @@ def update_decoder_subgraph_use_decoder_masked_attention(
return True
+def find_past_seq_len_usage(subg: GraphProto):
+ """Correct graph which originally use dim of past_seq_len from input_ids's shape which is fixed to max_seq_len after
+ shared past/present buffer
+
+ Args:
+ subg (GraphProto): GraphProto of the decoder subgraph
+ return:
+ tensor_names_to_rename : set of tensor names which is equal to past_sequence_length
+ nodes_to_remove : list of node to remove
+ """
+ tensor_names_to_rename = set()
+ nodes_to_remove = []
+
+ graph_intput_names = {inp.name: index for index, inp in enumerate(subg.input)}
+
+ input_name_to_nodes = {}
+ output_name_to_node = {}
+ for node in subg.node:
+ for input_name in node.input:
+ if input_name:
+ if input_name not in input_name_to_nodes:
+ input_name_to_nodes[input_name] = [node]
+ else:
+ input_name_to_nodes[input_name].append(node)
+ for output_name in node.output:
+ if output_name:
+ output_name_to_node[output_name] = node
+
+ for node in subg.node:
+ # find "Shape(past_key_self..) --> Gather(*, 2)"
+ if node.op_type == "Gather":
+ if not node.input[1] or not node.input[0]:
+ continue
+ shape_tensor_name, shape_index_name = (node.input[0], node.input[1])
+ ini_gather_indices = None
+ for tensor in subg.initializer:
+ if tensor.name == shape_index_name:
+ ini_gather_indices = tensor
+ break
+ if ini_gather_indices is None:
+ continue
+ gather_indices_arr = onnx.numpy_helper.to_array(ini_gather_indices)
+ if gather_indices_arr.size == 1 and gather_indices_arr.item() == 2 and node.input[0] in output_name_to_node:
+ shape_node = output_name_to_node[shape_tensor_name]
+ if (
+ shape_node.op_type == "Shape"
+ and shape_node.input[0]
+ and shape_node.input[0] in graph_intput_names
+ and (
+ shape_node.input[0].startswith("past_key_self_")
+ or shape_node.input[0].startswith("past_value_self_")
+ )
+ ):
+ tensor_names_to_rename.add(node.output[0])
+ nodes_to_remove.append(node)
+ if len(input_name_to_nodes[shape_node.output[0]]) == 1:
+ nodes_to_remove.append(shape_node)
+ return tensor_names_to_rename, nodes_to_remove
+
+
def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphProto):
- input_self_past_0 = 2
+ input_self_past_0 = 1
+ # w/wo attention mask, w/wo hidden_state
+ graph_input_names = [gi.name for gi in subg.input]
+ while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
+ input_self_past_0 += 1
output_self_past_0 = 1
num_layers = int((len(subg.input) - input_self_past_0) / 4)
@@ -1232,9 +1296,6 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
rel_pos_bias_node = node
break
- if rel_pos_bias_node is None:
- return False
-
decoder_masked_attention_supported_attr = [
"past_present_share_buffer",
"num_heads",
@@ -1243,8 +1304,31 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
"domain",
]
+ target_squeezed_past_seq_name = "past_sequence_length_squeezed_int64"
+ tensor_names_to_rename, nodes_to_remove = find_past_seq_len_usage(subg)
+ if len(tensor_names_to_rename) > 0:
+ for name_to_rename in tensor_names_to_rename:
+ print(f"Found tensor name {name_to_rename} to be renamed to {target_squeezed_past_seq_name}")
+ for nr in nodes_to_remove:
+ print(f"Found node to removed: type:{nr.op_type}, name:{nr.name}")
+
+ squeeze_node = onnx.helper.make_node(
+ "Squeeze",
+ ["past_sequence_length"],
+ ["past_sequence_length_squeezed"],
+ name="node_past_sequence_length_squeeze",
+ )
+ cast_node = onnx.helper.make_node(
+ "Cast",
+ ["past_sequence_length_squeezed"],
+ [target_squeezed_past_seq_name],
+ name="node_past_sequence_length_squeeze_cast",
+ to=TensorProto.INT64,
+ )
+ new_nodes.extend([squeeze_node, cast_node])
+
for node in subg.node:
- if len(node.output) > 0 and node.output[0] == rel_pos_bias_node.input[1]:
+ if len(node.output) > 0 and rel_pos_bias_node is not None and node.output[0] == rel_pos_bias_node.input[1]:
cast_node = onnx.helper.make_node(
"Cast",
["past_sequence_length"],
@@ -1266,20 +1350,16 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
node.input[0], # query
node.input[1], # key
node.input[2], # value
- node.input[4], # 2D mask
- node.input[5], # relative_position_bias
]
- if len(node.input) > 6:
- nis.extend([node.input[6]]) # past_key
- nis.extend([node.input[7]]) # past_value
- else:
- nis.extend([""]) # past_key
- nis.extend([""]) # past_value
-
+ nis.extend([node.input[4] if len(node.input) > 4 else ""]) # 2D mask
+ nis.extend([node.input[5] if len(node.input) > 5 else ""]) # relative_position_bias
+ nis.extend([node.input[6] if len(node.input) > 6 else ""]) # past_key
+ nis.extend([node.input[7] if len(node.input) > 7 else ""]) # past_value
nis.extend(["past_sequence_length"]) # past_sequence_length
nis.extend(["beam_width"]) # beam_width
nis.extend(["cache_indirection"]) # cache_indirection
+ nis.extend([node.input[3] if len(node.input) > 3 else ""]) # bias
kwargs["past_present_share_buffer"] = 1
@@ -1287,10 +1367,15 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
"DecoderMaskedMultiHeadAttention", nis, node.output, name=node.name, **kwargs
)
- new_nodes.extend([node])
+ if node not in nodes_to_remove:
+ for index, name in enumerate(node.input):
+ if name in tensor_names_to_rename:
+ node.input[index] = target_squeezed_past_seq_name
+ new_nodes.extend([node])
subg.ClearField("node")
subg.node.extend(new_nodes)
+ orig_input_names = [inp.name for inp in subg.input]
new_inputs = []
for i, vi in enumerate(subg.input):
@@ -1302,15 +1387,20 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
shape=[shape[0], shape[1], "max_seq_len", shape[3]],
)
new_inputs.extend([vi])
- new_inputs.extend([onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])])
- new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
- new_inputs.extend(
- [
- onnx.helper.make_tensor_value_info(
- "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
- )
- ]
- )
+ if "past_sequence_length" not in orig_input_names:
+ new_inputs.extend(
+ [onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])]
+ )
+ if "beam_width" not in orig_input_names:
+ new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
+ if "cache_indirection" not in orig_input_names:
+ new_inputs.extend(
+ [
+ onnx.helper.make_tensor_value_info(
+ "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
+ )
+ ]
+ )
subg.ClearField("input")
subg.input.extend(new_inputs)
diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
index 9beae2cd38..fa5adbb5fb 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
@@ -216,11 +216,12 @@ def export_onnx_models(
)
config = models["decoder"].config
- if (not use_external_data_format) and (config.num_layers > 24):
+ if (not use_external_data_format) and (config.num_hidden_layers > 24):
logger.info("Try use_external_data_format when model size > 2GB")
output_paths = []
for name, model in models.items():
+ print(f"========> Handling {name} model......")
model.to(device)
filename_suffix = "_" + name
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
index de7134eae5..bfb8a5e5fc 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
@@ -6,7 +6,11 @@ from onnx import TensorProto, helper
from transformers import WhisperConfig
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
-from convert_generation import get_shared_initializers # noqa: E402
+from benchmark_helper import Precision # noqa: E402
+from convert_generation import ( # noqa: E402
+ get_shared_initializers,
+ update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
+)
def chain_model(args):
@@ -54,8 +58,12 @@ def chain_model(args):
)
# beam graph inputs
+ float_data_type = TensorProto.FLOAT
+ if args.precision != Precision.FLOAT32:
+ float_data_type = TensorProto.FLOAT16
+
input_features = helper.make_tensor_value_info(
- "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
+ "input_features", float_data_type, ["batch_size", "feature_size", "sequence_length"]
)
max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
@@ -89,6 +97,12 @@ def chain_model(args):
)
graph_outputs = [sequences]
+ if hasattr(args, "use_gpu") and args.use_gpu:
+ if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
+ print("*****update whisper decoder subgraph successfully!!!*****")
+ else:
+ print("*****DecoderMaskedMultiHeadAttention is not applied to whisper decoder*****")
+
# Initializers/opsets
# Delete shared data between decoder/encoder and move to larger graph initializers
initializers = get_shared_initializers(encoder_model, decoder_model)
@@ -98,6 +112,7 @@ def chain_model(args):
helper.make_attribute("encoder", encoder_model.graph),
]
)
+
opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)]
beam_graph = helper.make_graph([node], "beam-search-test", graph_inputs, graph_outputs, initializers)
diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py
index 39111b16f0..3b1c624720 100644
--- a/onnxruntime/python/tools/transformers/onnx_model.py
+++ b/onnxruntime/python/tools/transformers/onnx_model.py
@@ -946,6 +946,7 @@ class OnnxModel:
sorted_node_set_len = -1
graph_nodes = graph.node if not is_deterministic else sorted(graph.node, key=lambda x: x.name)
+
last_node_name = None
while len(sorted_node_set) != len(graph_nodes):
if len(sorted_node_set) == sorted_node_set_len:
@@ -959,7 +960,8 @@ class OnnxModel:
sorted_nodes.append(node)
sorted_node_set.add(node_idx)
for output in node.output:
- deps_set.add(output)
+ if output:
+ deps_set.add(output)
continue
failed = False
for input_name in node.input:
@@ -970,7 +972,8 @@ class OnnxModel:
sorted_nodes.append(node)
sorted_node_set.add(node_idx)
for output in node.output:
- deps_set.add(output)
+ if output:
+ deps_set.add(output)
else:
continue