diff --git a/orttraining/orttraining/training_ops/cuda/loss/softmaxcrossentropy_impl.cu b/orttraining/orttraining/training_ops/cuda/loss/softmaxcrossentropy_impl.cu index f723d582b6..cbb430418b 100644 --- a/orttraining/orttraining/training_ops/cuda/loss/softmaxcrossentropy_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/loss/softmaxcrossentropy_impl.cu @@ -101,7 +101,11 @@ __global__ void _SparseSoftmaxCrossEntropy( CUDA_LONG D) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N); CUDA_KERNEL_ASSERT(label_data[i] >= 0 && label_data[i] < D); - output_data[i] = -log_prob_data[i * D + label_data[i]] / (*normalize_factor_data); + if (*normalize_factor_data == 0) { + output_data[i] = 0; + } else { + output_data[i] = -log_prob_data[i * D + label_data[i]] / (*normalize_factor_data); + } } template @@ -115,7 +119,11 @@ __global__ void _WeightedSparseSoftmaxCrossEntropy( CUDA_LONG D) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N); CUDA_KERNEL_ASSERT(label_data[i] >= 0 && label_data[i] < D); - output_data[i] = -log_prob_data[i * D + label_data[i]] * weight_data[i] / (*normalize_factor_data); + if (*normalize_factor_data == 0) { + output_data[i] = 0; + } else { + output_data[i] = -log_prob_data[i * D + label_data[i]] * weight_data[i] / (*normalize_factor_data); + } } template @@ -175,7 +183,11 @@ __global__ void _SparseSoftmaxCrossEntropyGrad( CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N * D); int row = i / D; int d = i % D; - output_data[i] = (*dY) * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor); + if (*normalize_factor == 0) { + output_data[i] = 0; + } else { + output_data[i] = (*dY) * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor); + } } template @@ -191,7 +203,11 @@ __global__ void _WeightedSparseSoftmaxCrossEntropyGrad( CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N * D); int row = i / D; int d = i % D; - output_data[i] = (*dY) * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor); + if (*normalize_factor == 0) { + output_data[i] = 0; + } else { + output_data[i] = (*dY) * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor); + } } template