diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index ed70ac5971..f22190e029 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -466,9 +466,9 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float)
+- T : tensor(float), tensor(float16)
- Constrain to float tensors.
-- F : tensor(float), tensor(int32)
+- F : tensor(float), tensor(int32), tensor(float16)
- Constrain input type to float or int tensors.
- I : tensor(int32)
- Constrain to integer types
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
index 7bc288a034..5159a86612 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
@@ -319,7 +319,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy,
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy,
- create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
+ create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds,
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer,
@@ -340,7 +340,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
init_beam_state_fp16_func_,
device_copy_func_,
device_copy_int32_func_,
- create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
+ create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
update_decoder_feeds_fp16_func_,
expand_buffer_int32_func_,
expand_buffer_float_func_,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
index f79f9b1dbf..0f3cbdf488 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -66,10 +66,28 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
"num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")");
auto* length_penalty_tensor = context->Input(5);
- length_penalty = length_penalty_tensor ? static_cast(*length_penalty_tensor->Data()) : 1;
+ if (length_penalty_tensor) {
+ if (length_penalty_tensor->DataType() == DataTypeImpl::GetType()) {
+ length_penalty = static_cast(*length_penalty_tensor->Data());
+ }
+ else {
+ length_penalty = static_cast(*length_penalty_tensor->Data());
+ }
+ } else {
+ length_penalty = 1.0f;
+ }
auto* repetition_penalty_tensor = context->Input(6);
- repetition_penalty = repetition_penalty_tensor ? static_cast(*repetition_penalty_tensor->Data()) : 1.0f;
+ if (repetition_penalty_tensor) {
+ if (repetition_penalty_tensor->DataType() == DataTypeImpl::GetType()) {
+ repetition_penalty = static_cast(*repetition_penalty_tensor->Data());
+ }
+ else {
+ repetition_penalty = static_cast(*repetition_penalty_tensor->Data());
+ }
+ } else {
+ repetition_penalty = 1.0f;
+ }
ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty);
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
index cf98886e9d..76630a30df 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
@@ -873,7 +873,7 @@ Status UpdateDecoderFeeds(
//------------------------------------------------
// Modified Encoder functions for Whisper Model
//------------------------------------------------
-
+template
Status CreateWhisperEncoderInputs(
const Tensor* original_encoder_input_features,
const OrtValue* attn_mask_value,
@@ -895,9 +895,9 @@ Status CreateWhisperEncoderInputs(
// Current shape is (batch_size, sequence_length)
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
- Tensor::InitOrtValue(DataTypeImpl::GetType(),
+ Tensor::InitOrtValue(DataTypeImpl::GetType(),
input_features_shape,
- const_cast(original_encoder_input_features)->MutableData(),
+ const_cast(original_encoder_input_features)->MutableData(),
allocator->Info(),
encoder_input_features);
@@ -1068,6 +1068,25 @@ template Status ExpandBuffer(
bool only_copy_shape,
int max_sequence_length);
+template Status CreateWhisperEncoderInputs(
+ const Tensor* original_encoder_input_features,
+ const OrtValue* attn_mask_value,
+ int pad_token_id,
+ int start_token_id,
+ AllocatorPtr allocator,
+ OrtValue& encoder_input_features,
+ OrtValue& encoder_attention_mask,
+ OrtValue& decoder_input_ids);
+
+template Status CreateWhisperEncoderInputs(
+ const Tensor* original_encoder_input_features,
+ const OrtValue* attn_mask_value,
+ int pad_token_id,
+ int start_token_id,
+ AllocatorPtr allocator,
+ OrtValue& encoder_input_features,
+ OrtValue& encoder_attention_mask,
+ OrtValue& decoder_input_ids);
} // namespace GenerationCpuDeviceHelper
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
index 3a70bb1926..ce6c35c5f3 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
@@ -318,7 +318,7 @@ Status UpdateDecoderFeeds(
// ---------------------------------------------------------------
// Functions for encoder-decoder model with float input like Whisper
// ---------------------------------------------------------------
-
+template
Status CreateWhisperEncoderInputs(
const Tensor* original_encoder_input_features,
const OrtValue* attn_mask_value,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc
index 8b2dde9518..3d9a3a52df 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc
@@ -74,8 +74,8 @@ Status WhisperEncoderSubgraph::Validate(const std::vector& subgr
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
- ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type,
- "encoder subgraph input 0 (encoder_input_features) shall have float32 type");
+ ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type && subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float16_type,
+ "encoder subgraph input 0 (encoder_input_features) shall have float32 or float16 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"encoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,
diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
index 4ed8884227..1c37c14c83 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
+++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
@@ -211,6 +211,10 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
memcpy(destination, input.Get().Data(), bytes);
} else if (dataType == DataTypeImpl::GetType()) {
memcpy(destination, input.Get().Data(), bytes);
+ } else if (dataType == DataTypeImpl::GetType()) {
+ memcpy(destination, input.Get().Data(), bytes);
+ } else if (dataType == DataTypeImpl::GetType()) {
+ memcpy(destination, input.Get().Data(), bytes);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"AddToFeeds: An implementation for the input type ",
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 95bb6b5937..637600dcaf 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1102,8 +1102,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
"Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam."
"Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)",
"T", OpSchema::Optional)
- .TypeConstraint("T", {"tensor(float)"}, "Constrain to float tensors.")
- .TypeConstraint("F", {"tensor(float)", "tensor(int32)"}, "Constrain input type to float or int tensors.")
+ .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.")
+ .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.")
.TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {