diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 4f40a494d1..e17ffe676e 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -175,8 +175,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, gsl::narrow_cast(conv_attrs_.group))); if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); @@ -330,6 +330,7 @@ Status CudnnConvolutionDescriptor::Set( const std::vector& pads, const std::vector& strides, const std::vector& dilations, + int groups, cudnnConvolutionMode_t mode, cudnnDataType_t data_type) { if (!desc_) @@ -344,6 +345,10 @@ Status CudnnConvolutionDescriptor::Set( dilation_dims[i] = gsl::narrow_cast(dilations[i]); } + // This piece of code is copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h + // Setting math_type to CUDNN_DATA_FLOAT for half input + cudnnDataType_t math_type = data_type; + if (data_type == CUDNN_DATA_HALF) math_type = CUDNN_DATA_FLOAT; CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionNdDescriptor( desc_, gsl::narrow_cast(rank), @@ -351,7 +356,16 @@ Status CudnnConvolutionDescriptor::Set( stride_dims.data(), dilation_dims.data(), mode, - data_type)); + math_type)); + + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(desc_, groups)); + + // Copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h + // See Note [behavior of cudnnFind and cudnnGet] at /pytorch/aten/src/ATen/native/cudnn/Conv_v7.cpp + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH)); + if (data_type == CUDNN_DATA_HALF) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH)); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index e562048eea..8f27265e37 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -22,6 +22,7 @@ class CudnnConvolutionDescriptor final { const std::vector& pads, const std::vector& strides, const std::vector& dilations, + int groups, cudnnConvolutionMode_t mode, cudnnDataType_t data_type); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index c71a9b4453..574fb0acd9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -107,10 +107,9 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, - p.dilations, mode, CudnnTensor::GetDataType())); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, - gsl::narrow_cast(conv_transpose_attrs_.group))); + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, + gsl::narrow_cast(conv_transpose_attrs_.group), + mode, CudnnTensor::GetDataType())); if (has_bias) { const auto& b_shape = p.B->Shape(); diff --git a/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json index 265a231e5d..2b2b1efeba 100644 --- a/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json +++ b/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json @@ -24,8 +24,8 @@ 407435603592769928 ], [ - "ConvGrad ai.onnx CPUExecutionProvider", - 551027277226613536 + "ConvGrad com.microsoft CPUExecutionProvider", + 6051867985469399832 ], [ "DropoutGrad com.microsoft CPUExecutionProvider", diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 40b81a6bc6..011890f6ad 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -706,7 +706,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetConvGradient) { } return std::vector{ - NodeDef("ConvGrad", + NodeDef(OpDef{"ConvGrad", kMSDomain, 1}, {GO(0), I(0), I(1)}, outputs, SrcNodeAttributes())}; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 8c7794cba3..6b6e3a03aa 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -468,11 +468,12 @@ void RegisterTrainingOpSchemas() { "Constrain index tensor to int64"); ONNX_CONTRIB_OPERATOR_SCHEMA(ConvGrad) - .SinceVersion(9) + .SetDomain(kMSDomain) + .SinceVersion(1) .Input(0, "dY", "Gradient of output Y", "T") .Input(1, "X", "Input tensor", "T") .Input(2, "W", "Weight tensor", "T") - .Output(0, "dX", "Gradient of input X", "T", OpSchema::Optional) + .Output(0, "dX", "Gradient of X", "T", OpSchema::Optional) .Output(1, "dW", "Gradient of W", "T", OpSchema::Optional) .Output(2, "dB", "Gradient of B", "T", OpSchema::Optional) .AllowUncheckedAttributes() diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 1057df0d03..29b5ae4e52 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -477,8 +477,19 @@ TEST(GradientCheckerTest, FlattenGrad) { GradientChecker gradient_checker; OpDef op_def{"Flatten", kOnnxDomain, 11}; - for (int axis = -3; axis < 3; ++axis) { - gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, {MakeAttribute("axis", int64_t(axis))}); + const std::vector> axis_to_shape = { + {-3, {1, 24}}, + {-2, {2, 12}}, + {-1, {6, 4}}, + {0, {1, 24}}, + {1, {2, 12}}, + {2, {6, 4}}, + {3, {24, 1}}}; + + for (auto& pair : axis_to_shape) { + int axis = pair.first; + const TensorShape& output_shape = pair.second; + gradient_checker.ComputeGradientError(op_def, {shape}, {output_shape}, &max_error, {MakeAttribute("axis", int64_t(axis))}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } } @@ -653,7 +664,6 @@ TEST(GradientCheckerTest, ReluGradDnnl) { } #endif // USE_DNNL -#ifndef USE_CUDA TEST(GradientCheckerTest, CastGrad) { // A dummy test that cast float to float // TODO: add more test here @@ -767,7 +777,6 @@ void MaxpoolGradientCheckerTest(std::vector> } } - TEST(GradientCheckerTest, MaxPoolGrad) { MaxpoolGradientCheckerTest(nullptr); @@ -805,10 +814,10 @@ void ConvGradientCheckerTest(std::vector>* e // 1D convolution { - TensorShape x_shape({2, 1, 5}); - TensorShape w_shape({1, 1, 3}); - TensorShape b_shape({1}); - TensorShape y_shape({2, 1, 5}); + TensorShape x_shape({2, 2, 5}); + TensorShape w_shape({2, 2, 3}); + TensorShape b_shape({2}); + TensorShape y_shape({2, 2, 5}); gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, {MakeAttribute("kernel_shape", std::vector{3}), MakeAttribute("pads", std::vector{1, 1})}, @@ -1201,7 +1210,6 @@ TEST(GradientCheckerTest, AveragePoolGrad) { EXPECT_IS_TINY(max_error); } } -#endif TEST(GradientCheckerTest, TransposeGrad) { float max_error; @@ -1569,7 +1577,7 @@ void TestSoftmaxCrossEntropyGrad(const TensorShape& input_shape, const std::stri GenerateRandomDataWithOneHot(x_datas, {input_shape, input_shape}, {1}); gradient_checker.ComputeGradientError(op_def, {input_shape, {input_shape, false}}, - {{1}, {input_shape, false}}, &max_error, x_datas, + {{}, {input_shape, false}}, &max_error, x_datas, {MakeAttribute("reduction", reduction)}); EXPECT_IS_TINY(max_error); } @@ -1598,7 +1606,7 @@ void TestSparseSoftmaxCrossEntropyGrad(const TensorShape& index_shape, const std TensorInfo index_info(index_shape, false, &transformer_index, DataTypeImpl::GetTensorType()); gradient_checker.ComputeGradientError(op_def, {x_info, index_info}, - {{1}, {logit_shape, false}}, &max_error, + {{}, {logit_shape, false}}, &max_error, {MakeAttribute("reduction", reduction)}); EXPECT_IS_TINY(max_error); } @@ -1613,7 +1621,7 @@ void TestSparseSoftmaxCrossEntropyGrad(const TensorShape& index_shape, const std TensorInfo weight_info(index_shape, false, &transformer_weight); gradient_checker.ComputeGradientError(op_def, {x_info, index_info, weight_info}, - {{1}, {logit_shape, false}}, &max_error, + {{}, {logit_shape, false}}, &max_error, {MakeAttribute("reduction", reduction)}); EXPECT_IS_TINY(max_error); } @@ -2069,7 +2077,7 @@ TEST(GradientCheckerTest, SimplifiedLayerNormGrad) { EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } } -#endif +#endif //USE_CUDA TEST(GradientUtilsTest, InPlaceAccumulatorFloat32) { OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain); @@ -2100,7 +2108,7 @@ TEST(GradientUtilsTest, InPlaceAccumulatorFloat16) { // Didn't implement mixed precision InPlaceAccumulator in CPU test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider}); } -#endif +#endif //defined(USE_CUDA) || defined(USE_ROCM) TEST(GradientUtilsTest, ZeroGradientFloat32) { OpTester test("ZeroGradient", 1, onnxruntime::kMSDomain); @@ -2133,7 +2141,7 @@ TEST(GradientUtilsTest, ZeroGradientFloat16) { test.Run(); } -#endif +#endif // defined(USE_CUDA) || defined(USE_ROCM) TEST(GradientCheckerTest, WhereGrad) { float max_error; diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 76111a047d..ef03ad62e7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -19,6 +19,8 @@ from inspect import signature from onnxruntime.training import _utils, ORTModule import _test_helpers +# Import autocasting libs +from torch.cuda import amp # PyTorch model definitions for tests @@ -507,6 +509,43 @@ def test_gradient_correctness(): assert torch.allclose(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) +@pytest.mark.parametrize("use_fp16", [False, True]) +def test_gradient_correctness_conv1d(use_fp16): + class NeuralNetConv1D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding=0, groups=1): + super(NeuralNetConv1D, self).__init__() + self.conv1 = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=groups) + self.conv2 = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=groups) + + def forward(self, input): + out = self.conv1(input.permute(0, 2, 1).contiguous()) + out = self.conv2(out).permute(0, 2, 1).contiguous() + return out + + device = 'cuda' + N, seq_len, C_in, C_out, kernel_size = 32, 128, 1536, 1536, 3 + pt_model = NeuralNetConv1D(C_in, C_out, kernel_size, padding=1).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, x): + with amp.autocast(use_fp16): + prediction = model(x) + loss = prediction.sum() + loss.backward() + return prediction + + for step in range(10): + x = torch.randn(N, seq_len, C_in, device=device, requires_grad=True) + pt_prediction = run_step(pt_model, x) + ort_prediction = run_step(ort_model, x) + + if use_fp16: + assert torch.allclose(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=1e-2, atol=1e-2) + else: + assert torch.allclose(ort_prediction, pt_prediction, atol=1e-5) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-3, atol=1e-3) + def test_module_with_non_differential_output(): device = 'cuda' N, D_in, H, D_out = 32, 128, 64, 10 diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 7e7a2b484f..e411def4c6 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -32,7 +32,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad); @@ -128,7 +128,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc index fa58bf0def..f9d1652e89 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc @@ -250,9 +250,11 @@ Status ConvGrad::Compute(OpKernelContext* context) const { return Status::OK(); } -ONNX_CPU_OPERATOR_KERNEL( +ONNX_OPERATOR_KERNEL_EX( ConvGrad, - 9, + kMSDomain, + 1, + kCpuExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), ConvGrad); diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 21fde3adcf..d047f0c4ef 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -64,6 +64,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LogSoftmaxGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DropoutGrad); @@ -236,6 +239,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc new file mode 100644 index 0000000000..b1e9f7a6df --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/nn/conv_grad.h" + +#include "core/providers/common.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_GRADIENT_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ConvGrad, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ConvGrad); + +REGISTER_GRADIENT_KERNEL_TYPED(float) +REGISTER_GRADIENT_KERNEL_TYPED(double) +REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16) + +cudnnStatus_t getWorkspaceSize( + const ConvolutionArgs& args, + cudnnConvolutionBwdDataAlgo_t algo, size_t* sz) { + return cudnnGetConvolutionBackwardDataWorkspaceSize( + args.handle, + args.w_desc, + args.o_desc, + args.c_desc, + args.i_desc, + algo, + sz); +} + +cudnnStatus_t getWorkspaceSize( + const ConvolutionArgs& args, + cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz) { + return cudnnGetConvolutionBackwardFilterWorkspaceSize( + args.handle, + args.i_desc, + args.o_desc, + args.c_desc, + args.w_desc, + algo, + sz); +} + +// TODO: we can cache the descriptors, and only update if the input shape changes +template +Status ConvGrad::PrepareArgs(const Tensor& input, const Tensor& output, const Tensor& weight, const Tensor* bias) const { + const TensorShape& i_shape = input.Shape(); + std::vector i_dims = i_shape.GetDims(); + + const TensorShape& o_shape = output.Shape(); + std::vector o_dims = o_shape.GetDims(); + + const TensorShape& w_shape = weight.Shape(); + std::vector w_dims = w_shape.GetDims(); + + // Update Attributes + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&input, &weight)); + + std::vector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); + auto rank = kernel_shape.size(); + + std::vector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } + + std::vector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } + + std::vector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } + + // cudnn only takes 4D or 5D input, so pad dimensions if needed + if (rank < 2) { + i_dims.push_back(1); + o_dims.push_back(1); + w_dims.push_back(1); + + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + + args_.handle = CudnnHandle(); + args_.data_type = CudnnTensor::GetDataType(); + ORT_RETURN_IF_ERROR(args_.i_desc.Set(i_dims, args_.data_type)); + ORT_RETURN_IF_ERROR(args_.o_desc.Set(o_dims, args_.data_type)); + ORT_RETURN_IF_ERROR(args_.w_desc.Set(w_dims, args_.data_type)); + ORT_RETURN_IF_ERROR(args_.c_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), + CUDNN_CROSS_CORRELATION, args_.data_type)); + + if (bias) { + const TensorShape& b_shape = bias->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + std::vector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = b_shape[0]; + ORT_RETURN_IF_ERROR(args_.b_desc.Set(b_dims, args_.data_type)); + } + + return Status::OK(); +} + +template +Status ConvGrad::ComputeInternal(OpKernelContext* context) const { + const Tensor* dY = context->Input(0); + const Tensor* X = context->Input(1); + const Tensor* W = context->Input(2); + + const int64_t M = W->Shape()[0]; + + Tensor* dX = context->Output(0, X->Shape()); + Tensor* dW = context->Output(1, W->Shape()); + Tensor* dB = context->Output(2, {M}); + + ORT_RETURN_IF_ERROR(PrepareArgs(*dX, *dY, *dW, dB)); + + ORT_RETURN_IF_ERROR(ComputeWeightGradient(dW, dY, X)); + ORT_RETURN_IF_ERROR(ComputeInputGradient(dX, dY, W)); + ORT_RETURN_IF_ERROR(ComputeBiasGradient(dB, dY)); + + return Status::OK(); +} + +template +Status ConvGrad::ComputeWeightGradient(Tensor* dW, const Tensor* dY, const Tensor* X) const { + if (dW == nullptr) return Status::OK(); + + // TODO: implement the algoritm search + cudnnConvolutionBwdFilterAlgoPerf_t perf; + perf.algo = kDefaultConvBwdFilterAlgo; + if (args_.data_type == CUDNN_DATA_HALF) { + perf.mathType = CUDNN_TENSOR_OP_MATH; + } else { + perf.mathType = CUDNN_DEFAULT_MATH; + } + CUDNN_RETURN_IF_ERROR(getWorkspaceSize(args_, perf.algo, &perf.memory)); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.c_desc, perf.mathType)); + + void* dw_data = dW->template MutableData(); + const void* dy_data = dY->template Data(); + const void* x_data = X->template Data(); + IAllocatorUniquePtr workspace = GetScratchBuffer(perf.memory); + + const auto one = Consts::One; + const auto zero = Consts::Zero; + + CUDNN_RETURN_IF_ERROR( + cudnnConvolutionBackwardFilter( + args_.handle, + &one, args_.i_desc, x_data, + args_.o_desc, dy_data, + args_.c_desc, perf.algo, workspace.get(), perf.memory, + &zero, args_.w_desc, dw_data)); + + return Status::OK(); +} + +template +Status ConvGrad::ComputeInputGradient(Tensor* dX, const Tensor* dY, const Tensor* W) const { + if (dX == nullptr) return Status::OK(); + + // TODO: implement the algoritm search + cudnnConvolutionBwdDataAlgoPerf_t perf; + perf.algo = kDefaultConvBwdDataAlgo; + if (args_.data_type == CUDNN_DATA_HALF) { + perf.mathType = CUDNN_TENSOR_OP_MATH; + } else { + perf.mathType = CUDNN_DEFAULT_MATH; + } + CUDNN_RETURN_IF_ERROR(getWorkspaceSize(args_, perf.algo, &perf.memory)); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.c_desc, perf.mathType)); + + void* dx_data = dX->template MutableData(); + const void* dy_data = dY->template Data(); + const void* w_data = W->template Data(); + IAllocatorUniquePtr workspace = GetScratchBuffer(perf.memory); + + const auto one = Consts::One; + const auto zero = Consts::Zero; + + CUDNN_RETURN_IF_ERROR( + cudnnConvolutionBackwardData( + args_.handle, + &one, args_.w_desc, w_data, + args_.o_desc, dy_data, + args_.c_desc, perf.algo, workspace.get(), perf.memory, + &zero, args_.i_desc, dx_data)); + + return Status::OK(); +} + +template +Status ConvGrad::ComputeBiasGradient(Tensor* dB, const Tensor* dY) const { + if (dB == nullptr) return Status::OK(); + + const auto one = Consts::One; + const auto zero = Consts::Zero; + + void* db_data = dB->template MutableData(); + const void* dy_data = dY->template Data(); + + CUDNN_RETURN_IF_ERROR( + cudnnConvolutionBackwardBias( + args_.handle, + &one, args_.o_desc, dy_data, + &zero, args_.b_desc, db_data)); + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h new file mode 100644 index 0000000000..5024a056c7 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/cuda/nn/conv.h" + +namespace onnxruntime { +namespace cuda { + +struct ConvolutionArgs { + cudnnHandle_t handle; + cudnnDataType_t data_type; + + CudnnTensor i_desc, o_desc, b_desc; + CudnnFilterDescriptor w_desc; + CudnnConvolutionDescriptor c_desc; + + ConvolutionArgs() {} +}; + +template +class ConvGrad final : public CudaKernel { + public: + using CudaT = typename ToCudaType::MappedType; + + ConvGrad(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { + auto pads_size = conv_attrs_.pads.size(); + ORT_ENFORCE(pads_size % 2 == 0); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + mutable ConvolutionArgs args_; + Status PrepareArgs(const Tensor& input, const Tensor& output, const Tensor& weight, const Tensor* bias) const; + + ConvAttributes conv_attrs_; + + // https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn_742/cudnn-developer-guide/index.html#tensor_ops + static constexpr auto kDefaultConvBwdDataAlgo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + static constexpr auto kDefaultConvBwdFilterAlgo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + + private: + Status ComputeWeightGradient(Tensor* dW, const Tensor* dY, const Tensor* X) const; + Status ComputeInputGradient(Tensor* dX, const Tensor* dY, const Tensor* W) const; + Status ComputeBiasGradient(Tensor* dB, const Tensor* dY) const; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 06f09b17c0..4d3bea4c81 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -198,6 +198,8 @@ training_ops_excluded_files = [ 'math/softmax_grad.cc', 'nn/batch_norm_grad.cc', 'nn/batch_norm_grad.h', + 'nn/conv_grad.cc', + 'nn/conv_grad.h', 'optimizer/adam.cc', 'optimizer/adam.cu', 'reduction/reduction_all.cc',