Fix divide-by-zero for SSCE kernel when normalize factor is zero. (#4911)

* Changes in SSCE for all tokens ignored case.
This commit is contained in:
harshithapv 2020-08-26 17:12:17 -07:00 committed by GitHub
parent cac25751bd
commit 00fe718264
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -101,7 +101,11 @@ __global__ void _SparseSoftmaxCrossEntropy(
CUDA_LONG D) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N);
CUDA_KERNEL_ASSERT(label_data[i] >= 0 && label_data[i] < D);
output_data[i] = -log_prob_data[i * D + label_data[i]] / (*normalize_factor_data);
if (*normalize_factor_data == 0) {
output_data[i] = 0;
} else {
output_data[i] = -log_prob_data[i * D + label_data[i]] / (*normalize_factor_data);
}
}
template <typename T, typename Tin>
@ -115,7 +119,11 @@ __global__ void _WeightedSparseSoftmaxCrossEntropy(
CUDA_LONG D) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N);
CUDA_KERNEL_ASSERT(label_data[i] >= 0 && label_data[i] < D);
output_data[i] = -log_prob_data[i * D + label_data[i]] * weight_data[i] / (*normalize_factor_data);
if (*normalize_factor_data == 0) {
output_data[i] = 0;
} else {
output_data[i] = -log_prob_data[i * D + label_data[i]] * weight_data[i] / (*normalize_factor_data);
}
}
template <typename T, typename Tin>
@ -175,7 +183,11 @@ __global__ void _SparseSoftmaxCrossEntropyGrad(
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N * D);
int row = i / D;
int d = i % D;
output_data[i] = (*dY) * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
if (*normalize_factor == 0) {
output_data[i] = 0;
} else {
output_data[i] = (*dY) * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
}
}
template <typename T, typename Tin>
@ -191,7 +203,11 @@ __global__ void _WeightedSparseSoftmaxCrossEntropyGrad(
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N * D);
int row = i / D;
int d = i % D;
output_data[i] = (*dY) * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
if (*normalize_factor == 0) {
output_data[i] = 0;
} else {
output_data[i] = (*dY) * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
}
}
template <typename T, typename Tin>