optimization for whisper model with decoder masked multihead attention (#15827)

* graph tools update
* cuda kernel update
* operator spec update and implementation update
* greed search bug fix on wrong assumption for cross/self attention
input length
* avoid use of "" name in value info when loading graph which
historically in many model
This commit is contained in:
Zhang Lei 2023-05-18 15:38:31 -07:00 committed by GitHub
parent be6c0bb53c
commit 0f8e66d905
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 246 additions and 72 deletions

View file

@ -1134,7 +1134,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
</dl>
#### Inputs (1 - 10)
#### Inputs (1 - 11)
<dl>
<dt><tt>query</tt> : T</dt>
@ -1157,6 +1157,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
</dl>
#### Outputs (1 - 3)

View file

@ -802,7 +802,7 @@ Do not modify directly.*
|ComplexMulConj|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|DecoderAttention|*in* query:**T**<br> *in* key:**T**<br> *in* q_weight:**T**<br> *in* kv_weight:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**B**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* static_kv:**B**<br> *in* use_past:**B**<br> *in* has_layer_state:**B**<br> *in* has_key_padding_mask:**B**<br> *out* output:**T**<br> *out* new_key_cache:**T**<br> *out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
|DecoderMaskedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* mask_index:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|DecoderMaskedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* mask_index:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|DecoderMaskedSelfAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(float16)|
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|

View file

@ -62,6 +62,14 @@ struct BeamSearchState : IBeamSearchState<T> {
}
}
void EnsurePastStateReorderStagingBuffer(AllocatorPtr allocator, int64_t sz) {
auto current_buffer_size = this->staging_for_past_state_reorder.Shape().Size();
if (sz > current_buffer_size) {
TensorShape buffer_shape = {sz};
this->staging_for_past_state_reorder = Tensor(DataTypeImpl::GetType<T>(), buffer_shape, allocator);
}
}
private:
BufferUniquePtr next_token_logits_buffer_;
BufferUniquePtr next_token_scores_buffer_;

View file

@ -254,6 +254,11 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
if (decoder_subgraph_.has_decoder_masked_attention_) {
size_t offset = static_cast<size_t>(decoder_subgraph_.GetFirstPastInputIndex());
// Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same
auto first_cross_attention_key = decoder_feeds[offset + 2 * static_cast<size_t>(decoder_subgraph_.num_layers)].GetMutable<Tensor>();
auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);
// Here we only need to reorder the past key for self-attention and cross-attention.
for (size_t i = 0; i < 2 * static_cast<size_t>(decoder_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,

View file

@ -245,6 +245,11 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
if (decoder_subgraph_.has_decoder_masked_attention_) {
size_t offset = static_cast<size_t>(decoder_subgraph_.GetFirstPastInputIndex());
// Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same
auto first_cross_attention_key = decoder_feeds[offset + 2 * static_cast<size_t>(decoder_subgraph_.num_layers)].GetMutable<Tensor>();
auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);
// Here we only need to reorder the past key for self-attention and cross-attention.
for (size_t i = 0; i < 2 * static_cast<size_t>(decoder_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,

View file

@ -106,9 +106,9 @@ Status WhisperDecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgr
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 0 (input_ids) shall have int32 type");
auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
auto float_type = subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
"decoder subgraph input 1 (encoder_hidden_states) shall have float or float16 type");
for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) {
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
@ -202,6 +202,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds(
// When first_past_input_index_ == 2, the encoder_hidden_states and past states are copied from the second output
// of encoder.
// When first_past_input_index_ == 1, the past states are copied from the second output of encoder.
// TODO: MAKE IT MORE READABLE
for (size_t j = static_cast<size_t>(3) - first_past_input_index_; j < encoder_fetches.size(); j++) {
if (j == 1) {
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
@ -226,7 +227,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds(
decoder_feeds.push_back(expanded_hidden_states);
} else {
// past key/value for cross attention does not need to be initialized with max_seq_len since they are static.
bool use_max_seq_len = (j - first_past_input_index_) < 2 * static_cast<size_t>(num_layers);
bool use_max_seq_len = (j - first_past_input_index_) <= 2 * static_cast<size_t>(num_layers);
OrtValue expanded_cache;
if (is_output_float16_) {
@ -252,7 +253,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds(
if (past_present_share_buffer_) {
// Past sequence length feed
ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, 1));
ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, cur_len - 1));
// Add beam search specific inputs
if (need_cache_indir) {
const int64_t batch_size = static_cast<int64_t>(batch_beam_size / num_beam);

View file

@ -22,6 +22,7 @@ static constexpr int kBeamWidthInputIndex = 8;
static constexpr int kCacheIndirectionInputIndex = 9;
static constexpr int kPastInputIndex = 5;
static constexpr int kPresentOutputIndex = 1;
static constexpr int kBiasIndex = 10;
#define REGISTER_KERNEL_TYPED(T1, T2) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
@ -63,6 +64,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
const Tensor* beam_width = context->Input<Tensor>(kBeamWidthInputIndex);
const Tensor* cache_indir = context->Input<Tensor>(kCacheIndirectionInputIndex);
const Tensor* bias = context->Input<Tensor>(kBiasIndex);
auto& device_prop = GetDeviceProp();
DecoderMaskedMultiHeadAttentionParams parameters;
@ -70,7 +72,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
value,
nullptr, // bias
bias,
mask_index,
relative_position_bias,
past_key,
@ -84,6 +86,13 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
is_dmmha_packing, // dmmha_packing
device_prop.maxThreadsPerBlock));
if (bias) {
const T1* bias_data = bias->Data<T1>();
parameters.q_bias = const_cast<T1*>(bias_data);
parameters.k_bias = const_cast<T1*>(bias_data + parameters.hidden_size);
parameters.v_bias = const_cast<T1*>(bias_data + 2LL * parameters.hidden_size);
}
int batch_size = parameters.batch_size;
int sequence_length = parameters.sequence_length;
@ -142,6 +151,9 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
// parameters.k and paraneters.v are nullptr
parameters.k_cache = const_cast<T1*>(key->Data<T1>());
parameters.v_cache = const_cast<T1*>(value->Data<T1>());
parameters.k_bias = nullptr;
parameters.v_bias = nullptr;
} else {
// Sanity check
ORT_ENFORCE(past_present_share_buffer_);

View file

@ -161,40 +161,19 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.q)[qk_offset]));
}
Qk_vec_k k;
if (!params.is_cross_attention) {
zero(k);
if (!is_masked) {
k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.k)[qk_offset]));
}
}
// The offset in the bias buffer.
int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE;
// Trigger the loads from the Q and K bias buffers.
Qk_vec_k q_bias;
Qk_vec_k k_bias;
if (!params.is_mha) {
// The offset in the bias buffer.
int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE;
if (params.q_bias && !is_masked) {
Qk_vec_k q_bias;
zero(q_bias);
q_bias = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.q_bias)[qk_bias_offset]));
if (!is_masked) {
q_bias = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.q_bias)[qk_bias_offset]));
}
zero(k_bias);
if (!is_masked) {
k_bias = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.k_bias)[qk_bias_offset]));
}
// Computes the Q/K values with bias.
q = add_vec(q, q_bias);
k = add_vec(k, k_bias);
}
T* params_k_cache = reinterpret_cast<T*>(params.k_cache);
const float inv_sqrt_dh = params.scale;
@ -205,6 +184,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
}
if (!params.is_cross_attention) {
Qk_vec_k k;
zero(k);
if (!is_masked) {
k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.k)[qk_offset]));
if (params.k_bias) {
Qk_vec_k k_bias;
k_bias = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&reinterpret_cast<T*>(params.k_bias)[qk_bias_offset]));
k = add_vec(k, k_bias);
}
}
if (!is_masked) {
// Write the K values to the global memory cache.
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
@ -438,7 +433,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
// One group of threads computes the product(s) for the current timestep.
V_vec_k v_bias;
if (!params.is_mha) {
if (params.v_bias && !params.is_cross_attention) {
zero(v_bias);
T* params_v_bias = reinterpret_cast<T*>(params.v_bias);
@ -478,7 +473,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
V_vec_k v;
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&reinterpret_cast<T*>(params.v)[v_offset]));
if (!params.is_mha) {
if (params.v_bias) {
v = add_vec(v, v_bias);
}

View file

@ -676,6 +676,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration",
"M",
OpSchema::Optional)
.Input(10,
"bias",
"Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection",
"T",
OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, v_hidden_size)",

View file

@ -1301,7 +1301,9 @@ Graph::Graph(const Model& owning_model,
for (auto& node_arg : graph_proto_->value_info()) {
if (utils::HasName(node_arg) && utils::HasType(node_arg)) {
name_to_type_map[node_arg.name()] = node_arg.type();
if (node_arg.name().size() > 0) {
name_to_type_map[node_arg.name()] = node_arg.type();
}
}
}

View file

@ -195,6 +195,7 @@ class SymbolicShapeInference:
"RestorePadding": self._infer_RestorePadding,
"BiasGelu": self._infer_BiasGelu,
"MultiHeadAttention": self._infer_MultiHeadAttention,
"DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
"FastGelu": self._infer_FastGelu,
"Gelu": self._infer_Gelu,
@ -2233,10 +2234,33 @@ class SymbolicShapeInference:
present_shape = [batch_size, num_heads, total_sequence_length, head_size]
assert output_dtype is not None
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
vi = self.known_vi_[node.output[2]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
if len(node.output) > 2 and node.output[1] and node.output[2]:
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
vi = self.known_vi_[node.output[2]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802
# Output 0 has shape (batch_size, 1, v_hidden_size)
# Q, K and V without packing:
# Input 0 (query) has shape (batch_size, 1, hidden_size)
# Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size)
query_shape = self._get_shape(node, 0)
if query_shape is not None:
output_shape = query_shape
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
assert output_dtype is not None
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
if len(node.output) > 2 and node.output[1] and node.output[2]:
past_shape = self._try_get_shape(node, 5)
if past_shape is not None:
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
vi = self.known_vi_[node.output[2]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
def _infer_FastGelu(self, node): # noqa: N802
self._propagate_shape_and_type(node)
@ -2659,10 +2683,16 @@ class SymbolicShapeInference:
logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
logger.debug("node inputs:")
for i in node.input:
logger.debug(self.known_vi_[i])
if i in self.known_vi_:
logger.debug(self.known_vi_[i])
else:
logger.debug(f"not in knwon_vi_ for {i}")
logger.debug("node outputs:")
for o in node.output:
logger.debug(self.known_vi_[o])
if o in self.known_vi_:
logger.debug(self.known_vi_[o])
else:
logger.debug(f"not in knwon_vi_ for {o}")
if self.auto_merge_ and not out_type_undefined:
logger.debug("Merging: " + str(self.suggested_merge_))
return False

View file

@ -1207,8 +1207,72 @@ def update_decoder_subgraph_use_decoder_masked_attention(
return True
def find_past_seq_len_usage(subg: GraphProto):
"""Correct graph which originally use dim of past_seq_len from input_ids's shape which is fixed to max_seq_len after
shared past/present buffer
Args:
subg (GraphProto): GraphProto of the decoder subgraph
return:
tensor_names_to_rename : set of tensor names which is equal to past_sequence_length
nodes_to_remove : list of node to remove
"""
tensor_names_to_rename = set()
nodes_to_remove = []
graph_intput_names = {inp.name: index for index, inp in enumerate(subg.input)}
input_name_to_nodes = {}
output_name_to_node = {}
for node in subg.node:
for input_name in node.input:
if input_name:
if input_name not in input_name_to_nodes:
input_name_to_nodes[input_name] = [node]
else:
input_name_to_nodes[input_name].append(node)
for output_name in node.output:
if output_name:
output_name_to_node[output_name] = node
for node in subg.node:
# find "Shape(past_key_self..) --> Gather(*, 2)"
if node.op_type == "Gather":
if not node.input[1] or not node.input[0]:
continue
shape_tensor_name, shape_index_name = (node.input[0], node.input[1])
ini_gather_indices = None
for tensor in subg.initializer:
if tensor.name == shape_index_name:
ini_gather_indices = tensor
break
if ini_gather_indices is None:
continue
gather_indices_arr = onnx.numpy_helper.to_array(ini_gather_indices)
if gather_indices_arr.size == 1 and gather_indices_arr.item() == 2 and node.input[0] in output_name_to_node:
shape_node = output_name_to_node[shape_tensor_name]
if (
shape_node.op_type == "Shape"
and shape_node.input[0]
and shape_node.input[0] in graph_intput_names
and (
shape_node.input[0].startswith("past_key_self_")
or shape_node.input[0].startswith("past_value_self_")
)
):
tensor_names_to_rename.add(node.output[0])
nodes_to_remove.append(node)
if len(input_name_to_nodes[shape_node.output[0]]) == 1:
nodes_to_remove.append(shape_node)
return tensor_names_to_rename, nodes_to_remove
def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphProto):
input_self_past_0 = 2
input_self_past_0 = 1
# w/wo attention mask, w/wo hidden_state
graph_input_names = [gi.name for gi in subg.input]
while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
input_self_past_0 += 1
output_self_past_0 = 1
num_layers = int((len(subg.input) - input_self_past_0) / 4)
@ -1232,9 +1296,6 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
rel_pos_bias_node = node
break
if rel_pos_bias_node is None:
return False
decoder_masked_attention_supported_attr = [
"past_present_share_buffer",
"num_heads",
@ -1243,8 +1304,31 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
"domain",
]
target_squeezed_past_seq_name = "past_sequence_length_squeezed_int64"
tensor_names_to_rename, nodes_to_remove = find_past_seq_len_usage(subg)
if len(tensor_names_to_rename) > 0:
for name_to_rename in tensor_names_to_rename:
print(f"Found tensor name {name_to_rename} to be renamed to {target_squeezed_past_seq_name}")
for nr in nodes_to_remove:
print(f"Found node to removed: type:{nr.op_type}, name:{nr.name}")
squeeze_node = onnx.helper.make_node(
"Squeeze",
["past_sequence_length"],
["past_sequence_length_squeezed"],
name="node_past_sequence_length_squeeze",
)
cast_node = onnx.helper.make_node(
"Cast",
["past_sequence_length_squeezed"],
[target_squeezed_past_seq_name],
name="node_past_sequence_length_squeeze_cast",
to=TensorProto.INT64,
)
new_nodes.extend([squeeze_node, cast_node])
for node in subg.node:
if len(node.output) > 0 and node.output[0] == rel_pos_bias_node.input[1]:
if len(node.output) > 0 and rel_pos_bias_node is not None and node.output[0] == rel_pos_bias_node.input[1]:
cast_node = onnx.helper.make_node(
"Cast",
["past_sequence_length"],
@ -1266,20 +1350,16 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
node.input[0], # query
node.input[1], # key
node.input[2], # value
node.input[4], # 2D mask
node.input[5], # relative_position_bias
]
if len(node.input) > 6:
nis.extend([node.input[6]]) # past_key
nis.extend([node.input[7]]) # past_value
else:
nis.extend([""]) # past_key
nis.extend([""]) # past_value
nis.extend([node.input[4] if len(node.input) > 4 else ""]) # 2D mask
nis.extend([node.input[5] if len(node.input) > 5 else ""]) # relative_position_bias
nis.extend([node.input[6] if len(node.input) > 6 else ""]) # past_key
nis.extend([node.input[7] if len(node.input) > 7 else ""]) # past_value
nis.extend(["past_sequence_length"]) # past_sequence_length
nis.extend(["beam_width"]) # beam_width
nis.extend(["cache_indirection"]) # cache_indirection
nis.extend([node.input[3] if len(node.input) > 3 else ""]) # bias
kwargs["past_present_share_buffer"] = 1
@ -1287,10 +1367,15 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
"DecoderMaskedMultiHeadAttention", nis, node.output, name=node.name, **kwargs
)
new_nodes.extend([node])
if node not in nodes_to_remove:
for index, name in enumerate(node.input):
if name in tensor_names_to_rename:
node.input[index] = target_squeezed_past_seq_name
new_nodes.extend([node])
subg.ClearField("node")
subg.node.extend(new_nodes)
orig_input_names = [inp.name for inp in subg.input]
new_inputs = []
for i, vi in enumerate(subg.input):
@ -1302,15 +1387,20 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphP
shape=[shape[0], shape[1], "max_seq_len", shape[3]],
)
new_inputs.extend([vi])
new_inputs.extend([onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])])
new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
new_inputs.extend(
[
onnx.helper.make_tensor_value_info(
"cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
)
]
)
if "past_sequence_length" not in orig_input_names:
new_inputs.extend(
[onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])]
)
if "beam_width" not in orig_input_names:
new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
if "cache_indirection" not in orig_input_names:
new_inputs.extend(
[
onnx.helper.make_tensor_value_info(
"cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
)
]
)
subg.ClearField("input")
subg.input.extend(new_inputs)

View file

@ -216,11 +216,12 @@ def export_onnx_models(
)
config = models["decoder"].config
if (not use_external_data_format) and (config.num_layers > 24):
if (not use_external_data_format) and (config.num_hidden_layers > 24):
logger.info("Try use_external_data_format when model size > 2GB")
output_paths = []
for name, model in models.items():
print(f"========> Handling {name} model......")
model.to(device)
filename_suffix = "_" + name

View file

@ -6,7 +6,11 @@ from onnx import TensorProto, helper
from transformers import WhisperConfig
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from convert_generation import get_shared_initializers # noqa: E402
from benchmark_helper import Precision # noqa: E402
from convert_generation import ( # noqa: E402
get_shared_initializers,
update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
)
def chain_model(args):
@ -54,8 +58,12 @@ def chain_model(args):
)
# beam graph inputs
float_data_type = TensorProto.FLOAT
if args.precision != Precision.FLOAT32:
float_data_type = TensorProto.FLOAT16
input_features = helper.make_tensor_value_info(
"input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
"input_features", float_data_type, ["batch_size", "feature_size", "sequence_length"]
)
max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
@ -89,6 +97,12 @@ def chain_model(args):
)
graph_outputs = [sequences]
if hasattr(args, "use_gpu") and args.use_gpu:
if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
print("*****update whisper decoder subgraph successfully!!!*****")
else:
print("*****DecoderMaskedMultiHeadAttention is not applied to whisper decoder*****")
# Initializers/opsets
# Delete shared data between decoder/encoder and move to larger graph initializers
initializers = get_shared_initializers(encoder_model, decoder_model)
@ -98,6 +112,7 @@ def chain_model(args):
helper.make_attribute("encoder", encoder_model.graph),
]
)
opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)]
beam_graph = helper.make_graph([node], "beam-search-test", graph_inputs, graph_outputs, initializers)

View file

@ -946,6 +946,7 @@ class OnnxModel:
sorted_node_set_len = -1
graph_nodes = graph.node if not is_deterministic else sorted(graph.node, key=lambda x: x.name)
last_node_name = None
while len(sorted_node_set) != len(graph_nodes):
if len(sorted_node_set) == sorted_node_set_len:
@ -959,7 +960,8 @@ class OnnxModel:
sorted_nodes.append(node)
sorted_node_set.add(node_idx)
for output in node.output:
deps_set.add(output)
if output:
deps_set.add(output)
continue
failed = False
for input_name in node.input:
@ -970,7 +972,8 @@ class OnnxModel:
sorted_nodes.append(node)
sorted_node_set.add(node_idx)
for output in node.output:
deps_set.add(output)
if output:
deps_set.add(output)
else:
continue