Add FP16 support for Whisper model (#15427)

Current ORT can only run inference for Whisper FP32 model. This PR adds
FP16 support.
This commit is contained in:
stevenlix 2023-04-08 21:36:10 -07:00 committed by GitHub
parent 34f22daf25
commit 6d126f8996
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 55 additions and 14 deletions

View file

@ -466,9 +466,9 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(float)</dt>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain to float tensors.</dd>
<dt><tt>F</tt> : tensor(float), tensor(int32)</dt>
<dt><tt>F</tt> : tensor(float), tensor(int32), tensor(float16)</dt>
<dd>Constrain input type to float or int tensors.</dd>
<dt><tt>I</tt> : tensor(int32)</dt>
<dd>Constrain to integer types</dd>

View file

@ -319,7 +319,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs<float>,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
@ -340,7 +340,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
init_beam_state_fp16_func_,
device_copy_func_,
device_copy_int32_func_,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs<MLFloat16>,
update_decoder_feeds_fp16_func_,
expand_buffer_int32_func_,
expand_buffer_float_func_,

View file

@ -66,10 +66,28 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
"num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")");
auto* length_penalty_tensor = context->Input<Tensor>(5);
length_penalty = length_penalty_tensor ? static_cast<float>(*length_penalty_tensor->Data<float>()) : 1;
if (length_penalty_tensor) {
if (length_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) {
length_penalty = static_cast<float>(*length_penalty_tensor->Data<float>());
}
else {
length_penalty = static_cast<MLFloat16>(*length_penalty_tensor->Data<MLFloat16>());
}
} else {
length_penalty = 1.0f;
}
auto* repetition_penalty_tensor = context->Input<Tensor>(6);
repetition_penalty = repetition_penalty_tensor ? static_cast<float>(*repetition_penalty_tensor->Data<float>()) : 1.0f;
if (repetition_penalty_tensor) {
if (repetition_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) {
repetition_penalty = static_cast<float>(*repetition_penalty_tensor->Data<float>());
}
else {
repetition_penalty = static_cast<MLFloat16>(*repetition_penalty_tensor->Data<MLFloat16>());
}
} else {
repetition_penalty = 1.0f;
}
ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty);
}

View file

@ -873,7 +873,7 @@ Status UpdateDecoderFeeds(
//------------------------------------------------
// Modified Encoder functions for Whisper Model
//------------------------------------------------
template <typename T>
Status CreateWhisperEncoderInputs(
const Tensor* original_encoder_input_features,
const OrtValue* attn_mask_value,
@ -895,9 +895,9 @@ Status CreateWhisperEncoderInputs(
// Current shape is (batch_size, sequence_length)
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
Tensor::InitOrtValue(DataTypeImpl::GetType<float>(),
Tensor::InitOrtValue(DataTypeImpl::GetType<T>(),
input_features_shape,
const_cast<Tensor*>(original_encoder_input_features)->MutableData<float>(),
const_cast<Tensor*>(original_encoder_input_features)->MutableData<T>(),
allocator->Info(),
encoder_input_features);
@ -1068,6 +1068,25 @@ template Status ExpandBuffer<MLFloat16>(
bool only_copy_shape,
int max_sequence_length);
template Status CreateWhisperEncoderInputs<float>(
const Tensor* original_encoder_input_features,
const OrtValue* attn_mask_value,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& encoder_input_features,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids);
template Status CreateWhisperEncoderInputs<MLFloat16>(
const Tensor* original_encoder_input_features,
const OrtValue* attn_mask_value,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& encoder_input_features,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids);
} // namespace GenerationCpuDeviceHelper
} // namespace contrib
} // namespace onnxruntime

View file

@ -318,7 +318,7 @@ Status UpdateDecoderFeeds(
// ---------------------------------------------------------------
// Functions for encoder-decoder model with float input like Whisper
// ---------------------------------------------------------------
template <typename T>
Status CreateWhisperEncoderInputs(
const Tensor* original_encoder_input_features,
const OrtValue* attn_mask_value,

View file

@ -74,8 +74,8 @@ Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgr
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type,
"encoder subgraph input 0 (encoder_input_features) shall have float32 type");
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type && subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float16_type,
"encoder subgraph input 0 (encoder_input_features) shall have float32 or float16 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"encoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,

View file

@ -211,6 +211,10 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
memcpy(destination, input.Get<Tensor>().Data<int32_t>(), bytes);
} else if (dataType == DataTypeImpl::GetType<int64_t>()) {
memcpy(destination, input.Get<Tensor>().Data<int64_t>(), bytes);
} else if (dataType == DataTypeImpl::GetType<float>()) {
memcpy(destination, input.Get<Tensor>().Data<float>(), bytes);
} else if (dataType == DataTypeImpl::GetType<MLFloat16>()) {
memcpy(destination, input.Get<Tensor>().Data<MLFloat16>(), bytes);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"AddToFeeds: An implementation for the input type ",

View file

@ -1102,8 +1102,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
"Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam."
"Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)",
"T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)"}, "Constrain to float tensors.")
.TypeConstraint("F", {"tensor(float)", "tensor(int32)"}, "Constrain input type to float or int tensors.")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.")
.TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.")
.TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {