Longformer Attention non-determinism issue fix (#7574)

* Fix run-to-run not deterministic bug.

* Remove non-deterministic logic in softmax

* Fix value diff when removing non-deterministic issue.

Co-authored-by: Lei Zhang <zhang.huanning@hotmail.com>
This commit is contained in:
Ye Wang 2021-05-05 09:54:25 -07:00 committed by GitHub
parent 94c97ac8c2
commit 8a9ddfe963
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -103,7 +103,7 @@ __launch_bounds__(blockSize)
if (is_local_row) {
for (int g = tid; g < global_num; g += blockSize) {
int i = global_index[g];
if (i < col_start || i > col_end) {
if (i < col_start || i >= col_end) {
float x = input_block[i];
x = x * scaler + (float)mask_block[i];
if (max_input < x) {
@ -130,7 +130,7 @@ __launch_bounds__(blockSize)
if (is_local_row) {
for (int g = tid; g < global_num; g += blockSize) {
int i = global_index[g];
if (i < col_start || i > col_end) {
if (i < col_start || i >= col_end) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
sum_input += x;
@ -163,14 +163,21 @@ __launch_bounds__(blockSize)
}
for (int i = tid + zero_start; i < zero_end; i += blockSize) {
output_block[i] = (T)(0.);
if (i < col_start || i >= col_end) {
output_block[i] = (T)(0.);
}
}
}
__syncthreads();
if (is_local_row) {
for (int g = tid; g < global_num; g += blockSize) {
int i = global_index[g];
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
output_block[i] = (T)(recip_sum * x);
if (i < col_start || i >= col_end) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
output_block[i] = (T)(recip_sum * x);
}
}
}