diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index ed70ac5971..f22190e029 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -466,9 +466,9 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float)
+
T : tensor(float), tensor(float16)
Constrain to float tensors.
-
F : tensor(float), tensor(int32)
+
F : tensor(float), tensor(int32), tensor(float16)
Constrain input type to float or int tensors.
I : tensor(int32)
Constrain to integer types
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 7bc288a034..5159a86612 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -319,7 +319,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState, device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy, - create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs, + create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs, update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, @@ -340,7 +340,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { init_beam_state_fp16_func_, device_copy_func_, device_copy_int32_func_, - create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs, + create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs, update_decoder_feeds_fp16_func_, expand_buffer_int32_func_, expand_buffer_float_func_, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index f79f9b1dbf..0f3cbdf488 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -66,10 +66,28 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { "num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")"); auto* length_penalty_tensor = context->Input(5); - length_penalty = length_penalty_tensor ? static_cast(*length_penalty_tensor->Data()) : 1; + if (length_penalty_tensor) { + if (length_penalty_tensor->DataType() == DataTypeImpl::GetType()) { + length_penalty = static_cast(*length_penalty_tensor->Data()); + } + else { + length_penalty = static_cast(*length_penalty_tensor->Data()); + } + } else { + length_penalty = 1.0f; + } auto* repetition_penalty_tensor = context->Input(6); - repetition_penalty = repetition_penalty_tensor ? static_cast(*repetition_penalty_tensor->Data()) : 1.0f; + if (repetition_penalty_tensor) { + if (repetition_penalty_tensor->DataType() == DataTypeImpl::GetType()) { + repetition_penalty = static_cast(*repetition_penalty_tensor->Data()); + } + else { + repetition_penalty = static_cast(*repetition_penalty_tensor->Data()); + } + } else { + repetition_penalty = 1.0f; + } ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index cf98886e9d..76630a30df 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -873,7 +873,7 @@ Status UpdateDecoderFeeds( //------------------------------------------------ // Modified Encoder functions for Whisper Model //------------------------------------------------ - +template Status CreateWhisperEncoderInputs( const Tensor* original_encoder_input_features, const OrtValue* attn_mask_value, @@ -895,9 +895,9 @@ Status CreateWhisperEncoderInputs( // Current shape is (batch_size, sequence_length) // Note that we will expand it to (batch_size * num_beams, sequence_length) later. // To avoid cloning input_ids, we use const_cast here since this function does not change its content. - Tensor::InitOrtValue(DataTypeImpl::GetType(), + Tensor::InitOrtValue(DataTypeImpl::GetType(), input_features_shape, - const_cast(original_encoder_input_features)->MutableData(), + const_cast(original_encoder_input_features)->MutableData(), allocator->Info(), encoder_input_features); @@ -1068,6 +1068,25 @@ template Status ExpandBuffer( bool only_copy_shape, int max_sequence_length); +template Status CreateWhisperEncoderInputs( + const Tensor* original_encoder_input_features, + const OrtValue* attn_mask_value, + int pad_token_id, + int start_token_id, + AllocatorPtr allocator, + OrtValue& encoder_input_features, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids); + +template Status CreateWhisperEncoderInputs( + const Tensor* original_encoder_input_features, + const OrtValue* attn_mask_value, + int pad_token_id, + int start_token_id, + AllocatorPtr allocator, + OrtValue& encoder_input_features, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids); } // namespace GenerationCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 3a70bb1926..ce6c35c5f3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -318,7 +318,7 @@ Status UpdateDecoderFeeds( // --------------------------------------------------------------- // Functions for encoder-decoder model with float input like Whisper // --------------------------------------------------------------- - +template Status CreateWhisperEncoderInputs( const Tensor* original_encoder_input_features, const OrtValue* attn_mask_value, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc index 8b2dde9518..3d9a3a52df 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc @@ -74,8 +74,8 @@ Status WhisperEncoderSubgraph::Validate(const std::vector& subgr constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT; constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16; - ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type, - "encoder subgraph input 0 (encoder_input_features) shall have float32 type"); + ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type && subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float16_type, + "encoder subgraph input 0 (encoder_input_features) shall have float32 or float16 type"); ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, "encoder subgraph input 1 (encoder_attention_mask) shall have int32 type"); ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 4ed8884227..1c37c14c83 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -211,6 +211,10 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, memcpy(destination, input.Get().Data(), bytes); } else if (dataType == DataTypeImpl::GetType()) { memcpy(destination, input.Get().Data(), bytes); + } else if (dataType == DataTypeImpl::GetType()) { + memcpy(destination, input.Get().Data(), bytes); + } else if (dataType == DataTypeImpl::GetType()) { + memcpy(destination, input.Get().Data(), bytes); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "AddToFeeds: An implementation for the input type ", diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 95bb6b5937..637600dcaf 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1102,8 +1102,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) - .TypeConstraint("T", {"tensor(float)"}, "Constrain to float tensors.") - .TypeConstraint("F", {"tensor(float)", "tensor(int32)"}, "Constrain input type to float or int tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") + .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {