mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
CUDA ConvGrad Kernel (#7227)
* ConvGrad CUDA impl * Set up the test case for Deberta Conv1D * Add fp16 test Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
8219518aa8
commit
4bc17ca04e
14 changed files with 386 additions and 30 deletions
|
|
@ -175,8 +175,8 @@ Status Conv<T>::UpdateState(OpKernelContext* context, bool bias_expected) const
|
|||
ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
|
||||
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
|
||||
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group),
|
||||
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>()));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, gsl::narrow_cast<int>(conv_attrs_.group)));
|
||||
|
||||
if (context->InputCount() >= 3) {
|
||||
const Tensor* B = context->Input<Tensor>(2);
|
||||
|
|
@ -330,6 +330,7 @@ Status CudnnConvolutionDescriptor::Set(
|
|||
const std::vector<int64_t>& pads,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& dilations,
|
||||
int groups,
|
||||
cudnnConvolutionMode_t mode,
|
||||
cudnnDataType_t data_type) {
|
||||
if (!desc_)
|
||||
|
|
@ -344,6 +345,10 @@ Status CudnnConvolutionDescriptor::Set(
|
|||
dilation_dims[i] = gsl::narrow_cast<int>(dilations[i]);
|
||||
}
|
||||
|
||||
// This piece of code is copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h
|
||||
// Setting math_type to CUDNN_DATA_FLOAT for half input
|
||||
cudnnDataType_t math_type = data_type;
|
||||
if (data_type == CUDNN_DATA_HALF) math_type = CUDNN_DATA_FLOAT;
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionNdDescriptor(
|
||||
desc_,
|
||||
gsl::narrow_cast<int>(rank),
|
||||
|
|
@ -351,7 +356,16 @@ Status CudnnConvolutionDescriptor::Set(
|
|||
stride_dims.data(),
|
||||
dilation_dims.data(),
|
||||
mode,
|
||||
data_type));
|
||||
math_type));
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(desc_, groups));
|
||||
|
||||
// Copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h
|
||||
// See Note [behavior of cudnnFind and cudnnGet] at /pytorch/aten/src/ATen/native/cudnn/Conv_v7.cpp
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH));
|
||||
if (data_type == CUDNN_DATA_HALF) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class CudnnConvolutionDescriptor final {
|
|||
const std::vector<int64_t>& pads,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& dilations,
|
||||
int groups,
|
||||
cudnnConvolutionMode_t mode,
|
||||
cudnnDataType_t data_type);
|
||||
|
||||
|
|
|
|||
|
|
@ -107,10 +107,9 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
|
|||
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType<CudaT>()));
|
||||
|
||||
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
|
||||
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides,
|
||||
p.dilations, mode, CudnnTensor::GetDataType<CudaT>()));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc,
|
||||
gsl::narrow_cast<int>(conv_transpose_attrs_.group)));
|
||||
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations,
|
||||
gsl::narrow_cast<int>(conv_transpose_attrs_.group),
|
||||
mode, CudnnTensor::GetDataType<CudaT>()));
|
||||
|
||||
if (has_bias) {
|
||||
const auto& b_shape = p.B->Shape();
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@
|
|||
407435603592769928
|
||||
],
|
||||
[
|
||||
"ConvGrad ai.onnx CPUExecutionProvider",
|
||||
551027277226613536
|
||||
"ConvGrad com.microsoft CPUExecutionProvider",
|
||||
6051867985469399832
|
||||
],
|
||||
[
|
||||
"DropoutGrad com.microsoft CPUExecutionProvider",
|
||||
|
|
|
|||
|
|
@ -706,7 +706,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetConvGradient) {
|
|||
}
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("ConvGrad",
|
||||
NodeDef(OpDef{"ConvGrad", kMSDomain, 1},
|
||||
{GO(0), I(0), I(1)},
|
||||
outputs,
|
||||
SrcNodeAttributes())};
|
||||
|
|
|
|||
|
|
@ -468,11 +468,12 @@ void RegisterTrainingOpSchemas() {
|
|||
"Constrain index tensor to int64");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ConvGrad)
|
||||
.SinceVersion(9)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.Input(0, "dY", "Gradient of output Y", "T")
|
||||
.Input(1, "X", "Input tensor", "T")
|
||||
.Input(2, "W", "Weight tensor", "T")
|
||||
.Output(0, "dX", "Gradient of input X", "T", OpSchema::Optional)
|
||||
.Output(0, "dX", "Gradient of X", "T", OpSchema::Optional)
|
||||
.Output(1, "dW", "Gradient of W", "T", OpSchema::Optional)
|
||||
.Output(2, "dB", "Gradient of B", "T", OpSchema::Optional)
|
||||
.AllowUncheckedAttributes()
|
||||
|
|
|
|||
|
|
@ -477,8 +477,19 @@ TEST(GradientCheckerTest, FlattenGrad) {
|
|||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Flatten", kOnnxDomain, 11};
|
||||
|
||||
for (int axis = -3; axis < 3; ++axis) {
|
||||
gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, {MakeAttribute("axis", int64_t(axis))});
|
||||
const std::vector<std::pair<int, TensorShape>> axis_to_shape = {
|
||||
{-3, {1, 24}},
|
||||
{-2, {2, 12}},
|
||||
{-1, {6, 4}},
|
||||
{0, {1, 24}},
|
||||
{1, {2, 12}},
|
||||
{2, {6, 4}},
|
||||
{3, {24, 1}}};
|
||||
|
||||
for (auto& pair : axis_to_shape) {
|
||||
int axis = pair.first;
|
||||
const TensorShape& output_shape = pair.second;
|
||||
gradient_checker.ComputeGradientError(op_def, {shape}, {output_shape}, &max_error, {MakeAttribute("axis", int64_t(axis))});
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
}
|
||||
|
|
@ -653,7 +664,6 @@ TEST(GradientCheckerTest, ReluGradDnnl) {
|
|||
}
|
||||
#endif // USE_DNNL
|
||||
|
||||
#ifndef USE_CUDA
|
||||
TEST(GradientCheckerTest, CastGrad) {
|
||||
// A dummy test that cast float to float
|
||||
// TODO: add more test here
|
||||
|
|
@ -767,7 +777,6 @@ void MaxpoolGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(GradientCheckerTest, MaxPoolGrad) {
|
||||
MaxpoolGradientCheckerTest(nullptr);
|
||||
|
||||
|
|
@ -805,10 +814,10 @@ void ConvGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>* e
|
|||
|
||||
// 1D convolution
|
||||
{
|
||||
TensorShape x_shape({2, 1, 5});
|
||||
TensorShape w_shape({1, 1, 3});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 5});
|
||||
TensorShape x_shape({2, 2, 5});
|
||||
TensorShape w_shape({2, 2, 3});
|
||||
TensorShape b_shape({2});
|
||||
TensorShape y_shape({2, 2, 5});
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{1, 1})},
|
||||
|
|
@ -1201,7 +1210,6 @@ TEST(GradientCheckerTest, AveragePoolGrad) {
|
|||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(GradientCheckerTest, TransposeGrad) {
|
||||
float max_error;
|
||||
|
|
@ -1569,7 +1577,7 @@ void TestSoftmaxCrossEntropyGrad(const TensorShape& input_shape, const std::stri
|
|||
GenerateRandomDataWithOneHot<float>(x_datas, {input_shape, input_shape}, {1});
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {input_shape, {input_shape, false}},
|
||||
{{1}, {input_shape, false}}, &max_error, x_datas,
|
||||
{{}, {input_shape, false}}, &max_error, x_datas,
|
||||
{MakeAttribute("reduction", reduction)});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
|
@ -1598,7 +1606,7 @@ void TestSparseSoftmaxCrossEntropyGrad(const TensorShape& index_shape, const std
|
|||
TensorInfo index_info(index_shape, false, &transformer_index, DataTypeImpl::GetTensorType<int64_t>());
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, index_info},
|
||||
{{1}, {logit_shape, false}}, &max_error,
|
||||
{{}, {logit_shape, false}}, &max_error,
|
||||
{MakeAttribute("reduction", reduction)});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
|
@ -1613,7 +1621,7 @@ void TestSparseSoftmaxCrossEntropyGrad(const TensorShape& index_shape, const std
|
|||
TensorInfo weight_info(index_shape, false, &transformer_weight);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, index_info, weight_info},
|
||||
{{1}, {logit_shape, false}}, &max_error,
|
||||
{{}, {logit_shape, false}}, &max_error,
|
||||
{MakeAttribute("reduction", reduction)});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
|
@ -2069,7 +2077,7 @@ TEST(GradientCheckerTest, SimplifiedLayerNormGrad) {
|
|||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif //USE_CUDA
|
||||
|
||||
TEST(GradientUtilsTest, InPlaceAccumulatorFloat32) {
|
||||
OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain);
|
||||
|
|
@ -2100,7 +2108,7 @@ TEST(GradientUtilsTest, InPlaceAccumulatorFloat16) {
|
|||
// Didn't implement mixed precision InPlaceAccumulator in CPU
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider});
|
||||
}
|
||||
#endif
|
||||
#endif //defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
TEST(GradientUtilsTest, ZeroGradientFloat32) {
|
||||
OpTester test("ZeroGradient", 1, onnxruntime::kMSDomain);
|
||||
|
|
@ -2133,7 +2141,7 @@ TEST(GradientUtilsTest, ZeroGradientFloat16) {
|
|||
|
||||
test.Run();
|
||||
}
|
||||
#endif
|
||||
#endif // defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
TEST(GradientCheckerTest, WhereGrad) {
|
||||
float max_error;
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ from inspect import signature
|
|||
from onnxruntime.training import _utils, ORTModule
|
||||
import _test_helpers
|
||||
|
||||
# Import autocasting libs
|
||||
from torch.cuda import amp
|
||||
|
||||
# PyTorch model definitions for tests
|
||||
|
||||
|
|
@ -507,6 +509,43 @@ def test_gradient_correctness():
|
|||
assert torch.allclose(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
@pytest.mark.parametrize("use_fp16", [False, True])
|
||||
def test_gradient_correctness_conv1d(use_fp16):
|
||||
class NeuralNetConv1D(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, padding=0, groups=1):
|
||||
super(NeuralNetConv1D, self).__init__()
|
||||
self.conv1 = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=groups)
|
||||
self.conv2 = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=groups)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input.permute(0, 2, 1).contiguous())
|
||||
out = self.conv2(out).permute(0, 2, 1).contiguous()
|
||||
return out
|
||||
|
||||
device = 'cuda'
|
||||
N, seq_len, C_in, C_out, kernel_size = 32, 128, 1536, 1536, 3
|
||||
pt_model = NeuralNetConv1D(C_in, C_out, kernel_size, padding=1).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, x):
|
||||
with amp.autocast(use_fp16):
|
||||
prediction = model(x)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return prediction
|
||||
|
||||
for step in range(10):
|
||||
x = torch.randn(N, seq_len, C_in, device=device, requires_grad=True)
|
||||
pt_prediction = run_step(pt_model, x)
|
||||
ort_prediction = run_step(ort_model, x)
|
||||
|
||||
if use_fp16:
|
||||
assert torch.allclose(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=1e-2, atol=1e-2)
|
||||
else:
|
||||
assert torch.allclose(ort_prediction, pt_prediction, atol=1e-5)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-3, atol=1e-3)
|
||||
|
||||
def test_module_with_non_differential_output():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 128, 64, 10
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad);
|
||||
|
|
@ -128,7 +128,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad)>,
|
||||
|
|
|
|||
|
|
@ -250,9 +250,11 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
ConvGrad,
|
||||
9,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ConvGrad<float>);
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LogSoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DropoutGrad);
|
||||
|
||||
|
|
@ -236,6 +239,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DivGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, DivGrad)>,
|
||||
|
|
|
|||
229
orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc
Normal file
229
orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/nn/conv_grad.h"
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_GRADIENT_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
ConvGrad, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
ConvGrad<T>);
|
||||
|
||||
REGISTER_GRADIENT_KERNEL_TYPED(float)
|
||||
REGISTER_GRADIENT_KERNEL_TYPED(double)
|
||||
REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
cudnnStatus_t getWorkspaceSize(
|
||||
const ConvolutionArgs& args,
|
||||
cudnnConvolutionBwdDataAlgo_t algo, size_t* sz) {
|
||||
return cudnnGetConvolutionBackwardDataWorkspaceSize(
|
||||
args.handle,
|
||||
args.w_desc,
|
||||
args.o_desc,
|
||||
args.c_desc,
|
||||
args.i_desc,
|
||||
algo,
|
||||
sz);
|
||||
}
|
||||
|
||||
cudnnStatus_t getWorkspaceSize(
|
||||
const ConvolutionArgs& args,
|
||||
cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz) {
|
||||
return cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
||||
args.handle,
|
||||
args.i_desc,
|
||||
args.o_desc,
|
||||
args.c_desc,
|
||||
args.w_desc,
|
||||
algo,
|
||||
sz);
|
||||
}
|
||||
|
||||
// TODO: we can cache the descriptors, and only update if the input shape changes
|
||||
template <typename T>
|
||||
Status ConvGrad<T>::PrepareArgs(const Tensor& input, const Tensor& output, const Tensor& weight, const Tensor* bias) const {
|
||||
const TensorShape& i_shape = input.Shape();
|
||||
std::vector<int64_t> i_dims = i_shape.GetDims();
|
||||
|
||||
const TensorShape& o_shape = output.Shape();
|
||||
std::vector<int64_t> o_dims = o_shape.GetDims();
|
||||
|
||||
const TensorShape& w_shape = weight.Shape();
|
||||
std::vector<int64_t> w_dims = w_shape.GetDims();
|
||||
|
||||
// Update Attributes
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&input, &weight));
|
||||
|
||||
std::vector<int64_t> kernel_shape;
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape));
|
||||
auto rank = kernel_shape.size();
|
||||
|
||||
std::vector<int64_t> pads(conv_attrs_.pads);
|
||||
if (pads.empty()) {
|
||||
pads.resize(rank * 2, 0);
|
||||
}
|
||||
|
||||
std::vector<int64_t> dilations(conv_attrs_.dilations);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(rank, 1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> strides(conv_attrs_.strides);
|
||||
if (strides.empty()) {
|
||||
strides.resize(rank, 1);
|
||||
}
|
||||
|
||||
// cudnn only takes 4D or 5D input, so pad dimensions if needed
|
||||
if (rank < 2) {
|
||||
i_dims.push_back(1);
|
||||
o_dims.push_back(1);
|
||||
w_dims.push_back(1);
|
||||
|
||||
pads.insert(pads.begin() + rank, 0);
|
||||
pads.insert(pads.end(), 0);
|
||||
kernel_shape.push_back(1);
|
||||
strides.push_back(1);
|
||||
dilations.push_back(1);
|
||||
}
|
||||
|
||||
args_.handle = CudnnHandle();
|
||||
args_.data_type = CudnnTensor::GetDataType<CudaT>();
|
||||
ORT_RETURN_IF_ERROR(args_.i_desc.Set(i_dims, args_.data_type));
|
||||
ORT_RETURN_IF_ERROR(args_.o_desc.Set(o_dims, args_.data_type));
|
||||
ORT_RETURN_IF_ERROR(args_.w_desc.Set(w_dims, args_.data_type));
|
||||
ORT_RETURN_IF_ERROR(args_.c_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group),
|
||||
CUDNN_CROSS_CORRELATION, args_.data_type));
|
||||
|
||||
if (bias) {
|
||||
const TensorShape& b_shape = bias->Shape();
|
||||
ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D");
|
||||
std::vector<int64_t> b_dims(2 + kernel_shape.size(), 1);
|
||||
b_dims[1] = b_shape[0];
|
||||
ORT_RETURN_IF_ERROR(args_.b_desc.Set(b_dims, args_.data_type));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ConvGrad<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* dY = context->Input<Tensor>(0);
|
||||
const Tensor* X = context->Input<Tensor>(1);
|
||||
const Tensor* W = context->Input<Tensor>(2);
|
||||
|
||||
const int64_t M = W->Shape()[0];
|
||||
|
||||
Tensor* dX = context->Output(0, X->Shape());
|
||||
Tensor* dW = context->Output(1, W->Shape());
|
||||
Tensor* dB = context->Output(2, {M});
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareArgs(*dX, *dY, *dW, dB));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ComputeWeightGradient(dW, dY, X));
|
||||
ORT_RETURN_IF_ERROR(ComputeInputGradient(dX, dY, W));
|
||||
ORT_RETURN_IF_ERROR(ComputeBiasGradient(dB, dY));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ConvGrad<T>::ComputeWeightGradient(Tensor* dW, const Tensor* dY, const Tensor* X) const {
|
||||
if (dW == nullptr) return Status::OK();
|
||||
|
||||
// TODO: implement the algoritm search
|
||||
cudnnConvolutionBwdFilterAlgoPerf_t perf;
|
||||
perf.algo = kDefaultConvBwdFilterAlgo;
|
||||
if (args_.data_type == CUDNN_DATA_HALF) {
|
||||
perf.mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else {
|
||||
perf.mathType = CUDNN_DEFAULT_MATH;
|
||||
}
|
||||
CUDNN_RETURN_IF_ERROR(getWorkspaceSize(args_, perf.algo, &perf.memory));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.c_desc, perf.mathType));
|
||||
|
||||
void* dw_data = dW->template MutableData<T>();
|
||||
const void* dy_data = dY->template Data<T>();
|
||||
const void* x_data = X->template Data<T>();
|
||||
IAllocatorUniquePtr<void> workspace = GetScratchBuffer<void>(perf.memory);
|
||||
|
||||
const auto one = Consts<CudaT>::One;
|
||||
const auto zero = Consts<CudaT>::Zero;
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(
|
||||
cudnnConvolutionBackwardFilter(
|
||||
args_.handle,
|
||||
&one, args_.i_desc, x_data,
|
||||
args_.o_desc, dy_data,
|
||||
args_.c_desc, perf.algo, workspace.get(), perf.memory,
|
||||
&zero, args_.w_desc, dw_data));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ConvGrad<T>::ComputeInputGradient(Tensor* dX, const Tensor* dY, const Tensor* W) const {
|
||||
if (dX == nullptr) return Status::OK();
|
||||
|
||||
// TODO: implement the algoritm search
|
||||
cudnnConvolutionBwdDataAlgoPerf_t perf;
|
||||
perf.algo = kDefaultConvBwdDataAlgo;
|
||||
if (args_.data_type == CUDNN_DATA_HALF) {
|
||||
perf.mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else {
|
||||
perf.mathType = CUDNN_DEFAULT_MATH;
|
||||
}
|
||||
CUDNN_RETURN_IF_ERROR(getWorkspaceSize(args_, perf.algo, &perf.memory));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.c_desc, perf.mathType));
|
||||
|
||||
void* dx_data = dX->template MutableData<T>();
|
||||
const void* dy_data = dY->template Data<T>();
|
||||
const void* w_data = W->template Data<T>();
|
||||
IAllocatorUniquePtr<void> workspace = GetScratchBuffer<void>(perf.memory);
|
||||
|
||||
const auto one = Consts<CudaT>::One;
|
||||
const auto zero = Consts<CudaT>::Zero;
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(
|
||||
cudnnConvolutionBackwardData(
|
||||
args_.handle,
|
||||
&one, args_.w_desc, w_data,
|
||||
args_.o_desc, dy_data,
|
||||
args_.c_desc, perf.algo, workspace.get(), perf.memory,
|
||||
&zero, args_.i_desc, dx_data));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ConvGrad<T>::ComputeBiasGradient(Tensor* dB, const Tensor* dY) const {
|
||||
if (dB == nullptr) return Status::OK();
|
||||
|
||||
const auto one = Consts<CudaT>::One;
|
||||
const auto zero = Consts<CudaT>::Zero;
|
||||
|
||||
void* db_data = dB->template MutableData<T>();
|
||||
const void* dy_data = dY->template Data<T>();
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(
|
||||
cudnnConvolutionBackwardBias(
|
||||
args_.handle,
|
||||
&one, args_.o_desc, dy_data,
|
||||
&zero, args_.b_desc, db_data));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
55
orttraining/orttraining/training_ops/cuda/nn/conv_grad.h
Normal file
55
orttraining/orttraining/training_ops/cuda/nn/conv_grad.h
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cpu/nn/conv_attributes.h"
|
||||
#include "core/providers/cuda/nn/conv.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
struct ConvolutionArgs {
|
||||
cudnnHandle_t handle;
|
||||
cudnnDataType_t data_type;
|
||||
|
||||
CudnnTensor i_desc, o_desc, b_desc;
|
||||
CudnnFilterDescriptor w_desc;
|
||||
CudnnConvolutionDescriptor c_desc;
|
||||
|
||||
ConvolutionArgs() {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ConvGrad final : public CudaKernel {
|
||||
public:
|
||||
using CudaT = typename ToCudaType<T>::MappedType;
|
||||
|
||||
ConvGrad(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) {
|
||||
auto pads_size = conv_attrs_.pads.size();
|
||||
ORT_ENFORCE(pads_size % 2 == 0);
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
protected:
|
||||
mutable ConvolutionArgs args_;
|
||||
Status PrepareArgs(const Tensor& input, const Tensor& output, const Tensor& weight, const Tensor* bias) const;
|
||||
|
||||
ConvAttributes conv_attrs_;
|
||||
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn_742/cudnn-developer-guide/index.html#tensor_ops
|
||||
static constexpr auto kDefaultConvBwdDataAlgo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
||||
static constexpr auto kDefaultConvBwdFilterAlgo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
||||
|
||||
private:
|
||||
Status ComputeWeightGradient(Tensor* dW, const Tensor* dY, const Tensor* X) const;
|
||||
Status ComputeInputGradient(Tensor* dX, const Tensor* dY, const Tensor* W) const;
|
||||
Status ComputeBiasGradient(Tensor* dB, const Tensor* dY) const;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -198,6 +198,8 @@ training_ops_excluded_files = [
|
|||
'math/softmax_grad.cc',
|
||||
'nn/batch_norm_grad.cc',
|
||||
'nn/batch_norm_grad.h',
|
||||
'nn/conv_grad.cc',
|
||||
'nn/conv_grad.h',
|
||||
'optimizer/adam.cc',
|
||||
'optimizer/adam.cu',
|
||||
'reduction/reduction_all.cc',
|
||||
|
|
|
|||
Loading…
Reference in a new issue