From 0f8e66d905e796f93da372396331bb15c5108a04 Mon Sep 17 00:00:00 2001 From: Zhang Lei Date: Thu, 18 May 2023 15:38:31 -0700 Subject: [PATCH] optimization for whisper model with decoder masked multihead attention (#15827) * graph tools update * cuda kernel update * operator spec update and implementation update * greed search bug fix on wrong assumption for cross/self attention input length * avoid use of "" name in value info when loading graph which historically in many model --- docs/ContribOperators.md | 4 +- docs/OperatorKernels.md | 2 +- .../cpu/transformers/beam_search_impl_base.h | 8 + .../cpu/transformers/beam_search_impl_t5.h | 5 + .../transformers/beam_search_impl_whisper.h | 5 + .../transformers/subgraph_whisper_decoder.cc | 9 +- .../decoder_masked_multihead_attention.cc | 14 +- ...decoder_masked_multihead_attention_impl.cu | 53 +++---- .../core/graph/contrib_ops/bert_defs.cc | 5 + onnxruntime/core/graph/graph.cc | 4 +- .../python/tools/symbolic_shape_infer.py | 42 +++++- .../tools/transformers/convert_generation.py | 138 +++++++++++++++--- .../models/whisper/convert_to_onnx.py | 3 +- .../models/whisper/whisper_chain.py | 19 ++- .../python/tools/transformers/onnx_model.py | 7 +- 15 files changed, 246 insertions(+), 72 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 07e892cabd..beb7029895 100755 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1134,7 +1134,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs (1 - 10) +#### Inputs (1 - 11)
query : T
@@ -1157,6 +1157,8 @@ This version of the operator has been available since version 1 of the 'com.micr
The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
cache_indirection (optional) : M
A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
+
bias (optional) : T
+
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
#### Outputs (1 - 3) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 71a6a3afeb..c23209d4fe 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -802,7 +802,7 @@ Do not modify directly.* |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)| -|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| +|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)| |DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 804cf2cf10..52b1efdeb2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -62,6 +62,14 @@ struct BeamSearchState : IBeamSearchState { } } + void EnsurePastStateReorderStagingBuffer(AllocatorPtr allocator, int64_t sz) { + auto current_buffer_size = this->staging_for_past_state_reorder.Shape().Size(); + if (sz > current_buffer_size) { + TensorShape buffer_shape = {sz}; + this->staging_for_past_state_reorder = Tensor(DataTypeImpl::GetType(), buffer_shape, allocator); + } + } + private: BufferUniquePtr next_token_logits_buffer_; BufferUniquePtr next_token_scores_buffer_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 5c4b7e2753..6385511785 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -254,6 +254,11 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches if (decoder_subgraph_.has_decoder_masked_attention_) { size_t offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()); + // Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same + auto first_cross_attention_key = decoder_feeds[offset + 2 * static_cast(decoder_subgraph_.num_layers)].GetMutable(); + auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size(); + beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz); + // Here we only need to reorder the past key for self-attention and cross-attention. for (size_t i = 0; i < 2 * static_cast(decoder_subgraph_.num_layers); ++i) { ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index a5178422c0..5a7c154d31 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -245,6 +245,11 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe if (decoder_subgraph_.has_decoder_masked_attention_) { size_t offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()); + // Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same + auto first_cross_attention_key = decoder_feeds[offset + 2 * static_cast(decoder_subgraph_.num_layers)].GetMutable(); + auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size(); + beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz); + // Here we only need to reorder the past key for self-attention and cross-attention. for (size_t i = 0; i < 2 * static_cast(decoder_subgraph_.num_layers); ++i) { ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index 37f3e5a43d..887a6a8984 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -106,9 +106,9 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type, "decoder subgraph input 0 (input_ids) shall have int32 type"); - auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type(); + auto float_type = subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type(); ORT_RETURN_IF(float_type != float32_type && float_type != float16_type, - "decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type"); + "decoder subgraph input 1 (encoder_hidden_states) shall have float or float16 type"); for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) { ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, @@ -202,6 +202,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( // When first_past_input_index_ == 2, the encoder_hidden_states and past states are copied from the second output // of encoder. // When first_past_input_index_ == 1, the past states are copied from the second output of encoder. + // TODO: MAKE IT MORE READABLE for (size_t j = static_cast(3) - first_past_input_index_; j < encoder_fetches.size(); j++) { if (j == 1) { ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false"); @@ -226,7 +227,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( decoder_feeds.push_back(expanded_hidden_states); } else { // past key/value for cross attention does not need to be initialized with max_seq_len since they are static. - bool use_max_seq_len = (j - first_past_input_index_) < 2 * static_cast(num_layers); + bool use_max_seq_len = (j - first_past_input_index_) <= 2 * static_cast(num_layers); OrtValue expanded_cache; if (is_output_float16_) { @@ -252,7 +253,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( if (past_present_share_buffer_) { // Past sequence length feed - ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, 1)); + ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, cur_len - 1)); // Add beam search specific inputs if (need_cache_indir) { const int64_t batch_size = static_cast(batch_beam_size / num_beam); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index d5053e1d85..4bdc6db30b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -22,6 +22,7 @@ static constexpr int kBeamWidthInputIndex = 8; static constexpr int kCacheIndirectionInputIndex = 9; static constexpr int kPastInputIndex = 5; static constexpr int kPresentOutputIndex = 1; +static constexpr int kBiasIndex = 10; #define REGISTER_KERNEL_TYPED(T1, T2) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ @@ -63,6 +64,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); const Tensor* beam_width = context->Input(kBeamWidthInputIndex); const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex); + const Tensor* bias = context->Input(kBiasIndex); auto& device_prop = GetDeviceProp(); DecoderMaskedMultiHeadAttentionParams parameters; @@ -70,7 +72,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, - nullptr, // bias + bias, mask_index, relative_position_bias, past_key, @@ -84,6 +86,13 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* is_dmmha_packing, // dmmha_packing device_prop.maxThreadsPerBlock)); + if (bias) { + const T1* bias_data = bias->Data(); + parameters.q_bias = const_cast(bias_data); + parameters.k_bias = const_cast(bias_data + parameters.hidden_size); + parameters.v_bias = const_cast(bias_data + 2LL * parameters.hidden_size); + } + int batch_size = parameters.batch_size; int sequence_length = parameters.sequence_length; @@ -142,6 +151,9 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* // parameters.k and paraneters.v are nullptr parameters.k_cache = const_cast(key->Data()); parameters.v_cache = const_cast(value->Data()); + parameters.k_bias = nullptr; + parameters.v_bias = nullptr; + } else { // Sanity check ORT_ENFORCE(past_present_share_buffer_); diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 5b63621a59..f1ddf13f8b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -161,40 +161,19 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio q = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q)[qk_offset])); } - Qk_vec_k k; - - if (!params.is_cross_attention) { - zero(k); - - if (!is_masked) { - k = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k)[qk_offset])); - } - } + // The offset in the bias buffer. + int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE; // Trigger the loads from the Q and K bias buffers. - Qk_vec_k q_bias; - Qk_vec_k k_bias; - if (!params.is_mha) { - // The offset in the bias buffer. - int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE; + if (params.q_bias && !is_masked) { + Qk_vec_k q_bias; - zero(q_bias); + q_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q_bias)[qk_bias_offset])); - if (!is_masked) { - q_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q_bias)[qk_bias_offset])); - } - - zero(k_bias); - - if (!is_masked) { - k_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k_bias)[qk_bias_offset])); - } - - // Computes the Q/K values with bias. q = add_vec(q, q_bias); - k = add_vec(k, k_bias); } + T* params_k_cache = reinterpret_cast(params.k_cache); const float inv_sqrt_dh = params.scale; @@ -205,6 +184,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio } if (!params.is_cross_attention) { + Qk_vec_k k; + + zero(k); + + if (!is_masked) { + k = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k)[qk_offset])); + + if (params.k_bias) { + Qk_vec_k k_bias; + + k_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k_bias)[qk_bias_offset])); + + k = add_vec(k, k_bias); + } + } + if (!is_masked) { // Write the K values to the global memory cache. // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory @@ -438,7 +433,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // One group of threads computes the product(s) for the current timestep. V_vec_k v_bias; - if (!params.is_mha) { + if (params.v_bias && !params.is_cross_attention) { zero(v_bias); T* params_v_bias = reinterpret_cast(params.v_bias); @@ -478,7 +473,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio V_vec_k v; v = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.v)[v_offset])); - if (!params.is_mha) { + if (params.v_bias) { v = add_vec(v, v_bias); } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 8d9e560d4e..7895eff00f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -676,6 +676,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration", "M", OpSchema::Optional) + .Input(10, + "bias", + "Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, v_hidden_size)", diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 22f82baf5c..74d525324f 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1301,7 +1301,9 @@ Graph::Graph(const Model& owning_model, for (auto& node_arg : graph_proto_->value_info()) { if (utils::HasName(node_arg) && utils::HasType(node_arg)) { - name_to_type_map[node_arg.name()] = node_arg.type(); + if (node_arg.name().size() > 0) { + name_to_type_map[node_arg.name()] = node_arg.type(); + } } } diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index d388c42218..d1c14de9ee 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -195,6 +195,7 @@ class SymbolicShapeInference: "RestorePadding": self._infer_RestorePadding, "BiasGelu": self._infer_BiasGelu, "MultiHeadAttention": self._infer_MultiHeadAttention, + "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, "FastGelu": self._infer_FastGelu, "Gelu": self._infer_Gelu, @@ -2233,10 +2234,33 @@ class SymbolicShapeInference: present_shape = [batch_size, num_heads, total_sequence_length, head_size] assert output_dtype is not None - vi = self.known_vi_[node.output[1]] - vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) - vi = self.known_vi_[node.output[2]] - vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + if len(node.output) > 2 and node.output[1] and node.output[2]: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + + def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802 + # Output 0 has shape (batch_size, 1, v_hidden_size) + # Q, K and V without packing: + # Input 0 (query) has shape (batch_size, 1, hidden_size) + # Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size) + + query_shape = self._get_shape(node, 0) + if query_shape is not None: + output_shape = query_shape + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + assert output_dtype is not None + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + if len(node.output) > 2 and node.output[1] and node.output[2]: + past_shape = self._try_get_shape(node, 5) + if past_shape is not None: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) def _infer_FastGelu(self, node): # noqa: N802 self._propagate_shape_and_type(node) @@ -2659,10 +2683,16 @@ class SymbolicShapeInference: logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name) logger.debug("node inputs:") for i in node.input: - logger.debug(self.known_vi_[i]) + if i in self.known_vi_: + logger.debug(self.known_vi_[i]) + else: + logger.debug(f"not in knwon_vi_ for {i}") logger.debug("node outputs:") for o in node.output: - logger.debug(self.known_vi_[o]) + if o in self.known_vi_: + logger.debug(self.known_vi_[o]) + else: + logger.debug(f"not in knwon_vi_ for {o}") if self.auto_merge_ and not out_type_undefined: logger.debug("Merging: " + str(self.suggested_merge_)) return False diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 96812dff16..162713e54b 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1207,8 +1207,72 @@ def update_decoder_subgraph_use_decoder_masked_attention( return True +def find_past_seq_len_usage(subg: GraphProto): + """Correct graph which originally use dim of past_seq_len from input_ids's shape which is fixed to max_seq_len after + shared past/present buffer + + Args: + subg (GraphProto): GraphProto of the decoder subgraph + return: + tensor_names_to_rename : set of tensor names which is equal to past_sequence_length + nodes_to_remove : list of node to remove + """ + tensor_names_to_rename = set() + nodes_to_remove = [] + + graph_intput_names = {inp.name: index for index, inp in enumerate(subg.input)} + + input_name_to_nodes = {} + output_name_to_node = {} + for node in subg.node: + for input_name in node.input: + if input_name: + if input_name not in input_name_to_nodes: + input_name_to_nodes[input_name] = [node] + else: + input_name_to_nodes[input_name].append(node) + for output_name in node.output: + if output_name: + output_name_to_node[output_name] = node + + for node in subg.node: + # find "Shape(past_key_self..) --> Gather(*, 2)" + if node.op_type == "Gather": + if not node.input[1] or not node.input[0]: + continue + shape_tensor_name, shape_index_name = (node.input[0], node.input[1]) + ini_gather_indices = None + for tensor in subg.initializer: + if tensor.name == shape_index_name: + ini_gather_indices = tensor + break + if ini_gather_indices is None: + continue + gather_indices_arr = onnx.numpy_helper.to_array(ini_gather_indices) + if gather_indices_arr.size == 1 and gather_indices_arr.item() == 2 and node.input[0] in output_name_to_node: + shape_node = output_name_to_node[shape_tensor_name] + if ( + shape_node.op_type == "Shape" + and shape_node.input[0] + and shape_node.input[0] in graph_intput_names + and ( + shape_node.input[0].startswith("past_key_self_") + or shape_node.input[0].startswith("past_value_self_") + ) + ): + tensor_names_to_rename.add(node.output[0]) + nodes_to_remove.append(node) + if len(input_name_to_nodes[shape_node.output[0]]) == 1: + nodes_to_remove.append(shape_node) + return tensor_names_to_rename, nodes_to_remove + + def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphProto): - input_self_past_0 = 2 + input_self_past_0 = 1 + # w/wo attention mask, w/wo hidden_state + graph_input_names = [gi.name for gi in subg.input] + while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"): + input_self_past_0 += 1 output_self_past_0 = 1 num_layers = int((len(subg.input) - input_self_past_0) / 4) @@ -1232,9 +1296,6 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP rel_pos_bias_node = node break - if rel_pos_bias_node is None: - return False - decoder_masked_attention_supported_attr = [ "past_present_share_buffer", "num_heads", @@ -1243,8 +1304,31 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP "domain", ] + target_squeezed_past_seq_name = "past_sequence_length_squeezed_int64" + tensor_names_to_rename, nodes_to_remove = find_past_seq_len_usage(subg) + if len(tensor_names_to_rename) > 0: + for name_to_rename in tensor_names_to_rename: + print(f"Found tensor name {name_to_rename} to be renamed to {target_squeezed_past_seq_name}") + for nr in nodes_to_remove: + print(f"Found node to removed: type:{nr.op_type}, name:{nr.name}") + + squeeze_node = onnx.helper.make_node( + "Squeeze", + ["past_sequence_length"], + ["past_sequence_length_squeezed"], + name="node_past_sequence_length_squeeze", + ) + cast_node = onnx.helper.make_node( + "Cast", + ["past_sequence_length_squeezed"], + [target_squeezed_past_seq_name], + name="node_past_sequence_length_squeeze_cast", + to=TensorProto.INT64, + ) + new_nodes.extend([squeeze_node, cast_node]) + for node in subg.node: - if len(node.output) > 0 and node.output[0] == rel_pos_bias_node.input[1]: + if len(node.output) > 0 and rel_pos_bias_node is not None and node.output[0] == rel_pos_bias_node.input[1]: cast_node = onnx.helper.make_node( "Cast", ["past_sequence_length"], @@ -1266,20 +1350,16 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP node.input[0], # query node.input[1], # key node.input[2], # value - node.input[4], # 2D mask - node.input[5], # relative_position_bias ] - if len(node.input) > 6: - nis.extend([node.input[6]]) # past_key - nis.extend([node.input[7]]) # past_value - else: - nis.extend([""]) # past_key - nis.extend([""]) # past_value - + nis.extend([node.input[4] if len(node.input) > 4 else ""]) # 2D mask + nis.extend([node.input[5] if len(node.input) > 5 else ""]) # relative_position_bias + nis.extend([node.input[6] if len(node.input) > 6 else ""]) # past_key + nis.extend([node.input[7] if len(node.input) > 7 else ""]) # past_value nis.extend(["past_sequence_length"]) # past_sequence_length nis.extend(["beam_width"]) # beam_width nis.extend(["cache_indirection"]) # cache_indirection + nis.extend([node.input[3] if len(node.input) > 3 else ""]) # bias kwargs["past_present_share_buffer"] = 1 @@ -1287,10 +1367,15 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP "DecoderMaskedMultiHeadAttention", nis, node.output, name=node.name, **kwargs ) - new_nodes.extend([node]) + if node not in nodes_to_remove: + for index, name in enumerate(node.input): + if name in tensor_names_to_rename: + node.input[index] = target_squeezed_past_seq_name + new_nodes.extend([node]) subg.ClearField("node") subg.node.extend(new_nodes) + orig_input_names = [inp.name for inp in subg.input] new_inputs = [] for i, vi in enumerate(subg.input): @@ -1302,15 +1387,20 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP shape=[shape[0], shape[1], "max_seq_len", shape[3]], ) new_inputs.extend([vi]) - new_inputs.extend([onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])]) - new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])]) - new_inputs.extend( - [ - onnx.helper.make_tensor_value_info( - "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"] - ) - ] - ) + if "past_sequence_length" not in orig_input_names: + new_inputs.extend( + [onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])] + ) + if "beam_width" not in orig_input_names: + new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])]) + if "cache_indirection" not in orig_input_names: + new_inputs.extend( + [ + onnx.helper.make_tensor_value_info( + "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"] + ) + ] + ) subg.ClearField("input") subg.input.extend(new_inputs) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 9beae2cd38..fa5adbb5fb 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -216,11 +216,12 @@ def export_onnx_models( ) config = models["decoder"].config - if (not use_external_data_format) and (config.num_layers > 24): + if (not use_external_data_format) and (config.num_hidden_layers > 24): logger.info("Try use_external_data_format when model size > 2GB") output_paths = [] for name, model in models.items(): + print(f"========> Handling {name} model......") model.to(device) filename_suffix = "_" + name diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index de7134eae5..bfb8a5e5fc 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -6,7 +6,11 @@ from onnx import TensorProto, helper from transformers import WhisperConfig sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from convert_generation import get_shared_initializers # noqa: E402 +from benchmark_helper import Precision # noqa: E402 +from convert_generation import ( # noqa: E402 + get_shared_initializers, + update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, +) def chain_model(args): @@ -54,8 +58,12 @@ def chain_model(args): ) # beam graph inputs + float_data_type = TensorProto.FLOAT + if args.precision != Precision.FLOAT32: + float_data_type = TensorProto.FLOAT16 + input_features = helper.make_tensor_value_info( - "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] + "input_features", float_data_type, ["batch_size", "feature_size", "sequence_length"] ) max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1]) min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1]) @@ -89,6 +97,12 @@ def chain_model(args): ) graph_outputs = [sequences] + if hasattr(args, "use_gpu") and args.use_gpu: + if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): + print("*****update whisper decoder subgraph successfully!!!*****") + else: + print("*****DecoderMaskedMultiHeadAttention is not applied to whisper decoder*****") + # Initializers/opsets # Delete shared data between decoder/encoder and move to larger graph initializers initializers = get_shared_initializers(encoder_model, decoder_model) @@ -98,6 +112,7 @@ def chain_model(args): helper.make_attribute("encoder", encoder_model.graph), ] ) + opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] beam_graph = helper.make_graph([node], "beam-search-test", graph_inputs, graph_outputs, initializers) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 39111b16f0..3b1c624720 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -946,6 +946,7 @@ class OnnxModel: sorted_node_set_len = -1 graph_nodes = graph.node if not is_deterministic else sorted(graph.node, key=lambda x: x.name) + last_node_name = None while len(sorted_node_set) != len(graph_nodes): if len(sorted_node_set) == sorted_node_set_len: @@ -959,7 +960,8 @@ class OnnxModel: sorted_nodes.append(node) sorted_node_set.add(node_idx) for output in node.output: - deps_set.add(output) + if output: + deps_set.add(output) continue failed = False for input_name in node.input: @@ -970,7 +972,8 @@ class OnnxModel: sorted_nodes.append(node) sorted_node_set.add(node_idx) for output in node.output: - deps_set.add(output) + if output: + deps_set.add(output) else: continue