From d6280e26bd23a0ce029b12a436b645943e4aeb74 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 3 May 2024 09:43:53 -0700 Subject: [PATCH] check rotary_embedding with seq length (#20547) ### Description with past/present shared same buffer, the present seq length is different with total sequence length. The size of cos/sin cache should be checked with sequence length. ### Motivation and Context --- .../contrib_ops/cpu/bert/group_query_attention_helper.h | 8 ++++---- .../contrib_ops/cuda/bert/group_query_attention_helper.h | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index e2a615eba4..860a4355a3 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -192,13 +192,13 @@ Status CheckInputs(const Tensor* query, "head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16); } - if (cos_dims[0] < present_sequence_length) { + if (cos_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be of max_sequence_length."); + "cos_cache dimension 0 should be not be less than total_sequence_length."); } - if (sin_dims[0] < present_sequence_length) { + if (sin_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be of max_sequence_length."); + "sin_cache dimension 0 should be not be less than total_sequence_length."); } if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 1a7c3fcea3..8352397f68 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -215,13 +215,13 @@ Status CheckInputs(const Tensor* query, "head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16); } - if (cos_dims[0] < present_sequence_length) { + if (cos_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be of max_sequence_length."); + "cos_cache dimension 0 should be not be less than total_sequence_length."); } - if (sin_dims[0] < present_sequence_length) { + if (sin_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be of max_sequence_length."); + "sin_cache dimension 0 should be not be less than total_sequence_length."); } if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,