mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
check rotary_embedding with seq length (#20547)
### Description <!-- Describe your changes. --> with past/present shared same buffer, the present seq length is different with total sequence length. The size of cos/sin cache should be checked with sequence length. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
e540423179
commit
d6280e26bd
2 changed files with 8 additions and 8 deletions
|
|
@ -192,13 +192,13 @@ Status CheckInputs(const Tensor* query,
|
|||
"head_size shall be a multiple of 16. Got head_size % 16 == ",
|
||||
head_size % 16);
|
||||
}
|
||||
if (cos_dims[0] < present_sequence_length) {
|
||||
if (cos_dims[0] < total_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 0 should be of max_sequence_length.");
|
||||
"cos_cache dimension 0 should be not be less than total_sequence_length.");
|
||||
}
|
||||
if (sin_dims[0] < present_sequence_length) {
|
||||
if (sin_dims[0] < total_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sin_cache dimension 0 should be of max_sequence_length.");
|
||||
"sin_cache dimension 0 should be not be less than total_sequence_length.");
|
||||
}
|
||||
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
|
|
|||
|
|
@ -215,13 +215,13 @@ Status CheckInputs(const Tensor* query,
|
|||
"head_size shall be a multiple of 16. Got head_size % 16 == ",
|
||||
head_size % 16);
|
||||
}
|
||||
if (cos_dims[0] < present_sequence_length) {
|
||||
if (cos_dims[0] < total_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 0 should be of max_sequence_length.");
|
||||
"cos_cache dimension 0 should be not be less than total_sequence_length.");
|
||||
}
|
||||
if (sin_dims[0] < present_sequence_length) {
|
||||
if (sin_dims[0] < total_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sin_cache dimension 0 should be of max_sequence_length.");
|
||||
"sin_cache dimension 0 should be not be less than total_sequence_length.");
|
||||
}
|
||||
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
|
|
|||
Loading…
Reference in a new issue