mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Longformer Attention non-determinism issue fix (#7574)
* Fix run-to-run not deterministic bug. * Remove non-deterministic logic in softmax * Fix value diff when removing non-deterministic issue. Co-authored-by: Lei Zhang <zhang.huanning@hotmail.com>
This commit is contained in:
parent
94c97ac8c2
commit
8a9ddfe963
1 changed files with 13 additions and 6 deletions
|
|
@ -103,7 +103,7 @@ __launch_bounds__(blockSize)
|
|||
if (is_local_row) {
|
||||
for (int g = tid; g < global_num; g += blockSize) {
|
||||
int i = global_index[g];
|
||||
if (i < col_start || i > col_end) {
|
||||
if (i < col_start || i >= col_end) {
|
||||
float x = input_block[i];
|
||||
x = x * scaler + (float)mask_block[i];
|
||||
if (max_input < x) {
|
||||
|
|
@ -130,7 +130,7 @@ __launch_bounds__(blockSize)
|
|||
if (is_local_row) {
|
||||
for (int g = tid; g < global_num; g += blockSize) {
|
||||
int i = global_index[g];
|
||||
if (i < col_start || i > col_end) {
|
||||
if (i < col_start || i >= col_end) {
|
||||
float x = input_block[i];
|
||||
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
|
||||
sum_input += x;
|
||||
|
|
@ -163,14 +163,21 @@ __launch_bounds__(blockSize)
|
|||
}
|
||||
|
||||
for (int i = tid + zero_start; i < zero_end; i += blockSize) {
|
||||
output_block[i] = (T)(0.);
|
||||
if (i < col_start || i >= col_end) {
|
||||
output_block[i] = (T)(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_local_row) {
|
||||
for (int g = tid; g < global_num; g += blockSize) {
|
||||
int i = global_index[g];
|
||||
float x = input_block[i];
|
||||
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
|
||||
output_block[i] = (T)(recip_sum * x);
|
||||
if (i < col_start || i >= col_end) {
|
||||
float x = input_block[i];
|
||||
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
|
||||
output_block[i] = (T)(recip_sum * x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue