mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Fix softmax_warp_backward math when is_log_softmax = True and register LogSoftmax CUDA kernel (#5160)
* register logsoftmax cuda kernel; fix logsoftmaxgrad cuda kernal; fix tests to invoke dispatch_softmax_* * forgot to remove axis check * add tests all axis Co-authored-by: suffian khan <sukha@OrtTrainingDev1.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
8e650c5384
commit
39a7f96a44
5 changed files with 146 additions and 45 deletions
|
|
@ -378,6 +378,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, float, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, double, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow);
|
||||
|
|
@ -735,6 +738,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, LogSoftmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, LogSoftmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK);
|
||||
|
|
@ -925,6 +931,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Softmax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Softmax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, LogSoftmax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LogSoftmax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LogSoftmax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Split);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Squeeze);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Unsqueeze);
|
||||
|
|
@ -1052,6 +1061,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, float, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, double, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow)>,
|
||||
|
|
@ -1412,6 +1424,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, LogSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, LogSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK)>,
|
||||
|
|
@ -1598,6 +1613,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, LogSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LogSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LogSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Squeeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Unsqueeze)>,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ Status SoftMaxComputeHelper(
|
|||
|
||||
// cudnnSoftmaxForward/Backward is not optimal implementation.
|
||||
// TODO: remove cudnn path completely in the future.
|
||||
if (D == input_shape[normalized_axis] && D <= 1024 && D * sizeof(T) <= 4096) {
|
||||
if (D <= 1024 && D * sizeof(T) <= 4096) {
|
||||
dispatch_softmax_forward<CudaT, CudaT, AccType<T>, is_log_softmax>(Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,12 +8,14 @@ namespace test {
|
|||
|
||||
static void TestSoftmax(const std::vector<int64_t>& X_dims,
|
||||
const std::vector<int64_t>& Y_dims,
|
||||
int axis = 1,
|
||||
bool is_log_softmax=false,
|
||||
double per_sample_tolerance = 1e-4,
|
||||
double relative_per_sample_tolerance = 1e-4) {
|
||||
|
||||
const char* op = is_log_softmax? "LogSoftmax" : "Softmax";
|
||||
CompareOpTester test(op);
|
||||
test.AddAttribute<int64_t>("axis", axis);
|
||||
|
||||
// create rand inputs
|
||||
RandomValueGenerator random{};
|
||||
|
|
@ -26,39 +28,72 @@ static void TestSoftmax(const std::vector<int64_t>& X_dims,
|
|||
test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, Softmax_SmallTensor) {
|
||||
std::vector<int64_t> X_dims{8, 2, 128, 128};
|
||||
std::vector<int64_t> Y_dims{8, 2, 128, 128};
|
||||
TestSoftmax(X_dims, Y_dims, false);
|
||||
// small tensor to check softmax_warp_forward
|
||||
// note: keep nelem <= 1024 to invoke softmax_warp_forward!
|
||||
TEST(CudaKernelTest, Softmax_SmallTensor_LastAxis) {
|
||||
std::vector<int64_t> X_dims{4, 2, 128};
|
||||
std::vector<int64_t> Y_dims{4, 2, 128};
|
||||
TestSoftmax(X_dims, Y_dims, 2, false);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, Softmax_LargeTensor) {
|
||||
std::vector<int64_t> X_dims{8, 16, 512, 512};
|
||||
std::vector<int64_t> Y_dims{8, 16, 512, 512};
|
||||
TestSoftmax(X_dims, Y_dims, false);
|
||||
TEST(CudaKernelTest, Softmax_SmallTensor_AllAxis) {
|
||||
std::vector<int64_t> X_dims{4, 2, 128};
|
||||
std::vector<int64_t> Y_dims{4, 2, 128};
|
||||
TestSoftmax(X_dims, Y_dims, 0, false);
|
||||
TestSoftmax(X_dims, Y_dims, 1, false);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmax_SmallTensor) {
|
||||
std::vector<int64_t> X_dims{8, 2, 128, 128};
|
||||
std::vector<int64_t> Y_dims{8, 2, 128, 128};
|
||||
TestSoftmax(X_dims, Y_dims, true);
|
||||
// large tensor to check cuda DNN softmax forward
|
||||
TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis) {
|
||||
std::vector<int64_t> X_dims{8, 16, 2048};
|
||||
std::vector<int64_t> Y_dims{8, 16, 2048};
|
||||
TestSoftmax(X_dims, Y_dims, 2, false);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmax_LargeTensor) {
|
||||
std::vector<int64_t> X_dims{8, 16, 512, 512};
|
||||
std::vector<int64_t> Y_dims{8, 16, 512, 512};
|
||||
TestSoftmax(X_dims, Y_dims, true);
|
||||
TEST(CudaKernelTest, Softmax_LargeTensor_AllAxis) {
|
||||
std::vector<int64_t> X_dims{8, 16, 512};
|
||||
std::vector<int64_t> Y_dims{8, 16, 512};
|
||||
TestSoftmax(X_dims, Y_dims, 0, false);
|
||||
TestSoftmax(X_dims, Y_dims, 1, false);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmax_SmallTensor_LastAxis) {
|
||||
std::vector<int64_t> X_dims{4, 2, 128};
|
||||
std::vector<int64_t> Y_dims{4, 2, 128};
|
||||
TestSoftmax(X_dims, Y_dims, 2, true);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmax_SmallTensor_AllAxis) {
|
||||
std::vector<int64_t> X_dims{4, 2, 128};
|
||||
std::vector<int64_t> Y_dims{4, 2, 128};
|
||||
TestSoftmax(X_dims, Y_dims, 0, true);
|
||||
TestSoftmax(X_dims, Y_dims, 1, true);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmax_LargeTensor_LastAxis) {
|
||||
std::vector<int64_t> X_dims{8, 16, 2048};
|
||||
std::vector<int64_t> Y_dims{8, 16, 2048};
|
||||
TestSoftmax(X_dims, Y_dims, 2, true);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmax_LargeTensor_AllAxis) {
|
||||
std::vector<int64_t> X_dims{8, 16, 512};
|
||||
std::vector<int64_t> Y_dims{8, 16, 512};
|
||||
TestSoftmax(X_dims, Y_dims, 0, true);
|
||||
TestSoftmax(X_dims, Y_dims, 1, true);
|
||||
}
|
||||
|
||||
static void TestSoftmaxGrad(const std::vector<int64_t>& dY_dims,
|
||||
const std::vector<int64_t>& Y_dims,
|
||||
const std::vector<int64_t>& dX_dims,
|
||||
int axis = 1,
|
||||
bool is_log_softmax = false,
|
||||
double per_sample_tolerance = 1e-4,
|
||||
double relative_per_sample_tolerance = 1e-4) {
|
||||
|
||||
const char* op = is_log_softmax? "LogSoftmaxGrad" : "SoftmaxGrad";
|
||||
CompareOpTester test(op, 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", axis);
|
||||
|
||||
// create rand inputs
|
||||
RandomValueGenerator random{};
|
||||
|
|
@ -75,32 +110,67 @@ static void TestSoftmaxGrad(const std::vector<int64_t>& dY_dims,
|
|||
test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxGrad_SmallTensor) {
|
||||
std::vector<int64_t> dY_dims{8, 2, 128, 128};
|
||||
std::vector<int64_t> Y_dims{8, 2, 128, 128};
|
||||
std::vector<int64_t> dX_dims{8, 2, 128, 128};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims);
|
||||
// small tensor to check dispatch_softmax_backward
|
||||
TEST(CudaKernelTest, SoftmaxGrad_SmallTensor_LastAxis) {
|
||||
std::vector<int64_t> dY_dims{4, 2, 128};
|
||||
std::vector<int64_t> Y_dims{4, 2, 128};
|
||||
std::vector<int64_t> dX_dims{4, 2, 128};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxGrad_LargeTensor) {
|
||||
std::vector<int64_t> dY_dims{8, 16, 512, 512};
|
||||
std::vector<int64_t> Y_dims{8, 16, 512, 512};
|
||||
std::vector<int64_t> dX_dims{8, 16, 512, 512};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims);
|
||||
TEST(CudaKernelTest, SoftmaxGrad_SmallTensor_AllAxis) {
|
||||
std::vector<int64_t> dY_dims{4, 2, 128};
|
||||
std::vector<int64_t> Y_dims{4, 2, 128};
|
||||
std::vector<int64_t> dX_dims{4, 2, 128};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0);
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor) {
|
||||
std::vector<int64_t> dY_dims{8, 2, 128, 128};
|
||||
std::vector<int64_t> Y_dims{8, 2, 128, 128};
|
||||
std::vector<int64_t> dX_dims{8, 2, 128, 128};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, true);
|
||||
// large tensor to check cuda DNN softmax backward
|
||||
TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_LastAxis) {
|
||||
std::vector<int64_t> dY_dims{8, 16, 2048};
|
||||
std::vector<int64_t> Y_dims{8, 16, 2048};
|
||||
std::vector<int64_t> dX_dims{8, 16, 2048};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor) {
|
||||
std::vector<int64_t> dY_dims{8, 16, 512, 512};
|
||||
std::vector<int64_t> Y_dims{8, 16, 512, 512};
|
||||
std::vector<int64_t> dX_dims{8, 16, 512, 512};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, true);
|
||||
// large tensor to check cuda DNN softmax backward
|
||||
TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_AllAxis) {
|
||||
std::vector<int64_t> dY_dims{8, 16, 512};
|
||||
std::vector<int64_t> Y_dims{8, 16, 512};
|
||||
std::vector<int64_t> dX_dims{8, 16, 512};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0);
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor_LastAxis) {
|
||||
std::vector<int64_t> dY_dims{4, 2, 128};
|
||||
std::vector<int64_t> Y_dims{4, 2, 128};
|
||||
std::vector<int64_t> dX_dims{4, 2, 128};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor_AllAxis) {
|
||||
std::vector<int64_t> dY_dims{4, 2, 128};
|
||||
std::vector<int64_t> Y_dims{4, 2, 128};
|
||||
std::vector<int64_t> dX_dims{4, 2, 128};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true);
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_LastAxis) {
|
||||
std::vector<int64_t> dY_dims{8, 16, 2048};
|
||||
std::vector<int64_t> Y_dims{8, 16, 2048};
|
||||
std::vector<int64_t> dX_dims{8, 16, 2048};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_AllAxis) {
|
||||
std::vector<int64_t> dY_dims{8, 16, 512};
|
||||
std::vector<int64_t> Y_dims{8, 16, 512};
|
||||
std::vector<int64_t> dX_dims{8, 16, 512};
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true);
|
||||
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ Status SoftMaxGradComputeHelper(
|
|||
auto Y_data = reinterpret_cast<const CudaT*>(Y);
|
||||
auto dX_data = reinterpret_cast<CudaT*>(dX);
|
||||
|
||||
if (D == input_shape[normalized_axis] && D <= 1024 && D * sizeof(T) <= 4096) {
|
||||
if (D <= 1024 && D * sizeof(T) <= 4096) {
|
||||
dispatch_softmax_backward<CudaT, CudaT, AccType<T>, is_log_softmax>(dX_data, dY_data, Y_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,15 +77,28 @@ __global__ void softmax_warp_backward(output_t* gradInput, const input_t* grad,
|
|||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_output_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_output_reg[i][it];
|
||||
if (!is_log_softmax) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_output_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_output_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
|
|
|
|||
Loading…
Reference in a new issue