Fix softmax_warp_backward math when is_log_softmax = True and register LogSoftmax CUDA kernel (#5160)

* register logsoftmax cuda kernel; fix logsoftmaxgrad cuda kernal; fix tests to invoke dispatch_softmax_*

* forgot to remove axis check

* add tests all axis

Co-authored-by: suffian khan <sukha@OrtTrainingDev1.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Suffian Khan 2020-09-17 07:15:25 -07:00 committed by Tianlei Wu
parent 8e650c5384
commit 39a7f96a44
5 changed files with 146 additions and 45 deletions

View file

@ -378,6 +378,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, float, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, double, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow);
@ -735,6 +738,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK);
@ -925,6 +931,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Softmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Softmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, LogSoftmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LogSoftmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LogSoftmax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Split);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Unsqueeze);
@ -1052,6 +1061,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, float, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, double, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow)>,
@ -1412,6 +1424,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK)>,
@ -1598,6 +1613,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Unsqueeze)>,

View file

@ -26,7 +26,7 @@ Status SoftMaxComputeHelper(
// cudnnSoftmaxForward/Backward is not optimal implementation.
// TODO: remove cudnn path completely in the future.
if (D == input_shape[normalized_axis] && D <= 1024 && D * sizeof(T) <= 4096) {
if (D <= 1024 && D * sizeof(T) <= 4096) {
dispatch_softmax_forward<CudaT, CudaT, AccType<T>, is_log_softmax>(Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
return Status::OK();
}

View file

@ -8,12 +8,14 @@ namespace test {
static void TestSoftmax(const std::vector<int64_t>& X_dims,
const std::vector<int64_t>& Y_dims,
int axis = 1,
bool is_log_softmax=false,
double per_sample_tolerance = 1e-4,
double relative_per_sample_tolerance = 1e-4) {
const char* op = is_log_softmax? "LogSoftmax" : "Softmax";
CompareOpTester test(op);
test.AddAttribute<int64_t>("axis", axis);
// create rand inputs
RandomValueGenerator random{};
@ -26,39 +28,72 @@ static void TestSoftmax(const std::vector<int64_t>& X_dims,
test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
}
TEST(CudaKernelTest, Softmax_SmallTensor) {
std::vector<int64_t> X_dims{8, 2, 128, 128};
std::vector<int64_t> Y_dims{8, 2, 128, 128};
TestSoftmax(X_dims, Y_dims, false);
// small tensor to check softmax_warp_forward
// note: keep nelem <= 1024 to invoke softmax_warp_forward!
TEST(CudaKernelTest, Softmax_SmallTensor_LastAxis) {
std::vector<int64_t> X_dims{4, 2, 128};
std::vector<int64_t> Y_dims{4, 2, 128};
TestSoftmax(X_dims, Y_dims, 2, false);
}
TEST(CudaKernelTest, Softmax_LargeTensor) {
std::vector<int64_t> X_dims{8, 16, 512, 512};
std::vector<int64_t> Y_dims{8, 16, 512, 512};
TestSoftmax(X_dims, Y_dims, false);
TEST(CudaKernelTest, Softmax_SmallTensor_AllAxis) {
std::vector<int64_t> X_dims{4, 2, 128};
std::vector<int64_t> Y_dims{4, 2, 128};
TestSoftmax(X_dims, Y_dims, 0, false);
TestSoftmax(X_dims, Y_dims, 1, false);
}
TEST(CudaKernelTest, LogSoftmax_SmallTensor) {
std::vector<int64_t> X_dims{8, 2, 128, 128};
std::vector<int64_t> Y_dims{8, 2, 128, 128};
TestSoftmax(X_dims, Y_dims, true);
// large tensor to check cuda DNN softmax forward
TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis) {
std::vector<int64_t> X_dims{8, 16, 2048};
std::vector<int64_t> Y_dims{8, 16, 2048};
TestSoftmax(X_dims, Y_dims, 2, false);
}
TEST(CudaKernelTest, LogSoftmax_LargeTensor) {
std::vector<int64_t> X_dims{8, 16, 512, 512};
std::vector<int64_t> Y_dims{8, 16, 512, 512};
TestSoftmax(X_dims, Y_dims, true);
TEST(CudaKernelTest, Softmax_LargeTensor_AllAxis) {
std::vector<int64_t> X_dims{8, 16, 512};
std::vector<int64_t> Y_dims{8, 16, 512};
TestSoftmax(X_dims, Y_dims, 0, false);
TestSoftmax(X_dims, Y_dims, 1, false);
}
TEST(CudaKernelTest, LogSoftmax_SmallTensor_LastAxis) {
std::vector<int64_t> X_dims{4, 2, 128};
std::vector<int64_t> Y_dims{4, 2, 128};
TestSoftmax(X_dims, Y_dims, 2, true);
}
TEST(CudaKernelTest, LogSoftmax_SmallTensor_AllAxis) {
std::vector<int64_t> X_dims{4, 2, 128};
std::vector<int64_t> Y_dims{4, 2, 128};
TestSoftmax(X_dims, Y_dims, 0, true);
TestSoftmax(X_dims, Y_dims, 1, true);
}
TEST(CudaKernelTest, LogSoftmax_LargeTensor_LastAxis) {
std::vector<int64_t> X_dims{8, 16, 2048};
std::vector<int64_t> Y_dims{8, 16, 2048};
TestSoftmax(X_dims, Y_dims, 2, true);
}
TEST(CudaKernelTest, LogSoftmax_LargeTensor_AllAxis) {
std::vector<int64_t> X_dims{8, 16, 512};
std::vector<int64_t> Y_dims{8, 16, 512};
TestSoftmax(X_dims, Y_dims, 0, true);
TestSoftmax(X_dims, Y_dims, 1, true);
}
static void TestSoftmaxGrad(const std::vector<int64_t>& dY_dims,
const std::vector<int64_t>& Y_dims,
const std::vector<int64_t>& dX_dims,
int axis = 1,
bool is_log_softmax = false,
double per_sample_tolerance = 1e-4,
double relative_per_sample_tolerance = 1e-4) {
const char* op = is_log_softmax? "LogSoftmaxGrad" : "SoftmaxGrad";
CompareOpTester test(op, 1, kMSDomain);
test.AddAttribute<int64_t>("axis", axis);
// create rand inputs
RandomValueGenerator random{};
@ -75,32 +110,67 @@ static void TestSoftmaxGrad(const std::vector<int64_t>& dY_dims,
test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
}
TEST(CudaKernelTest, SoftmaxGrad_SmallTensor) {
std::vector<int64_t> dY_dims{8, 2, 128, 128};
std::vector<int64_t> Y_dims{8, 2, 128, 128};
std::vector<int64_t> dX_dims{8, 2, 128, 128};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims);
// small tensor to check dispatch_softmax_backward
TEST(CudaKernelTest, SoftmaxGrad_SmallTensor_LastAxis) {
std::vector<int64_t> dY_dims{4, 2, 128};
std::vector<int64_t> Y_dims{4, 2, 128};
std::vector<int64_t> dX_dims{4, 2, 128};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2);
}
TEST(CudaKernelTest, SoftmaxGrad_LargeTensor) {
std::vector<int64_t> dY_dims{8, 16, 512, 512};
std::vector<int64_t> Y_dims{8, 16, 512, 512};
std::vector<int64_t> dX_dims{8, 16, 512, 512};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims);
TEST(CudaKernelTest, SoftmaxGrad_SmallTensor_AllAxis) {
std::vector<int64_t> dY_dims{4, 2, 128};
std::vector<int64_t> Y_dims{4, 2, 128};
std::vector<int64_t> dX_dims{4, 2, 128};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0);
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1);
}
TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor) {
std::vector<int64_t> dY_dims{8, 2, 128, 128};
std::vector<int64_t> Y_dims{8, 2, 128, 128};
std::vector<int64_t> dX_dims{8, 2, 128, 128};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, true);
// large tensor to check cuda DNN softmax backward
TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_LastAxis) {
std::vector<int64_t> dY_dims{8, 16, 2048};
std::vector<int64_t> Y_dims{8, 16, 2048};
std::vector<int64_t> dX_dims{8, 16, 2048};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2);
}
TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor) {
std::vector<int64_t> dY_dims{8, 16, 512, 512};
std::vector<int64_t> Y_dims{8, 16, 512, 512};
std::vector<int64_t> dX_dims{8, 16, 512, 512};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, true);
// large tensor to check cuda DNN softmax backward
TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_AllAxis) {
std::vector<int64_t> dY_dims{8, 16, 512};
std::vector<int64_t> Y_dims{8, 16, 512};
std::vector<int64_t> dX_dims{8, 16, 512};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0);
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1);
}
TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor_LastAxis) {
std::vector<int64_t> dY_dims{4, 2, 128};
std::vector<int64_t> Y_dims{4, 2, 128};
std::vector<int64_t> dX_dims{4, 2, 128};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true);
}
TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor_AllAxis) {
std::vector<int64_t> dY_dims{4, 2, 128};
std::vector<int64_t> Y_dims{4, 2, 128};
std::vector<int64_t> dX_dims{4, 2, 128};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true);
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true);
}
TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_LastAxis) {
std::vector<int64_t> dY_dims{8, 16, 2048};
std::vector<int64_t> Y_dims{8, 16, 2048};
std::vector<int64_t> dX_dims{8, 16, 2048};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true);
}
TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_AllAxis) {
std::vector<int64_t> dY_dims{8, 16, 512};
std::vector<int64_t> Y_dims{8, 16, 512};
std::vector<int64_t> dX_dims{8, 16, 512};
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true);
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true);
}
} // namespace test

View file

@ -29,7 +29,7 @@ Status SoftMaxGradComputeHelper(
auto Y_data = reinterpret_cast<const CudaT*>(Y);
auto dX_data = reinterpret_cast<CudaT*>(dX);
if (D == input_shape[normalized_axis] && D <= 1024 && D * sizeof(T) <= 4096) {
if (D <= 1024 && D * sizeof(T) <= 4096) {
dispatch_softmax_backward<CudaT, CudaT, AccType<T>, is_log_softmax>(dX_data, dY_data, Y_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
return Status::OK();
}

View file

@ -77,15 +77,28 @@ __global__ void softmax_warp_backward(output_t* gradInput, const input_t* grad,
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_output_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_output_reg[i][it];
if (!is_log_softmax) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_output_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_output_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
}
else {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll