diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index e2a615eba4..860a4355a3 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -192,13 +192,13 @@ Status CheckInputs(const Tensor* query, "head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16); } - if (cos_dims[0] < present_sequence_length) { + if (cos_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be of max_sequence_length."); + "cos_cache dimension 0 should be not be less than total_sequence_length."); } - if (sin_dims[0] < present_sequence_length) { + if (sin_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be of max_sequence_length."); + "sin_cache dimension 0 should be not be less than total_sequence_length."); } if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 1a7c3fcea3..8352397f68 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -215,13 +215,13 @@ Status CheckInputs(const Tensor* query, "head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16); } - if (cos_dims[0] < present_sequence_length) { + if (cos_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be of max_sequence_length."); + "cos_cache dimension 0 should be not be less than total_sequence_length."); } - if (sin_dims[0] < present_sequence_length) { + if (sin_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be of max_sequence_length."); + "sin_cache dimension 0 should be not be less than total_sequence_length."); } if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,