mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
optimization for whisper model with decoder masked multihead attention (#15827)
* graph tools update * cuda kernel update * operator spec update and implementation update * greed search bug fix on wrong assumption for cross/self attention input length * avoid use of "" name in value info when loading graph which historically in many model
This commit is contained in:
parent
be6c0bb53c
commit
0f8e66d905
15 changed files with 246 additions and 72 deletions
|
|
@ -1134,7 +1134,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (1 - 10)
|
||||
#### Inputs (1 - 11)
|
||||
|
||||
<dl>
|
||||
<dt><tt>query</tt> : T</dt>
|
||||
|
|
@ -1157,6 +1157,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
|
||||
<dt><tt>cache_indirection</tt> (optional) : M</dt>
|
||||
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
|
||||
<dt><tt>bias</tt> (optional) : T</dt>
|
||||
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs (1 - 3)
|
||||
|
|
|
|||
|
|
@ -802,7 +802,7 @@ Do not modify directly.*
|
|||
|ComplexMulConj|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|DecoderAttention|*in* query:**T**<br> *in* key:**T**<br> *in* q_weight:**T**<br> *in* kv_weight:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**B**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* static_kv:**B**<br> *in* use_past:**B**<br> *in* has_layer_state:**B**<br> *in* has_key_padding_mask:**B**<br> *out* output:**T**<br> *out* new_key_cache:**T**<br> *out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|DecoderMaskedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* mask_index:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|DecoderMaskedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* mask_index:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|DecoderMaskedSelfAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(float16)|
|
||||
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|
||||
|
|
|
|||
|
|
@ -62,6 +62,14 @@ struct BeamSearchState : IBeamSearchState<T> {
|
|||
}
|
||||
}
|
||||
|
||||
void EnsurePastStateReorderStagingBuffer(AllocatorPtr allocator, int64_t sz) {
|
||||
auto current_buffer_size = this->staging_for_past_state_reorder.Shape().Size();
|
||||
if (sz > current_buffer_size) {
|
||||
TensorShape buffer_shape = {sz};
|
||||
this->staging_for_past_state_reorder = Tensor(DataTypeImpl::GetType<T>(), buffer_shape, allocator);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
BufferUniquePtr next_token_logits_buffer_;
|
||||
BufferUniquePtr next_token_scores_buffer_;
|
||||
|
|
|
|||
|
|
@ -254,6 +254,11 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
|
|||
|
||||
if (decoder_subgraph_.has_decoder_masked_attention_) {
|
||||
size_t offset = static_cast<size_t>(decoder_subgraph_.GetFirstPastInputIndex());
|
||||
// Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same
|
||||
auto first_cross_attention_key = decoder_feeds[offset + 2 * static_cast<size_t>(decoder_subgraph_.num_layers)].GetMutable<Tensor>();
|
||||
auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
|
||||
beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);
|
||||
|
||||
// Here we only need to reorder the past key for self-attention and cross-attention.
|
||||
for (size_t i = 0; i < 2 * static_cast<size_t>(decoder_subgraph_.num_layers); ++i) {
|
||||
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
|
||||
|
|
|
|||
|
|
@ -245,6 +245,11 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
|
|||
|
||||
if (decoder_subgraph_.has_decoder_masked_attention_) {
|
||||
size_t offset = static_cast<size_t>(decoder_subgraph_.GetFirstPastInputIndex());
|
||||
// Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same
|
||||
auto first_cross_attention_key = decoder_feeds[offset + 2 * static_cast<size_t>(decoder_subgraph_.num_layers)].GetMutable<Tensor>();
|
||||
auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
|
||||
beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);
|
||||
|
||||
// Here we only need to reorder the past key for self-attention and cross-attention.
|
||||
for (size_t i = 0; i < 2 * static_cast<size_t>(decoder_subgraph_.num_layers); ++i) {
|
||||
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
|
||||
|
|
|
|||
|
|
@ -106,9 +106,9 @@ Status WhisperDecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgr
|
|||
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"decoder subgraph input 0 (input_ids) shall have int32 type");
|
||||
|
||||
auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
|
||||
auto float_type = subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type();
|
||||
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
|
||||
"decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
|
||||
"decoder subgraph input 1 (encoder_hidden_states) shall have float or float16 type");
|
||||
|
||||
for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) {
|
||||
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
|
||||
|
|
@ -202,6 +202,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds(
|
|||
// When first_past_input_index_ == 2, the encoder_hidden_states and past states are copied from the second output
|
||||
// of encoder.
|
||||
// When first_past_input_index_ == 1, the past states are copied from the second output of encoder.
|
||||
// TODO: MAKE IT MORE READABLE
|
||||
for (size_t j = static_cast<size_t>(3) - first_past_input_index_; j < encoder_fetches.size(); j++) {
|
||||
if (j == 1) {
|
||||
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
|
||||
|
|
@ -226,7 +227,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds(
|
|||
decoder_feeds.push_back(expanded_hidden_states);
|
||||
} else {
|
||||
// past key/value for cross attention does not need to be initialized with max_seq_len since they are static.
|
||||
bool use_max_seq_len = (j - first_past_input_index_) < 2 * static_cast<size_t>(num_layers);
|
||||
bool use_max_seq_len = (j - first_past_input_index_) <= 2 * static_cast<size_t>(num_layers);
|
||||
|
||||
OrtValue expanded_cache;
|
||||
if (is_output_float16_) {
|
||||
|
|
@ -252,7 +253,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds(
|
|||
|
||||
if (past_present_share_buffer_) {
|
||||
// Past sequence length feed
|
||||
ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, 1));
|
||||
ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, cur_len - 1));
|
||||
// Add beam search specific inputs
|
||||
if (need_cache_indir) {
|
||||
const int64_t batch_size = static_cast<int64_t>(batch_beam_size / num_beam);
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ static constexpr int kBeamWidthInputIndex = 8;
|
|||
static constexpr int kCacheIndirectionInputIndex = 9;
|
||||
static constexpr int kPastInputIndex = 5;
|
||||
static constexpr int kPresentOutputIndex = 1;
|
||||
static constexpr int kBiasIndex = 10;
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T1, T2) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
|
|
@ -63,6 +64,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
|
||||
const Tensor* beam_width = context->Input<Tensor>(kBeamWidthInputIndex);
|
||||
const Tensor* cache_indir = context->Input<Tensor>(kCacheIndirectionInputIndex);
|
||||
const Tensor* bias = context->Input<Tensor>(kBiasIndex);
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
DecoderMaskedMultiHeadAttentionParams parameters;
|
||||
|
|
@ -70,7 +72,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
|
||||
key,
|
||||
value,
|
||||
nullptr, // bias
|
||||
bias,
|
||||
mask_index,
|
||||
relative_position_bias,
|
||||
past_key,
|
||||
|
|
@ -84,6 +86,13 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
is_dmmha_packing, // dmmha_packing
|
||||
device_prop.maxThreadsPerBlock));
|
||||
|
||||
if (bias) {
|
||||
const T1* bias_data = bias->Data<T1>();
|
||||
parameters.q_bias = const_cast<T1*>(bias_data);
|
||||
parameters.k_bias = const_cast<T1*>(bias_data + parameters.hidden_size);
|
||||
parameters.v_bias = const_cast<T1*>(bias_data + 2LL * parameters.hidden_size);
|
||||
}
|
||||
|
||||
int batch_size = parameters.batch_size;
|
||||
int sequence_length = parameters.sequence_length;
|
||||
|
||||
|
|
@ -142,6 +151,9 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
// parameters.k and paraneters.v are nullptr
|
||||
parameters.k_cache = const_cast<T1*>(key->Data<T1>());
|
||||
parameters.v_cache = const_cast<T1*>(value->Data<T1>());
|
||||
parameters.k_bias = nullptr;
|
||||
parameters.v_bias = nullptr;
|
||||
|
||||
} else {
|
||||
// Sanity check
|
||||
ORT_ENFORCE(past_present_share_buffer_);
|
||||
|
|
|
|||
|
|
@ -161,40 +161,19 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
|
|||
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.q)[qk_offset]));
|
||||
}
|
||||
|
||||
Qk_vec_k k;
|
||||
|
||||
if (!params.is_cross_attention) {
|
||||
zero(k);
|
||||
|
||||
if (!is_masked) {
|
||||
k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.k)[qk_offset]));
|
||||
}
|
||||
}
|
||||
// The offset in the bias buffer.
|
||||
int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE;
|
||||
|
||||
// Trigger the loads from the Q and K bias buffers.
|
||||
Qk_vec_k q_bias;
|
||||
Qk_vec_k k_bias;
|
||||
if (!params.is_mha) {
|
||||
// The offset in the bias buffer.
|
||||
int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE;
|
||||
if (params.q_bias && !is_masked) {
|
||||
Qk_vec_k q_bias;
|
||||
|
||||
zero(q_bias);
|
||||
q_bias = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.q_bias)[qk_bias_offset]));
|
||||
|
||||
if (!is_masked) {
|
||||
q_bias = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.q_bias)[qk_bias_offset]));
|
||||
}
|
||||
|
||||
zero(k_bias);
|
||||
|
||||
if (!is_masked) {
|
||||
k_bias = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.k_bias)[qk_bias_offset]));
|
||||
}
|
||||
|
||||
// Computes the Q/K values with bias.
|
||||
q = add_vec(q, q_bias);
|
||||
k = add_vec(k, k_bias);
|
||||
}
|
||||
|
||||
|
||||
T* params_k_cache = reinterpret_cast<T*>(params.k_cache);
|
||||
|
||||
const float inv_sqrt_dh = params.scale;
|
||||
|
|
@ -205,6 +184,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
|
|||
}
|
||||
|
||||
if (!params.is_cross_attention) {
|
||||
Qk_vec_k k;
|
||||
|
||||
zero(k);
|
||||
|
||||
if (!is_masked) {
|
||||
k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.k)[qk_offset]));
|
||||
|
||||
if (params.k_bias) {
|
||||
Qk_vec_k k_bias;
|
||||
|
||||
k_bias = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.k_bias)[qk_bias_offset]));
|
||||
|
||||
k = add_vec(k, k_bias);
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_masked) {
|
||||
// Write the K values to the global memory cache.
|
||||
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
|
||||
|
|
@ -438,7 +433,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
|
|||
|
||||
// One group of threads computes the product(s) for the current timestep.
|
||||
V_vec_k v_bias;
|
||||
if (!params.is_mha) {
|
||||
if (params.v_bias && !params.is_cross_attention) {
|
||||
zero(v_bias);
|
||||
|
||||
T* params_v_bias = reinterpret_cast<T*>(params.v_bias);
|
||||
|
|
@ -478,7 +473,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
|
|||
|
||||
V_vec_k v;
|
||||
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&reinterpret_cast<T*>(params.v)[v_offset]));
|
||||
if (!params.is_mha) {
|
||||
if (params.v_bias) {
|
||||
v = add_vec(v, v_bias);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -676,6 +676,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration",
|
||||
"M",
|
||||
OpSchema::Optional)
|
||||
.Input(10,
|
||||
"bias",
|
||||
"Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Output(0,
|
||||
"output",
|
||||
"3D output tensor with shape (batch_size, sequence_length, v_hidden_size)",
|
||||
|
|
|
|||
|
|
@ -1301,7 +1301,9 @@ Graph::Graph(const Model& owning_model,
|
|||
|
||||
for (auto& node_arg : graph_proto_->value_info()) {
|
||||
if (utils::HasName(node_arg) && utils::HasType(node_arg)) {
|
||||
name_to_type_map[node_arg.name()] = node_arg.type();
|
||||
if (node_arg.name().size() > 0) {
|
||||
name_to_type_map[node_arg.name()] = node_arg.type();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -195,6 +195,7 @@ class SymbolicShapeInference:
|
|||
"RestorePadding": self._infer_RestorePadding,
|
||||
"BiasGelu": self._infer_BiasGelu,
|
||||
"MultiHeadAttention": self._infer_MultiHeadAttention,
|
||||
"DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
|
||||
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
|
||||
"FastGelu": self._infer_FastGelu,
|
||||
"Gelu": self._infer_Gelu,
|
||||
|
|
@ -2233,10 +2234,33 @@ class SymbolicShapeInference:
|
|||
present_shape = [batch_size, num_heads, total_sequence_length, head_size]
|
||||
|
||||
assert output_dtype is not None
|
||||
vi = self.known_vi_[node.output[1]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
||||
vi = self.known_vi_[node.output[2]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
||||
if len(node.output) > 2 and node.output[1] and node.output[2]:
|
||||
vi = self.known_vi_[node.output[1]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
||||
vi = self.known_vi_[node.output[2]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
||||
|
||||
def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802
|
||||
# Output 0 has shape (batch_size, 1, v_hidden_size)
|
||||
# Q, K and V without packing:
|
||||
# Input 0 (query) has shape (batch_size, 1, hidden_size)
|
||||
# Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size)
|
||||
|
||||
query_shape = self._get_shape(node, 0)
|
||||
if query_shape is not None:
|
||||
output_shape = query_shape
|
||||
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
assert output_dtype is not None
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
||||
|
||||
if len(node.output) > 2 and node.output[1] and node.output[2]:
|
||||
past_shape = self._try_get_shape(node, 5)
|
||||
if past_shape is not None:
|
||||
vi = self.known_vi_[node.output[1]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
||||
vi = self.known_vi_[node.output[2]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
||||
|
||||
def _infer_FastGelu(self, node): # noqa: N802
|
||||
self._propagate_shape_and_type(node)
|
||||
|
|
@ -2659,10 +2683,16 @@ class SymbolicShapeInference:
|
|||
logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
|
||||
logger.debug("node inputs:")
|
||||
for i in node.input:
|
||||
logger.debug(self.known_vi_[i])
|
||||
if i in self.known_vi_:
|
||||
logger.debug(self.known_vi_[i])
|
||||
else:
|
||||
logger.debug(f"not in knwon_vi_ for {i}")
|
||||
logger.debug("node outputs:")
|
||||
for o in node.output:
|
||||
logger.debug(self.known_vi_[o])
|
||||
if o in self.known_vi_:
|
||||
logger.debug(self.known_vi_[o])
|
||||
else:
|
||||
logger.debug(f"not in knwon_vi_ for {o}")
|
||||
if self.auto_merge_ and not out_type_undefined:
|
||||
logger.debug("Merging: " + str(self.suggested_merge_))
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -1207,8 +1207,72 @@ def update_decoder_subgraph_use_decoder_masked_attention(
|
|||
return True
|
||||
|
||||
|
||||
def find_past_seq_len_usage(subg: GraphProto):
|
||||
"""Correct graph which originally use dim of past_seq_len from input_ids's shape which is fixed to max_seq_len after
|
||||
shared past/present buffer
|
||||
|
||||
Args:
|
||||
subg (GraphProto): GraphProto of the decoder subgraph
|
||||
return:
|
||||
tensor_names_to_rename : set of tensor names which is equal to past_sequence_length
|
||||
nodes_to_remove : list of node to remove
|
||||
"""
|
||||
tensor_names_to_rename = set()
|
||||
nodes_to_remove = []
|
||||
|
||||
graph_intput_names = {inp.name: index for index, inp in enumerate(subg.input)}
|
||||
|
||||
input_name_to_nodes = {}
|
||||
output_name_to_node = {}
|
||||
for node in subg.node:
|
||||
for input_name in node.input:
|
||||
if input_name:
|
||||
if input_name not in input_name_to_nodes:
|
||||
input_name_to_nodes[input_name] = [node]
|
||||
else:
|
||||
input_name_to_nodes[input_name].append(node)
|
||||
for output_name in node.output:
|
||||
if output_name:
|
||||
output_name_to_node[output_name] = node
|
||||
|
||||
for node in subg.node:
|
||||
# find "Shape(past_key_self..) --> Gather(*, 2)"
|
||||
if node.op_type == "Gather":
|
||||
if not node.input[1] or not node.input[0]:
|
||||
continue
|
||||
shape_tensor_name, shape_index_name = (node.input[0], node.input[1])
|
||||
ini_gather_indices = None
|
||||
for tensor in subg.initializer:
|
||||
if tensor.name == shape_index_name:
|
||||
ini_gather_indices = tensor
|
||||
break
|
||||
if ini_gather_indices is None:
|
||||
continue
|
||||
gather_indices_arr = onnx.numpy_helper.to_array(ini_gather_indices)
|
||||
if gather_indices_arr.size == 1 and gather_indices_arr.item() == 2 and node.input[0] in output_name_to_node:
|
||||
shape_node = output_name_to_node[shape_tensor_name]
|
||||
if (
|
||||
shape_node.op_type == "Shape"
|
||||
and shape_node.input[0]
|
||||
and shape_node.input[0] in graph_intput_names
|
||||
and (
|
||||
shape_node.input[0].startswith("past_key_self_")
|
||||
or shape_node.input[0].startswith("past_value_self_")
|
||||
)
|
||||
):
|
||||
tensor_names_to_rename.add(node.output[0])
|
||||
nodes_to_remove.append(node)
|
||||
if len(input_name_to_nodes[shape_node.output[0]]) == 1:
|
||||
nodes_to_remove.append(shape_node)
|
||||
return tensor_names_to_rename, nodes_to_remove
|
||||
|
||||
|
||||
def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphProto):
|
||||
input_self_past_0 = 2
|
||||
input_self_past_0 = 1
|
||||
# w/wo attention mask, w/wo hidden_state
|
||||
graph_input_names = [gi.name for gi in subg.input]
|
||||
while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
|
||||
input_self_past_0 += 1
|
||||
output_self_past_0 = 1
|
||||
|
||||
num_layers = int((len(subg.input) - input_self_past_0) / 4)
|
||||
|
|
@ -1232,9 +1296,6 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
|
|||
rel_pos_bias_node = node
|
||||
break
|
||||
|
||||
if rel_pos_bias_node is None:
|
||||
return False
|
||||
|
||||
decoder_masked_attention_supported_attr = [
|
||||
"past_present_share_buffer",
|
||||
"num_heads",
|
||||
|
|
@ -1243,8 +1304,31 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
|
|||
"domain",
|
||||
]
|
||||
|
||||
target_squeezed_past_seq_name = "past_sequence_length_squeezed_int64"
|
||||
tensor_names_to_rename, nodes_to_remove = find_past_seq_len_usage(subg)
|
||||
if len(tensor_names_to_rename) > 0:
|
||||
for name_to_rename in tensor_names_to_rename:
|
||||
print(f"Found tensor name {name_to_rename} to be renamed to {target_squeezed_past_seq_name}")
|
||||
for nr in nodes_to_remove:
|
||||
print(f"Found node to removed: type:{nr.op_type}, name:{nr.name}")
|
||||
|
||||
squeeze_node = onnx.helper.make_node(
|
||||
"Squeeze",
|
||||
["past_sequence_length"],
|
||||
["past_sequence_length_squeezed"],
|
||||
name="node_past_sequence_length_squeeze",
|
||||
)
|
||||
cast_node = onnx.helper.make_node(
|
||||
"Cast",
|
||||
["past_sequence_length_squeezed"],
|
||||
[target_squeezed_past_seq_name],
|
||||
name="node_past_sequence_length_squeeze_cast",
|
||||
to=TensorProto.INT64,
|
||||
)
|
||||
new_nodes.extend([squeeze_node, cast_node])
|
||||
|
||||
for node in subg.node:
|
||||
if len(node.output) > 0 and node.output[0] == rel_pos_bias_node.input[1]:
|
||||
if len(node.output) > 0 and rel_pos_bias_node is not None and node.output[0] == rel_pos_bias_node.input[1]:
|
||||
cast_node = onnx.helper.make_node(
|
||||
"Cast",
|
||||
["past_sequence_length"],
|
||||
|
|
@ -1266,20 +1350,16 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
|
|||
node.input[0], # query
|
||||
node.input[1], # key
|
||||
node.input[2], # value
|
||||
node.input[4], # 2D mask
|
||||
node.input[5], # relative_position_bias
|
||||
]
|
||||
|
||||
if len(node.input) > 6:
|
||||
nis.extend([node.input[6]]) # past_key
|
||||
nis.extend([node.input[7]]) # past_value
|
||||
else:
|
||||
nis.extend([""]) # past_key
|
||||
nis.extend([""]) # past_value
|
||||
|
||||
nis.extend([node.input[4] if len(node.input) > 4 else ""]) # 2D mask
|
||||
nis.extend([node.input[5] if len(node.input) > 5 else ""]) # relative_position_bias
|
||||
nis.extend([node.input[6] if len(node.input) > 6 else ""]) # past_key
|
||||
nis.extend([node.input[7] if len(node.input) > 7 else ""]) # past_value
|
||||
nis.extend(["past_sequence_length"]) # past_sequence_length
|
||||
nis.extend(["beam_width"]) # beam_width
|
||||
nis.extend(["cache_indirection"]) # cache_indirection
|
||||
nis.extend([node.input[3] if len(node.input) > 3 else ""]) # bias
|
||||
|
||||
kwargs["past_present_share_buffer"] = 1
|
||||
|
||||
|
|
@ -1287,10 +1367,15 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
|
|||
"DecoderMaskedMultiHeadAttention", nis, node.output, name=node.name, **kwargs
|
||||
)
|
||||
|
||||
new_nodes.extend([node])
|
||||
if node not in nodes_to_remove:
|
||||
for index, name in enumerate(node.input):
|
||||
if name in tensor_names_to_rename:
|
||||
node.input[index] = target_squeezed_past_seq_name
|
||||
new_nodes.extend([node])
|
||||
|
||||
subg.ClearField("node")
|
||||
subg.node.extend(new_nodes)
|
||||
orig_input_names = [inp.name for inp in subg.input]
|
||||
|
||||
new_inputs = []
|
||||
for i, vi in enumerate(subg.input):
|
||||
|
|
@ -1302,15 +1387,20 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
|
|||
shape=[shape[0], shape[1], "max_seq_len", shape[3]],
|
||||
)
|
||||
new_inputs.extend([vi])
|
||||
new_inputs.extend([onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])])
|
||||
new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
|
||||
new_inputs.extend(
|
||||
[
|
||||
onnx.helper.make_tensor_value_info(
|
||||
"cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
|
||||
)
|
||||
]
|
||||
)
|
||||
if "past_sequence_length" not in orig_input_names:
|
||||
new_inputs.extend(
|
||||
[onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])]
|
||||
)
|
||||
if "beam_width" not in orig_input_names:
|
||||
new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
|
||||
if "cache_indirection" not in orig_input_names:
|
||||
new_inputs.extend(
|
||||
[
|
||||
onnx.helper.make_tensor_value_info(
|
||||
"cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
|
||||
)
|
||||
]
|
||||
)
|
||||
subg.ClearField("input")
|
||||
subg.input.extend(new_inputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -216,11 +216,12 @@ def export_onnx_models(
|
|||
)
|
||||
config = models["decoder"].config
|
||||
|
||||
if (not use_external_data_format) and (config.num_layers > 24):
|
||||
if (not use_external_data_format) and (config.num_hidden_layers > 24):
|
||||
logger.info("Try use_external_data_format when model size > 2GB")
|
||||
|
||||
output_paths = []
|
||||
for name, model in models.items():
|
||||
print(f"========> Handling {name} model......")
|
||||
model.to(device)
|
||||
filename_suffix = "_" + name
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,11 @@ from onnx import TensorProto, helper
|
|||
from transformers import WhisperConfig
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
from convert_generation import get_shared_initializers # noqa: E402
|
||||
from benchmark_helper import Precision # noqa: E402
|
||||
from convert_generation import ( # noqa: E402
|
||||
get_shared_initializers,
|
||||
update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
|
||||
)
|
||||
|
||||
|
||||
def chain_model(args):
|
||||
|
|
@ -54,8 +58,12 @@ def chain_model(args):
|
|||
)
|
||||
|
||||
# beam graph inputs
|
||||
float_data_type = TensorProto.FLOAT
|
||||
if args.precision != Precision.FLOAT32:
|
||||
float_data_type = TensorProto.FLOAT16
|
||||
|
||||
input_features = helper.make_tensor_value_info(
|
||||
"input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
|
||||
"input_features", float_data_type, ["batch_size", "feature_size", "sequence_length"]
|
||||
)
|
||||
max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
|
||||
min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
|
||||
|
|
@ -89,6 +97,12 @@ def chain_model(args):
|
|||
)
|
||||
graph_outputs = [sequences]
|
||||
|
||||
if hasattr(args, "use_gpu") and args.use_gpu:
|
||||
if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
|
||||
print("*****update whisper decoder subgraph successfully!!!*****")
|
||||
else:
|
||||
print("*****DecoderMaskedMultiHeadAttention is not applied to whisper decoder*****")
|
||||
|
||||
# Initializers/opsets
|
||||
# Delete shared data between decoder/encoder and move to larger graph initializers
|
||||
initializers = get_shared_initializers(encoder_model, decoder_model)
|
||||
|
|
@ -98,6 +112,7 @@ def chain_model(args):
|
|||
helper.make_attribute("encoder", encoder_model.graph),
|
||||
]
|
||||
)
|
||||
|
||||
opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)]
|
||||
|
||||
beam_graph = helper.make_graph([node], "beam-search-test", graph_inputs, graph_outputs, initializers)
|
||||
|
|
|
|||
|
|
@ -946,6 +946,7 @@ class OnnxModel:
|
|||
|
||||
sorted_node_set_len = -1
|
||||
graph_nodes = graph.node if not is_deterministic else sorted(graph.node, key=lambda x: x.name)
|
||||
|
||||
last_node_name = None
|
||||
while len(sorted_node_set) != len(graph_nodes):
|
||||
if len(sorted_node_set) == sorted_node_set_len:
|
||||
|
|
@ -959,7 +960,8 @@ class OnnxModel:
|
|||
sorted_nodes.append(node)
|
||||
sorted_node_set.add(node_idx)
|
||||
for output in node.output:
|
||||
deps_set.add(output)
|
||||
if output:
|
||||
deps_set.add(output)
|
||||
continue
|
||||
failed = False
|
||||
for input_name in node.input:
|
||||
|
|
@ -970,7 +972,8 @@ class OnnxModel:
|
|||
sorted_nodes.append(node)
|
||||
sorted_node_set.add(node_idx)
|
||||
for output in node.output:
|
||||
deps_set.add(output)
|
||||
if output:
|
||||
deps_set.add(output)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue