From 39a7f96a4435eddb52fd1c496a9f348bfb9bf033 Mon Sep 17 00:00:00 2001 From: Suffian Khan Date: Thu, 17 Sep 2020 07:15:25 -0700 Subject: [PATCH] Fix softmax_warp_backward math when is_log_softmax = True and register LogSoftmax CUDA kernel (#5160) * register logsoftmax cuda kernel; fix logsoftmaxgrad cuda kernal; fix tests to invoke dispatch_softmax_* * forgot to remove axis check * add tests all axis Co-authored-by: suffian khan --- .../providers/cuda/cuda_execution_provider.cc | 18 +++ .../core/providers/cuda/math/softmax.cc | 2 +- .../test/training_ops/cuda/softmax_test.cc | 142 +++++++++++++----- .../training_ops/cuda/math/softmax_grad.cc | 2 +- .../cuda/math/softmax_grad_impl.cu | 27 +++- 5 files changed, 146 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 87443534a6..cabd267a18 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -378,6 +378,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, float, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, double, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow); @@ -735,6 +738,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Softmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, LogSoftmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, LogSoftmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK); @@ -925,6 +931,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Softmax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Softmax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, LogSoftmax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LogSoftmax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LogSoftmax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Split); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Squeeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Unsqueeze); @@ -1052,6 +1061,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1412,6 +1424,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1598,6 +1613,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/softmax.cc b/onnxruntime/core/providers/cuda/math/softmax.cc index ab18b38a17..39e3b57539 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.cc +++ b/onnxruntime/core/providers/cuda/math/softmax.cc @@ -26,7 +26,7 @@ Status SoftMaxComputeHelper( // cudnnSoftmaxForward/Backward is not optimal implementation. // TODO: remove cudnn path completely in the future. - if (D == input_shape[normalized_axis] && D <= 1024 && D * sizeof(T) <= 4096) { + if (D <= 1024 && D * sizeof(T) <= 4096) { dispatch_softmax_forward, is_log_softmax>(Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N)); return Status::OK(); } diff --git a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc index 192c7cb2b7..6d103dc924 100644 --- a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc @@ -8,12 +8,14 @@ namespace test { static void TestSoftmax(const std::vector& X_dims, const std::vector& Y_dims, + int axis = 1, bool is_log_softmax=false, double per_sample_tolerance = 1e-4, double relative_per_sample_tolerance = 1e-4) { const char* op = is_log_softmax? "LogSoftmax" : "Softmax"; CompareOpTester test(op); + test.AddAttribute("axis", axis); // create rand inputs RandomValueGenerator random{}; @@ -26,39 +28,72 @@ static void TestSoftmax(const std::vector& X_dims, test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance); } -TEST(CudaKernelTest, Softmax_SmallTensor) { - std::vector X_dims{8, 2, 128, 128}; - std::vector Y_dims{8, 2, 128, 128}; - TestSoftmax(X_dims, Y_dims, false); +// small tensor to check softmax_warp_forward +// note: keep nelem <= 1024 to invoke softmax_warp_forward! +TEST(CudaKernelTest, Softmax_SmallTensor_LastAxis) { + std::vector X_dims{4, 2, 128}; + std::vector Y_dims{4, 2, 128}; + TestSoftmax(X_dims, Y_dims, 2, false); } -TEST(CudaKernelTest, Softmax_LargeTensor) { - std::vector X_dims{8, 16, 512, 512}; - std::vector Y_dims{8, 16, 512, 512}; - TestSoftmax(X_dims, Y_dims, false); +TEST(CudaKernelTest, Softmax_SmallTensor_AllAxis) { + std::vector X_dims{4, 2, 128}; + std::vector Y_dims{4, 2, 128}; + TestSoftmax(X_dims, Y_dims, 0, false); + TestSoftmax(X_dims, Y_dims, 1, false); } -TEST(CudaKernelTest, LogSoftmax_SmallTensor) { - std::vector X_dims{8, 2, 128, 128}; - std::vector Y_dims{8, 2, 128, 128}; - TestSoftmax(X_dims, Y_dims, true); +// large tensor to check cuda DNN softmax forward +TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis) { + std::vector X_dims{8, 16, 2048}; + std::vector Y_dims{8, 16, 2048}; + TestSoftmax(X_dims, Y_dims, 2, false); } -TEST(CudaKernelTest, LogSoftmax_LargeTensor) { - std::vector X_dims{8, 16, 512, 512}; - std::vector Y_dims{8, 16, 512, 512}; - TestSoftmax(X_dims, Y_dims, true); +TEST(CudaKernelTest, Softmax_LargeTensor_AllAxis) { + std::vector X_dims{8, 16, 512}; + std::vector Y_dims{8, 16, 512}; + TestSoftmax(X_dims, Y_dims, 0, false); + TestSoftmax(X_dims, Y_dims, 1, false); +} + +TEST(CudaKernelTest, LogSoftmax_SmallTensor_LastAxis) { + std::vector X_dims{4, 2, 128}; + std::vector Y_dims{4, 2, 128}; + TestSoftmax(X_dims, Y_dims, 2, true); +} + +TEST(CudaKernelTest, LogSoftmax_SmallTensor_AllAxis) { + std::vector X_dims{4, 2, 128}; + std::vector Y_dims{4, 2, 128}; + TestSoftmax(X_dims, Y_dims, 0, true); + TestSoftmax(X_dims, Y_dims, 1, true); +} + +TEST(CudaKernelTest, LogSoftmax_LargeTensor_LastAxis) { + std::vector X_dims{8, 16, 2048}; + std::vector Y_dims{8, 16, 2048}; + TestSoftmax(X_dims, Y_dims, 2, true); +} + +TEST(CudaKernelTest, LogSoftmax_LargeTensor_AllAxis) { + std::vector X_dims{8, 16, 512}; + std::vector Y_dims{8, 16, 512}; + TestSoftmax(X_dims, Y_dims, 0, true); + TestSoftmax(X_dims, Y_dims, 1, true); } static void TestSoftmaxGrad(const std::vector& dY_dims, const std::vector& Y_dims, const std::vector& dX_dims, + int axis = 1, bool is_log_softmax = false, double per_sample_tolerance = 1e-4, double relative_per_sample_tolerance = 1e-4) { const char* op = is_log_softmax? "LogSoftmaxGrad" : "SoftmaxGrad"; CompareOpTester test(op, 1, kMSDomain); + test.AddAttribute("axis", axis); // create rand inputs RandomValueGenerator random{}; @@ -75,32 +110,67 @@ static void TestSoftmaxGrad(const std::vector& dY_dims, test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance); } -TEST(CudaKernelTest, SoftmaxGrad_SmallTensor) { - std::vector dY_dims{8, 2, 128, 128}; - std::vector Y_dims{8, 2, 128, 128}; - std::vector dX_dims{8, 2, 128, 128}; - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims); +// small tensor to check dispatch_softmax_backward +TEST(CudaKernelTest, SoftmaxGrad_SmallTensor_LastAxis) { + std::vector dY_dims{4, 2, 128}; + std::vector Y_dims{4, 2, 128}; + std::vector dX_dims{4, 2, 128}; + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2); } -TEST(CudaKernelTest, SoftmaxGrad_LargeTensor) { - std::vector dY_dims{8, 16, 512, 512}; - std::vector Y_dims{8, 16, 512, 512}; - std::vector dX_dims{8, 16, 512, 512}; - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims); +TEST(CudaKernelTest, SoftmaxGrad_SmallTensor_AllAxis) { + std::vector dY_dims{4, 2, 128}; + std::vector Y_dims{4, 2, 128}; + std::vector dX_dims{4, 2, 128}; + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0); + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1); } -TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor) { - std::vector dY_dims{8, 2, 128, 128}; - std::vector Y_dims{8, 2, 128, 128}; - std::vector dX_dims{8, 2, 128, 128}; - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, true); +// large tensor to check cuda DNN softmax backward +TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_LastAxis) { + std::vector dY_dims{8, 16, 2048}; + std::vector Y_dims{8, 16, 2048}; + std::vector dX_dims{8, 16, 2048}; + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2); } -TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor) { - std::vector dY_dims{8, 16, 512, 512}; - std::vector Y_dims{8, 16, 512, 512}; - std::vector dX_dims{8, 16, 512, 512}; - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, true); +// large tensor to check cuda DNN softmax backward +TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_AllAxis) { + std::vector dY_dims{8, 16, 512}; + std::vector Y_dims{8, 16, 512}; + std::vector dX_dims{8, 16, 512}; + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0); + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1); +} + +TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor_LastAxis) { + std::vector dY_dims{4, 2, 128}; + std::vector Y_dims{4, 2, 128}; + std::vector dX_dims{4, 2, 128}; + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true); +} + +TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor_AllAxis) { + std::vector dY_dims{4, 2, 128}; + std::vector Y_dims{4, 2, 128}; + std::vector dX_dims{4, 2, 128}; + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true); + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true); +} + +TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_LastAxis) { + std::vector dY_dims{8, 16, 2048}; + std::vector Y_dims{8, 16, 2048}; + std::vector dX_dims{8, 16, 2048}; + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true); +} + +TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_AllAxis) { + std::vector dY_dims{8, 16, 512}; + std::vector Y_dims{8, 16, 512}; + std::vector dX_dims{8, 16, 512}; + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true); + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true); } } // namespace test diff --git a/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc b/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc index 615ca9792c..c9351245a0 100644 --- a/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc @@ -29,7 +29,7 @@ Status SoftMaxGradComputeHelper( auto Y_data = reinterpret_cast(Y); auto dX_data = reinterpret_cast(dX); - if (D == input_shape[normalized_axis] && D <= 1024 && D * sizeof(T) <= 4096) { + if (D <= 1024 && D * sizeof(T) <= 4096) { dispatch_softmax_backward, is_log_softmax>(dX_data, dY_data, Y_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N)); return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu index 294f955ac1..50eae2c11f 100644 --- a/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu @@ -77,15 +77,28 @@ __global__ void softmax_warp_backward(output_t* gradInput, const input_t* grad, } acc_t sum[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_output_reg[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_output_reg[i][it]; + if (!is_log_softmax) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_output_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_output_reg[i][it]; + } } + warp_reduce(sum); + } + else { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); } - warp_reduce(sum); // store result #pragma unroll