CUDA ConvGrad Kernel (#7227)

* ConvGrad CUDA impl

* Set up the test case for Deberta Conv1D

* Add fp16 test

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Sherlock 2021-04-06 22:09:06 -07:00 committed by GitHub
parent 8219518aa8
commit 4bc17ca04e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 386 additions and 30 deletions

View file

@ -175,8 +175,8 @@ Status Conv<T>::UpdateState(OpKernelContext* context, bool bias_expected) const
ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
gsl::narrow_cast<int>(conv_attrs_.group),
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>()));
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, gsl::narrow_cast<int>(conv_attrs_.group)));
if (context->InputCount() >= 3) {
const Tensor* B = context->Input<Tensor>(2);
@ -330,6 +330,7 @@ Status CudnnConvolutionDescriptor::Set(
const std::vector<int64_t>& pads,
const std::vector<int64_t>& strides,
const std::vector<int64_t>& dilations,
int groups,
cudnnConvolutionMode_t mode,
cudnnDataType_t data_type) {
if (!desc_)
@ -344,6 +345,10 @@ Status CudnnConvolutionDescriptor::Set(
dilation_dims[i] = gsl::narrow_cast<int>(dilations[i]);
}
// This piece of code is copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h
// Setting math_type to CUDNN_DATA_FLOAT for half input
cudnnDataType_t math_type = data_type;
if (data_type == CUDNN_DATA_HALF) math_type = CUDNN_DATA_FLOAT;
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionNdDescriptor(
desc_,
gsl::narrow_cast<int>(rank),
@ -351,7 +356,16 @@ Status CudnnConvolutionDescriptor::Set(
stride_dims.data(),
dilation_dims.data(),
mode,
data_type));
math_type));
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(desc_, groups));
// Copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h
// See Note [behavior of cudnnFind and cudnnGet] at /pytorch/aten/src/ATen/native/cudnn/Conv_v7.cpp
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH));
if (data_type == CUDNN_DATA_HALF) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH));
}
return Status::OK();
}

View file

@ -22,6 +22,7 @@ class CudnnConvolutionDescriptor final {
const std::vector<int64_t>& pads,
const std::vector<int64_t>& strides,
const std::vector<int64_t>& dilations,
int groups,
cudnnConvolutionMode_t mode,
cudnnDataType_t data_type);

View file

@ -107,10 +107,9 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType<CudaT>()));
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides,
p.dilations, mode, CudnnTensor::GetDataType<CudaT>()));
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc,
gsl::narrow_cast<int>(conv_transpose_attrs_.group)));
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations,
gsl::narrow_cast<int>(conv_transpose_attrs_.group),
mode, CudnnTensor::GetDataType<CudaT>()));
if (has_bias) {
const auto& b_shape = p.B->Shape();

View file

@ -24,8 +24,8 @@
407435603592769928
],
[
"ConvGrad ai.onnx CPUExecutionProvider",
551027277226613536
"ConvGrad com.microsoft CPUExecutionProvider",
6051867985469399832
],
[
"DropoutGrad com.microsoft CPUExecutionProvider",

View file

@ -706,7 +706,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetConvGradient) {
}
return std::vector<NodeDef>{
NodeDef("ConvGrad",
NodeDef(OpDef{"ConvGrad", kMSDomain, 1},
{GO(0), I(0), I(1)},
outputs,
SrcNodeAttributes())};

View file

@ -468,11 +468,12 @@ void RegisterTrainingOpSchemas() {
"Constrain index tensor to int64");
ONNX_CONTRIB_OPERATOR_SCHEMA(ConvGrad)
.SinceVersion(9)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Input(0, "dY", "Gradient of output Y", "T")
.Input(1, "X", "Input tensor", "T")
.Input(2, "W", "Weight tensor", "T")
.Output(0, "dX", "Gradient of input X", "T", OpSchema::Optional)
.Output(0, "dX", "Gradient of X", "T", OpSchema::Optional)
.Output(1, "dW", "Gradient of W", "T", OpSchema::Optional)
.Output(2, "dB", "Gradient of B", "T", OpSchema::Optional)
.AllowUncheckedAttributes()

View file

@ -477,8 +477,19 @@ TEST(GradientCheckerTest, FlattenGrad) {
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Flatten", kOnnxDomain, 11};
for (int axis = -3; axis < 3; ++axis) {
gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, {MakeAttribute("axis", int64_t(axis))});
const std::vector<std::pair<int, TensorShape>> axis_to_shape = {
{-3, {1, 24}},
{-2, {2, 12}},
{-1, {6, 4}},
{0, {1, 24}},
{1, {2, 12}},
{2, {6, 4}},
{3, {24, 1}}};
for (auto& pair : axis_to_shape) {
int axis = pair.first;
const TensorShape& output_shape = pair.second;
gradient_checker.ComputeGradientError(op_def, {shape}, {output_shape}, &max_error, {MakeAttribute("axis", int64_t(axis))});
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
}
@ -653,7 +664,6 @@ TEST(GradientCheckerTest, ReluGradDnnl) {
}
#endif // USE_DNNL
#ifndef USE_CUDA
TEST(GradientCheckerTest, CastGrad) {
// A dummy test that cast float to float
// TODO: add more test here
@ -767,7 +777,6 @@ void MaxpoolGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>
}
}
TEST(GradientCheckerTest, MaxPoolGrad) {
MaxpoolGradientCheckerTest(nullptr);
@ -805,10 +814,10 @@ void ConvGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>* e
// 1D convolution
{
TensorShape x_shape({2, 1, 5});
TensorShape w_shape({1, 1, 3});
TensorShape b_shape({1});
TensorShape y_shape({2, 1, 5});
TensorShape x_shape({2, 2, 5});
TensorShape w_shape({2, 2, 3});
TensorShape b_shape({2});
TensorShape y_shape({2, 2, 5});
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
{MakeAttribute("kernel_shape", std::vector<int64_t>{3}),
MakeAttribute("pads", std::vector<int64_t>{1, 1})},
@ -1201,7 +1210,6 @@ TEST(GradientCheckerTest, AveragePoolGrad) {
EXPECT_IS_TINY(max_error);
}
}
#endif
TEST(GradientCheckerTest, TransposeGrad) {
float max_error;
@ -1569,7 +1577,7 @@ void TestSoftmaxCrossEntropyGrad(const TensorShape& input_shape, const std::stri
GenerateRandomDataWithOneHot<float>(x_datas, {input_shape, input_shape}, {1});
gradient_checker.ComputeGradientError(op_def, {input_shape, {input_shape, false}},
{{1}, {input_shape, false}}, &max_error, x_datas,
{{}, {input_shape, false}}, &max_error, x_datas,
{MakeAttribute("reduction", reduction)});
EXPECT_IS_TINY(max_error);
}
@ -1598,7 +1606,7 @@ void TestSparseSoftmaxCrossEntropyGrad(const TensorShape& index_shape, const std
TensorInfo index_info(index_shape, false, &transformer_index, DataTypeImpl::GetTensorType<int64_t>());
gradient_checker.ComputeGradientError(op_def, {x_info, index_info},
{{1}, {logit_shape, false}}, &max_error,
{{}, {logit_shape, false}}, &max_error,
{MakeAttribute("reduction", reduction)});
EXPECT_IS_TINY(max_error);
}
@ -1613,7 +1621,7 @@ void TestSparseSoftmaxCrossEntropyGrad(const TensorShape& index_shape, const std
TensorInfo weight_info(index_shape, false, &transformer_weight);
gradient_checker.ComputeGradientError(op_def, {x_info, index_info, weight_info},
{{1}, {logit_shape, false}}, &max_error,
{{}, {logit_shape, false}}, &max_error,
{MakeAttribute("reduction", reduction)});
EXPECT_IS_TINY(max_error);
}
@ -2069,7 +2077,7 @@ TEST(GradientCheckerTest, SimplifiedLayerNormGrad) {
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
}
#endif
#endif //USE_CUDA
TEST(GradientUtilsTest, InPlaceAccumulatorFloat32) {
OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain);
@ -2100,7 +2108,7 @@ TEST(GradientUtilsTest, InPlaceAccumulatorFloat16) {
// Didn't implement mixed precision InPlaceAccumulator in CPU
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider});
}
#endif
#endif //defined(USE_CUDA) || defined(USE_ROCM)
TEST(GradientUtilsTest, ZeroGradientFloat32) {
OpTester test("ZeroGradient", 1, onnxruntime::kMSDomain);
@ -2133,7 +2141,7 @@ TEST(GradientUtilsTest, ZeroGradientFloat16) {
test.Run();
}
#endif
#endif // defined(USE_CUDA) || defined(USE_ROCM)
TEST(GradientCheckerTest, WhereGrad) {
float max_error;

View file

@ -19,6 +19,8 @@ from inspect import signature
from onnxruntime.training import _utils, ORTModule
import _test_helpers
# Import autocasting libs
from torch.cuda import amp
# PyTorch model definitions for tests
@ -507,6 +509,43 @@ def test_gradient_correctness():
assert torch.allclose(ort_prediction, pt_prediction)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
@pytest.mark.parametrize("use_fp16", [False, True])
def test_gradient_correctness_conv1d(use_fp16):
class NeuralNetConv1D(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0, groups=1):
super(NeuralNetConv1D, self).__init__()
self.conv1 = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=groups)
self.conv2 = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=groups)
def forward(self, input):
out = self.conv1(input.permute(0, 2, 1).contiguous())
out = self.conv2(out).permute(0, 2, 1).contiguous()
return out
device = 'cuda'
N, seq_len, C_in, C_out, kernel_size = 32, 128, 1536, 1536, 3
pt_model = NeuralNetConv1D(C_in, C_out, kernel_size, padding=1).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, x):
with amp.autocast(use_fp16):
prediction = model(x)
loss = prediction.sum()
loss.backward()
return prediction
for step in range(10):
x = torch.randn(N, seq_len, C_in, device=device, requires_grad=True)
pt_prediction = run_step(pt_model, x)
ort_prediction = run_step(ort_model, x)
if use_fp16:
assert torch.allclose(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=1e-2, atol=1e-2)
else:
assert torch.allclose(ort_prediction, pt_prediction, atol=1e-5)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-3, atol=1e-3)
def test_module_with_non_differential_output():
device = 'cuda'
N, D_in, H, D_out = 32, 128, 64, 10

