diff --git a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc index 5ff75fb54b..6b8d1084ca 100644 --- a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc +++ b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc @@ -49,7 +49,6 @@ Status SoftmaxCrossEntropyLoss::ComputeInternal(OpKernelContext* ctx) co int64_t C; onnxruntime::contrib::GetNDCFromLogitAndLabelShape(logit_shape, label_shape, N_D, C); const TensorShape logit_reshape({N_D, C}); - const TensorShape label_reshape({N_D}); Tensor* total_loss = ctx->Output(0, reduction_ == ReductionType::NONE ? TensorShape(label.Shape()) : TensorShape({})); T* total_loss_data = total_loss->template MutableData(); T* tmp_loss_sample_buffer = nullptr; @@ -111,14 +110,15 @@ Status SoftmaxCrossEntropyLoss::ComputeInternal(OpKernelContext* ctx) co CUDA_RETURN_IF_ERROR(cudaMemsetAsync(weight_data_nd_data, 0, N_D * sizeof(T), Stream())); ComputeWeightsSoftmaxCrossEntropyImpl(Stream(), label_data, weight_data, N_D, C, ignore_index_, weight_data_nd_data); + // Compute buffer size in byte for reduction APIs. + const auto buffer_size = + compute_reduction_buffer_size(static_cast(N_D)); + // Allocate reduction buffer whose size is buffer_size bytes, or nullptr if no reduction. + IAllocatorUniquePtr reduction_buffer = GetScratchBuffer( + reduction_ != ReductionType::NONE ? buffer_size : 0); + auto normalize_factor_data = GetScratchBuffer(1); if (reduction_ == ReductionType::MEAN) { - // Compute buffer size in byte for reduction APIs. - const auto buffer_size = - compute_reduction_buffer_size(static_cast(N_D)); - // Allocate reduction buffer whose size is buffer_size bytes. - IAllocatorUniquePtr reduction_buffer = GetScratchBuffer( - buffer_size); ORT_RETURN_IF_ERROR(reduce_sum( Stream(), weight_data_nd_data, @@ -157,14 +157,13 @@ Status SoftmaxCrossEntropyLoss::ComputeInternal(OpKernelContext* ctx) co if (reduction_ != ReductionType::NONE) { // ReduceSum on loss_per_sample - std::vector output_dims(1, 1); - ReduceKernelShared( + ORT_RETURN_IF_ERROR(reduce_sum( + Stream(), tmp_loss_sample_buffer, - label_reshape, total_loss_data, - TensorShape({}), - CUDNN_REDUCE_TENSOR_ADD, - output_dims); + static_cast(N_D), + reduction_buffer.get(), + buffer_size)); } return Status::OK(); diff --git a/orttraining/orttraining/training_ops/cuda/loss/softmaxcrossentropy_impl.cc b/orttraining/orttraining/training_ops/cuda/loss/softmaxcrossentropy_impl.cc index 441a39d21d..a4b1b55330 100644 --- a/orttraining/orttraining/training_ops/cuda/loss/softmaxcrossentropy_impl.cc +++ b/orttraining/orttraining/training_ops/cuda/loss/softmaxcrossentropy_impl.cc @@ -167,6 +167,13 @@ Status SparseSoftmaxCrossEntropy::ComputeInternal(OpKernelContext* ctx) weight_data = weight.template Data(); } + // Compute buffer size in byte for reduction APIs. + const auto buffer_size = + compute_reduction_buffer_size(static_cast(N)); + // Allocate reduction buffer whose size is buffer_size bytes. + IAllocatorUniquePtr reduction_buffer = GetScratchBuffer( + buffer_size); + auto normalize_factor_data = GetScratchBuffer(1); if (reduction_ == ReductionType::SUM) { const T normalize_factor = static_cast(1); @@ -176,12 +183,6 @@ Status SparseSoftmaxCrossEntropy::ComputeInternal(OpKernelContext* ctx) const T normalize_factor = static_cast(N); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(T), cudaMemcpyHostToDevice, Stream())); } else { - // Compute buffer size in byte for reduction APIs. - const auto buffer_size = - compute_reduction_buffer_size(static_cast(N)); - // Allocate reduction buffer whose size is buffer_size bytes. - IAllocatorUniquePtr reduction_buffer = GetScratchBuffer( - buffer_size); ORT_RETURN_IF_ERROR(reduce_sum( Stream(), weight_data, @@ -202,14 +203,13 @@ Status SparseSoftmaxCrossEntropy::ComputeInternal(OpKernelContext* ctx) D); // ReduceSum on loss_per_sample - std::vector output_dims(1, 1); - return ReduceKernelShared( + return reduce_sum( + Stream(), tmp_loss_sample.get(), - label_reshape, total_loss_data, - TensorShape({}), - CUDNN_REDUCE_TENSOR_ADD, - output_dims); + static_cast(N), + reduction_buffer.get(), + buffer_size); } template diff --git a/tools/ci_build/github/pai/pai-excluded-tests.txt b/tools/ci_build/github/pai/pai-excluded-tests.txt index 37d0470127..7bee65b29a 100644 --- a/tools/ci_build/github/pai/pai-excluded-tests.txt +++ b/tools/ci_build/github/pai/pai-excluded-tests.txt @@ -1,4 +1,3 @@ -CudaKernelTest.SparseSoftmaxCrossEntropy_LargeSizeTensor CudaKernelTest.NegativeLogLikelihoodLoss_TinySizeTensor CudaKernelTest.NegativeLogLikelihoodLoss_SmallSizeTensor CudaKernelTest.NegativeLogLikelihoodLoss_MediumSizeTensor