View file

@ -32,7 +32,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad);
@ -128,7 +128,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad)>,

View file

@ -250,9 +250,11 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
return Status::OK();
}
ONNX_CPU_OPERATOR_KERNEL(
ONNX_OPERATOR_KERNEL_EX(
ConvGrad,
9,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ConvGrad<float>);

View file

@ -64,6 +64,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LogSoftmaxGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DropoutGrad);
@ -236,6 +239,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DivGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, DivGrad)>,

View file

@ -0,0 +1,229 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cuda/nn/conv_grad.h"
#include "core/providers/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
namespace onnxruntime {
namespace cuda {
#define REGISTER_GRADIENT_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ConvGrad, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ConvGrad<T>);
REGISTER_GRADIENT_KERNEL_TYPED(float)
REGISTER_GRADIENT_KERNEL_TYPED(double)
REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16)
cudnnStatus_t getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionBwdDataAlgo_t algo, size_t* sz) {
return cudnnGetConvolutionBackwardDataWorkspaceSize(
args.handle,
args.w_desc,
args.o_desc,
args.c_desc,
args.i_desc,
algo,
sz);
}
cudnnStatus_t getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz) {
return cudnnGetConvolutionBackwardFilterWorkspaceSize(
args.handle,
args.i_desc,
args.o_desc,
args.c_desc,
args.w_desc,
algo,
sz);
}
// TODO: we can cache the descriptors, and only update if the input shape changes
template <typename T>
Status ConvGrad<T>::PrepareArgs(const Tensor& input, const Tensor& output, const Tensor& weight, const Tensor* bias) const {
const TensorShape& i_shape = input.Shape();
std::vector<int64_t> i_dims = i_shape.GetDims();
const TensorShape& o_shape = output.Shape();
std::vector<int64_t> o_dims = o_shape.GetDims();
const TensorShape& w_shape = weight.Shape();
std::vector<int64_t> w_dims = w_shape.GetDims();
// Update Attributes
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&input, &weight));
std::vector<int64_t> kernel_shape;
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape));
auto rank = kernel_shape.size();
std::vector<int64_t> pads(conv_attrs_.pads);
if (pads.empty()) {
pads.resize(rank * 2, 0);
}
std::vector<int64_t> dilations(conv_attrs_.dilations);
if (dilations.empty()) {
dilations.resize(rank, 1);
}
std::vector<int64_t> strides(conv_attrs_.strides);
if (strides.empty()) {
strides.resize(rank, 1);
}
// cudnn only takes 4D or 5D input, so pad dimensions if needed
if (rank < 2) {
i_dims.push_back(1);
o_dims.push_back(1);
w_dims.push_back(1);
pads.insert(pads.begin() + rank, 0);
pads.insert(pads.end(), 0);
kernel_shape.push_back(1);
strides.push_back(1);
dilations.push_back(1);
}
args_.handle = CudnnHandle();
args_.data_type = CudnnTensor::GetDataType<CudaT>();
ORT_RETURN_IF_ERROR(args_.i_desc.Set(i_dims, args_.data_type));
ORT_RETURN_IF_ERROR(args_.o_desc.Set(o_dims, args_.data_type));
ORT_RETURN_IF_ERROR(args_.w_desc.Set(w_dims, args_.data_type));
ORT_RETURN_IF_ERROR(args_.c_desc.Set(kernel_shape.size(), pads, strides, dilations,
gsl::narrow_cast<int>(conv_attrs_.group),
CUDNN_CROSS_CORRELATION, args_.data_type));
if (bias) {
const TensorShape& b_shape = bias->Shape();
ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D");
std::vector<int64_t> b_dims(2 + kernel_shape.size(), 1);
b_dims[1] = b_shape[0];
ORT_RETURN_IF_ERROR(args_.b_desc.Set(b_dims, args_.data_type));
}
return Status::OK();
}
template <typename T>
Status ConvGrad<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* dY = context->Input<Tensor>(0);
const Tensor* X = context->Input<Tensor>(1);
const Tensor* W = context->Input<Tensor>(2);
const int64_t M = W->Shape()[0];
Tensor* dX = context->Output(0, X->Shape());
Tensor* dW = context->Output(1, W->Shape());
Tensor* dB = context->Output(2, {M});
ORT_RETURN_IF_ERROR(PrepareArgs(*dX, *dY, *dW, dB));
ORT_RETURN_IF_ERROR(ComputeWeightGradient(dW, dY, X));
ORT_RETURN_IF_ERROR(ComputeInputGradient(dX, dY, W));
ORT_RETURN_IF_ERROR(ComputeBiasGradient(dB, dY));
return Status::OK();
}
template <typename T>
Status ConvGrad<T>::ComputeWeightGradient(Tensor* dW, const Tensor* dY, const Tensor* X) const {
if (dW == nullptr) return Status::OK();
// TODO: implement the algoritm search
cudnnConvolutionBwdFilterAlgoPerf_t perf;
perf.algo = kDefaultConvBwdFilterAlgo;
if (args_.data_type == CUDNN_DATA_HALF) {
perf.mathType = CUDNN_TENSOR_OP_MATH;
} else {
perf.mathType = CUDNN_DEFAULT_MATH;
}
CUDNN_RETURN_IF_ERROR(getWorkspaceSize(args_, perf.algo, &perf.memory));
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.c_desc, perf.mathType));
void* dw_data = dW->template MutableData<T>();
const void* dy_data = dY->template Data<T>();
const void* x_data = X->template Data<T>();
IAllocatorUniquePtr<void> workspace = GetScratchBuffer<void>(perf.memory);
const auto one = Consts<CudaT>::One;
const auto zero = Consts<CudaT>::Zero;
CUDNN_RETURN_IF_ERROR(
cudnnConvolutionBackwardFilter(
args_.handle,
&one, args_.i_desc, x_data,
args_.o_desc, dy_data,
args_.c_desc, perf.algo, workspace.get(), perf.memory,
&zero, args_.w_desc, dw_data));
return Status::OK();
}
template <typename T>
Status ConvGrad<T>::ComputeInputGradient(Tensor* dX, const Tensor* dY, const Tensor* W) const {
if (dX == nullptr) return Status::OK();
// TODO: implement the algoritm search
cudnnConvolutionBwdDataAlgoPerf_t perf;
perf.algo = kDefaultConvBwdDataAlgo;
if (args_.data_type == CUDNN_DATA_HALF) {
perf.mathType = CUDNN_TENSOR_OP_MATH;
} else {
perf.mathType = CUDNN_DEFAULT_MATH;
}
CUDNN_RETURN_IF_ERROR(getWorkspaceSize(args_, perf.algo, &perf.memory));
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.c_desc, perf.mathType));
void* dx_data = dX->template MutableData<T>();
const void* dy_data = dY->template Data<T>();
const void* w_data = W->template Data<T>();
IAllocatorUniquePtr<void> workspace = GetScratchBuffer<void>(perf.memory);
const auto one = Consts<CudaT>::One;
const auto zero = Consts<CudaT>::Zero;
CUDNN_RETURN_IF_ERROR(
cudnnConvolutionBackwardData(
args_.handle,
&one, args_.w_desc, w_data,
args_.o_desc, dy_data,
args_.c_desc, perf.algo, workspace.get(), perf.memory,
&zero, args_.i_desc, dx_data));
return Status::OK();
}
template <typename T>
Status ConvGrad<T>::ComputeBiasGradient(Tensor* dB, const Tensor* dY) const {
if (dB == nullptr) return Status::OK();
const auto one = Consts<CudaT>::One;
const auto zero = Consts<CudaT>::Zero;
void* db_data = dB->template MutableData<T>();
const void* dy_data = dY->template Data<T>();
CUDNN_RETURN_IF_ERROR(
cudnnConvolutionBackwardBias(
args_.handle,
&one, args_.o_desc, dy_data,
&zero, args_.b_desc, db_data));
return Status::OK();
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "core/providers/cuda/nn/conv.h"
namespace onnxruntime {
namespace cuda {
struct ConvolutionArgs {
cudnnHandle_t handle;
cudnnDataType_t data_type;
CudnnTensor i_desc, o_desc, b_desc;
CudnnFilterDescriptor w_desc;
CudnnConvolutionDescriptor c_desc;
ConvolutionArgs() {}
};
template <typename T>
class ConvGrad final : public CudaKernel {
public:
using CudaT = typename ToCudaType<T>::MappedType;
ConvGrad(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) {
auto pads_size = conv_attrs_.pads.size();
ORT_ENFORCE(pads_size % 2 == 0);
}
Status ComputeInternal(OpKernelContext* context) const override;
protected:
mutable ConvolutionArgs args_;
Status PrepareArgs(const Tensor& input, const Tensor& output, const Tensor& weight, const Tensor* bias) const;
ConvAttributes conv_attrs_;
// https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn_742/cudnn-developer-guide/index.html#tensor_ops
static constexpr auto kDefaultConvBwdDataAlgo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
static constexpr auto kDefaultConvBwdFilterAlgo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
private:
Status ComputeWeightGradient(Tensor* dW, const Tensor* dY, const Tensor* X) const;
Status ComputeInputGradient(Tensor* dX, const Tensor* dY, const Tensor* W) const;
Status ComputeBiasGradient(Tensor* dB, const Tensor* dY) const;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -198,6 +198,8 @@ training_ops_excluded_files = [
'math/softmax_grad.cc',
'nn/batch_norm_grad.cc',
'nn/batch_norm_grad.h',
'nn/conv_grad.cc',
'nn/conv_grad.h',
'optimizer/adam.cc',
'optimizer/adam.cu',
'reduction/reduction_all.cc',