From fca81cc5d5672dca039bece842f785b38bbd8f4f Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 24 Aug 2023 09:08:06 -0700 Subject: [PATCH] ConvTransposeGrad CUDA Kernel (#17201) --- cmake/onnxruntime_rocm_hipify.cmake | 4 + .../core/graph/gradient_builder.cc | 17 + .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../core/graph/training_op_defs.cc | 15 + .../test/gradient/gradient_ops_test.cc | 198 ++++++++++ .../python/orttraining_test_ortmodule_api.py | 165 +++++++- .../cuda/conv_transpose_grad_test.cc | 360 ++++++++++++++++++ .../cuda/cuda_training_kernels.cc | 6 + .../training_ops/cuda/nn/conv_grad.cc | 230 ----------- .../training_ops/cuda/nn/conv_grad.h | 38 +- .../training_ops/cuda/nn/conv_shared.cc | 275 +++++++++++++ .../training_ops/cuda/nn/conv_shared.h | 84 ++++ .../cuda/nn/conv_transpose_grad.cc | 308 +++++++++++++++ .../cuda/nn/conv_transpose_grad.h | 41 ++ 15 files changed, 1475 insertions(+), 268 deletions(-) create mode 100644 orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc create mode 100644 orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc create mode 100644 orttraining/orttraining/training_ops/cuda/nn/conv_shared.h create mode 100644 orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc create mode 100644 orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index c8592a4019..ecee52f642 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -201,6 +201,10 @@ set(training_ops_excluded_files "reduction/reduction_ops.cc" # no double type support "cuda_training_kernels.cc" "cuda_training_kernels.h" + "nn/conv_shared.cc" + "nn/conv_shared.h" + "nn/conv_transpose_grad.cc" + "nn/conv_transpose_grad.h" ) function(auto_set_source_files_hip_language) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index f8e0545574..429ce6d968 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2070,5 +2070,22 @@ IMPLEMENT_GRADIENT_BUILDER(GetLeakyReluGradient) { {GO(0), O(0)}, {GI(0)}, SrcNodeAttributes())}; } +IMPLEMENT_GRADIENT_BUILDER(GetConvTransposeGradient) { + std::vector outputs; + for (int i = 0; i < GetSrcNodeInputSize(); i++) { + if (IsGradientRequiredForSrcNodeInput(i)) { + outputs.push_back(GI(i)); + } else { + outputs.push_back(ArgDef("", nullptr)); + } + } + + return std::vector{ + NodeDef(OpDef{"ConvTransposeGrad", kMSDomain, 1}, + {GO(0), I(0), I(1)}, + outputs, + SrcNodeAttributes())}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index ca86777d36..84880b8850 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -88,6 +88,7 @@ DECLARE_GRADIENT_BUILDER(GetLSTMGradient) DECLARE_GRADIENT_BUILDER(GetGRUGradient) DECLARE_GRADIENT_BUILDER(GetReciprocalGradient) DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient) +DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index cc9a762ff8..c84fc0d360 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -120,6 +120,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("GRUTraining", GetGRUGradient); REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient); REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient); + REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 60867accb8..eb84865fd7 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4908,6 +4908,21 @@ Return true if all elements are true and false otherwise. } } }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(ConvTransposeGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "dY", "Gradient of output Y", "T") + .Input(1, "X", "Input tensor", "T") + .Input(2, "W", "Weight tensor", "T") + .Output(0, "dX", "Gradient of X", "T", OpSchema::Optional) + .Output(1, "dW", "Gradient of W", "T", OpSchema::Optional) + .Output(2, "dB", "Gradient of B", "T", OpSchema::Optional) + .AllowUncheckedAttributes() + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors."); } } // namespace training diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 39cc6bdd11..d4e18dbfd2 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3039,6 +3039,204 @@ TEST(GradientCheckerTest, LeakyReluGrad) { UnaryOpGradientTest("LeakyRelu", kOnnxDomain, 16, nullptr, &transformer); } +#ifdef USE_CUDA +void ConvTransposeGradientCheckerTest(std::vector>* execution_providers) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"ConvTranspose"}; + + float error_tolerance = 1e-1f; + + // 1D convolution + { + TensorShape x_shape({2, 2, 5}); + TensorShape w_shape({2, 2, 3}); + TensorShape b_shape({2}); + TensorShape y_shape({2, 2, 5}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3}), MakeAttribute("pads", std::vector{1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 1D strided convolution + { + TensorShape x_shape({2, 1, 7}); + TensorShape w_shape({1, 1, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 13}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3}), MakeAttribute("pads", std::vector{1, 1}), + MakeAttribute("strides", std::vector{2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 1D pointwise convolution (with padding) + { + TensorShape x_shape({2, 1, 5}); + TensorShape w_shape({1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 3}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1}), MakeAttribute("pads", std::vector{1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 1D pointwise convolution (no padding) + { + TensorShape x_shape({2, 1, 5}); + TensorShape w_shape({1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 5}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1}), MakeAttribute("pads", std::vector{0, 0})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D convolution + { + TensorShape x_shape({1, 1, 3, 3}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({1, 1, 3, 3}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D convolution + { + TensorShape x_shape({2, 1, 5, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 5, 5}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D pointwise convolution (with padding) + { + TensorShape x_shape({1, 1, 3, 3}); + TensorShape w_shape({1, 1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({1, 1, 1, 1}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1, 1}), + MakeAttribute("pads", std::vector{1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D pointwise convolution (no padding) + { + TensorShape x_shape({1, 1, 3, 3}); + TensorShape w_shape({1, 1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({1, 1, 3, 3}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1, 1}), + MakeAttribute("pads", std::vector{0, 0, 0, 0})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D strided convolution + { + TensorShape x_shape({2, 1, 7, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 13, 9}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1}), MakeAttribute("strides", std::vector{2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D dilated convolution (no padding) + { + TensorShape x_shape({2, 1, 5, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 9, 9}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{0, 0, 0, 0}), + MakeAttribute("dilations", std::vector{2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D dilated convolution (with padding) + { + TensorShape x_shape({2, 1, 7, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 9, 7}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1}), + MakeAttribute("dilations", std::vector{2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 3D convolution + { + TensorShape x_shape({2, 1, 5, 5, 5}); + TensorShape w_shape({1, 1, 3, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 5, 5, 5}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 3D strided convolution + { + TensorShape x_shape({2, 1, 7, 5, 5}); + TensorShape w_shape({1, 1, 3, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 13, 9, 9}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1, 1, 1}), + MakeAttribute("strides", std::vector{2, 2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } +} + +TEST(GradientCheckerTest, ConvTransposeGrad) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + ConvTransposeGradientCheckerTest(&execution_providers); +} +#endif // USE_CUDA + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index b62e959556..0a398bd7b4 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6096,7 +6096,6 @@ def test_ortmodule_log_level_control(log_level, caplog): found_missing_inference_log = False for record in caplog.records: msg = record.getMessage() - print(msg) if "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing" in msg: found_missing_inference_log = True break @@ -6205,3 +6204,167 @@ def test_leakyrelu_gradient(): _test_helpers.assert_values_are_close(pt_prediction, ort_prediction) _test_helpers.assert_values_are_close(pt_loss, ort_loss) _test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad) + + +@pytest.mark.skipif( + os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" +) +@pytest.mark.parametrize("use_fp16", [False, True]) +@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) +def test_conv_transpose_gradient(use_fp16, conv_algo_search): + class ChainedTransposedConv(nn.Module): + def __init__(self): + super().__init__() + + # Transposed Convolution 1D + self.conv1d_transpose = nn.ConvTranspose1d( + in_channels=4, out_channels=2, kernel_size=3, stride=2, padding=1 + ) + self.relu1 = nn.ReLU() + + # Transposed Convolution 2D + self.conv2d_transpose = nn.ConvTranspose2d( + in_channels=2, out_channels=3, kernel_size=3, stride=2, padding=1 + ) + self.relu2 = nn.ReLU() + + # Transposed Convolution 3D + self.conv3d_transpose = nn.ConvTranspose3d( + in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1 + ) + self.relu3 = nn.ReLU() + + def forward(self, x): + out1d = self.relu1(self.conv1d_transpose(x)) + out2d = self.relu2(self.conv2d_transpose(out1d.unsqueeze(2))) + out3d = self.relu3(self.conv3d_transpose(out2d.unsqueeze(2))) + return out3d.squeeze(2) + + if conv_algo_search is not None: + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search + + def run_step(model, x): + with amp.autocast(use_fp16): + loss = model(x).sum() + loss.backward() + + return ( + x.grad, + model.conv1d_transpose.weight.grad, + model.conv1d_transpose.bias.grad, + model.conv2d_transpose.weight.grad, + model.conv2d_transpose.bias.grad, + model.conv3d_transpose.weight.grad, + model.conv3d_transpose.bias.grad, + ) + + device = "cuda" + pt_model = ChainedTransposedConv().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(1, 4, 8, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_grads = run_step(pt_model, pt_x) + ort_grads = run_step(ort_model, ort_x) + + for pt_grad, ort_grad in zip(pt_grads, ort_grads): + if use_fp16: + assert torch.allclose(pt_grad, ort_grad, atol=1e-3, rtol=1e-3) + else: + assert torch.allclose(pt_grad, ort_grad) + + if conv_algo_search is not None: + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] + + +@pytest.mark.skipif( + os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" +) +@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) +def test_conv_transpose_gradient_with_groups(conv_algo_search): + class TransposedConv3DWithGroups(nn.Module): + def __init__(self): + super().__init__() + # in_channels, out_channels, kernel_size, stride, padding + self.conv_transpose = nn.ConvTranspose3d( + in_channels=6, out_channels=4, kernel_size=3, stride=2, padding=1, groups=2 + ) + + def forward(self, x): + return self.conv_transpose(x) + + if conv_algo_search is not None: + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search + + def run_step(model, x): + loss = model(x).sum() + loss.backward() + + return ( + x.grad, + model.conv_transpose.weight.grad, + model.conv_transpose.bias.grad, + ) + + device = "cuda" + pt_model = TransposedConv3DWithGroups().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(1, 6, 8, 16, 16, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_grads = run_step(pt_model, pt_x) + ort_grads = run_step(ort_model, ort_x) + + for pt_grad, ort_grad in zip(pt_grads, ort_grads): + assert torch.allclose(pt_grad, ort_grad) + + if conv_algo_search is not None: + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] + + +@pytest.mark.skipif( + os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" +) +@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) +def test_conv_transpose_gradient_with_strides_padding_and_dilation(conv_algo_search): + class ConvTransposeComplexModel(nn.Module): + def __init__(self): + super().__init__() + self.conv_transpose = nn.ConvTranspose3d( + 16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2), dilation=(1, 2, 1) + ) + self.param = nn.Parameter(torch.randn(20, 33, 21, 50, 97)) + + def forward(self, x): + return self.conv_transpose(x) * self.param + + if conv_algo_search is not None: + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search + + def run_step(model, x): + loss = model(x).sum() + loss.backward() + + return ( + x.grad, + model.conv_transpose.weight.grad, + model.conv_transpose.bias.grad, + ) + + device = "cuda" + pt_model = ConvTransposeComplexModel().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)).to(device) + + pt_x = torch.randn(20, 16, 10, 50, 100, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_grads = run_step(pt_model, pt_x) + ort_grads = run_step(ort_model, ort_x) + + for pt_grad, ort_grad in zip(pt_grads, ort_grads): + assert torch.allclose(pt_grad, ort_grad, atol=1e-2, rtol=1e-2) + + if conv_algo_search is not None: + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] diff --git a/orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc new file mode 100644 index 0000000000..18c5ff9437 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc @@ -0,0 +1,360 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime::contrib::test { + +using namespace onnxruntime::test; + +#if USE_CUDA +namespace { + +struct ConvTransposeGradOpAttributes { + std::vector dilations; + int64_t group; + std::vector kernel_shape; + std::vector pads; + std::vector strides; +}; + +void TestConvTransposeGradOp(const ConvTransposeGradOpAttributes& attributes, + const std::vector>& inputs, + const std::vector>& input_shapes, + const std::vector>& outputs, + const std::vector>& output_shapes, + bool is_half = false) { + OpTester test("ConvTransposeGrad", 1, kMSDomain); + test.AddAttribute("group", attributes.group); + test.AddAttribute("kernel_shape", attributes.kernel_shape); + test.AddAttribute("pads", attributes.pads); + + if (!attributes.dilations.empty()) { + test.AddAttribute("dilations", attributes.dilations); + } + + if (!attributes.strides.empty()) { + test.AddAttribute("strides", attributes.strides); + } + + if (is_half) { + std::vector dY_half(inputs[0].size()); + ConvertFloatToMLFloat16(inputs[0].data(), dY_half.data(), static_cast(inputs[0].size())); + test.AddInput("dY", input_shapes[0], dY_half); + + std::vector X_half(inputs[1].size()); + ConvertFloatToMLFloat16(inputs[1].data(), X_half.data(), static_cast(inputs[1].size())); + test.AddInput("X", input_shapes[1], X_half); + + std::vector W_half(inputs[2].size()); + ConvertFloatToMLFloat16(inputs[2].data(), W_half.data(), static_cast(inputs[2].size())); + test.AddInput("W", input_shapes[2], W_half); + + std::vector dX_half(outputs[0].size()); + ConvertFloatToMLFloat16(outputs[0].data(), dX_half.data(), static_cast(outputs[0].size())); + test.AddOutput("dX", output_shapes[0], dX_half); + + std::vector dW_half(outputs[1].size()); + ConvertFloatToMLFloat16(outputs[1].data(), dW_half.data(), static_cast(outputs[1].size())); + test.AddOutput("dW", output_shapes[1], dW_half); + + if (outputs.size() >= 3) { + std::vector dB_half(outputs[2].size()); + ConvertFloatToMLFloat16(outputs[2].data(), dB_half.data(), static_cast(outputs[2].size())); + test.AddOutput("dB", output_shapes[2], dB_half); + } + } else { + test.AddInput("dY", input_shapes[0], inputs[0]); + test.AddInput("X", input_shapes[1], inputs[1]); + test.AddInput("W", input_shapes[2], inputs[2]); + + test.AddOutput("dX", output_shapes[0], outputs[0]); + test.AddOutput("dW", output_shapes[1], outputs[1]); + + if (outputs.size() >= 3) { + test.AddOutput("dB", output_shapes[2], outputs[2]); + } + } + + test.Run(); +} + +} // namespace + +TEST(ConvTransposeGradTest, ConvTranspose1DDefaultAttributes) { + ConvTransposeGradOpAttributes attrs = { + std::vector{1}, // dilations + 1, // group + std::vector{2}, // kernel_shape + std::vector{0, 0}, // pads + std::vector{1}, // strides + }; + + std::vector dY(12, 1.0f); + std::vector dY_shape = {1, 2, 6}; + std::vector X = {0.1868f, -0.1679f, 1.2677f, 2.1288f, -0.0331f, + 1.0454f, 0.7722f, 0.2963f, -0.8684f, -0.0547f}; + std::vector X_shape = {1, 2, 5}; + std::vector W = {0.0847f, -0.0066f, + 0.1212f, 0.2317f, + -0.4975f, 0.2762f, + -0.2644f, 0.3210f}; + std::vector W_shape = {2, 2, 2}; + std::vector dX = {0.4309f, 0.4309f, 0.4309f, 0.4309f, 0.4309f, + -0.1647f, -0.1647f, -0.1647f, -0.1647f, -0.1647f}; + std::vector dX_shape = X_shape; + std::vector dW = {3.3823f, 3.3823f, + 3.3823f, 3.3823f, + 1.1908f, 1.1908f, + 1.1908f, 1.1908f}; + std::vector dW_shape = W_shape; + std::vector dB = {6.f, 6.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose1DStrideAndPadding) { + ConvTransposeGradOpAttributes attrs = { + std::vector{1}, // dilations + 1, // group + std::vector{2}, // kernel_shape + std::vector{2, 2}, // pads + std::vector{2}, // strides + }; + + std::vector dY(12, 1.0f); + std::vector dY_shape = {1, 2, 6}; + std::vector X = {-0.0254f, -1.4303f, -0.1568f, 1.2318f, -0.8365f, + 2.0836f, -1.0181f, -0.7539f, 0.4484f, -0.5799f}; + std::vector X_shape = {1, 2, 5}; + std::vector W = {-0.1438f, 0.2386f, + -0.3085f, 0.1149f, + -0.1653f, -0.0707f, + -0.1479f, -0.0918f}; + std::vector W_shape = {2, 2, 2}; + std::vector dX = {0.0000f, -0.0988f, -0.0988f, -0.0988f, 0.0000f, + 0.0000f, -0.4757f, -0.4757f, -0.4757f, 0.0000f}; + std::vector dX_shape = X_shape; + std::vector dW = {-0.3553f, -0.3553f, + -0.3553f, -0.3553f, + -1.3236f, -1.3236f, + -1.3236f, -1.3236f}; + std::vector dW_shape = W_shape; + std::vector dB = {6.f, 6.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose1D) { + ConvTransposeGradOpAttributes attrs = { + std::vector{2}, // dilations + 2, // group + std::vector{3}, // kernel_shape + std::vector{2, 2}, // pads + std::vector{2}, // strides + }; + + std::vector dY(38, 1.0f); + std::vector dY_shape = {1, 2, 19}; + std::vector X = {0.2816f, 1.4660f, 0.1002f, -0.2460f, -0.1027f, 0.1228f, -0.8516f, -1.0246f, -0.6576f, -1.0280f, + 0.1093f, 0.1447f, 1.1279f, 0.1085f, -0.3438f, -0.6224f, -0.0902f, 2.2791f, -2.1910f, 1.9736f}; + std::vector X_shape = {1, 2, 10}; + std::vector W = {-0.1050f, -0.0622f, -0.3632f, + -0.3861f, -0.0134f, -0.0277f}; + std::vector W_shape = {2, 1, 3}; + std::vector dX = {-0.4254f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.1672f, + -0.0411f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.3995f}; + std::vector dX_shape = X_shape; + std::vector dW = {-2.2215f, -1.9400f, -0.9120f, + 2.3863f, 2.4956f, 0.5220f}; + std::vector dW_shape = W_shape; + std::vector dB = {19.f, 19.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose2DDefaultAttributes) { + ConvTransposeGradOpAttributes attrs = { + std::vector{1, 1}, // dilations + 1, // group + std::vector{3, 3}, // kernel_shape + std::vector{0, 0, 0, 0}, // pads + std::vector{1, 1}, // strides + }; + + std::vector dY(98, 1.0f); + std::vector dY_shape = {1, 2, 7, 7}; + std::vector X = {1.1371f, -0.1498f, -1.7541f, -0.7585f, 1.6009f, -0.7496f, 0.1535f, -0.2533f, -1.0811f, 0.9760f, + -0.2528f, 0.1820f, -1.7450f, 0.1632f, -0.3469f, 1.1150f, -2.6888f, -0.1632f, -0.3269f, 0.6904f, + 1.3036f, 0.7883f, 0.4459f, 0.1223f, 0.1576f, -0.8187f, 0.2281f, 1.5320f, 1.2643f, -0.5163f, + 1.0677f, -0.2141f, 1.2992f, -2.1865f, -0.6346f, 0.8938f, 0.8346f, -2.7397f, 0.9223f, 0.8166f, + 1.1736f, -1.3644f, 0.0316f, -1.2904f, 0.7062f, 0.2470f, 0.4559f, 0.8493f, 1.0519f, 0.9915f}; + std::vector X_shape = {1, 2, 5, 5}; + std::vector W = {0.0761f, 0.0270f, -0.1677f, 0.1803f, -0.0824f, -0.0285f, + 0.2098f, -0.0569f, -0.1514f, 0.0338f, -0.1962f, -0.2169f, + 0.0432f, -0.1977f, -0.0814f, -0.1866f, -0.1574f, -0.0198f, + 0.0097f, 0.0019f, -0.1204f, 0.2018f, -0.1750f, -0.0549f, + -0.0687f, -0.1269f, 0.1913f, 0.1331f, -0.0632f, 0.0821f, + 0.0127f, 0.1761f, -0.0883f, -0.1370f, 0.1472f, 0.0690f}; + std::vector W_shape = {2, 2, 3, 3}; + std::vector dX = {-0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, + -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, + -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, + 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, + 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f}; + std::vector dX_shape = X_shape; + std::vector dW = {-1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, + -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, + -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, + 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, + 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, + 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f}; + std::vector dW_shape = W_shape; + std::vector dB = {49.f, 49.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose2D) { + ConvTransposeGradOpAttributes attrs = { + std::vector{2, 2}, // dilations + 2, // group + std::vector{3, 3}, // kernel_shape + std::vector{2, 2, 2, 2}, // pads + std::vector{2, 2}, // strides + }; + + std::vector dY(162U, 1.0f); + std::vector dY_shape = {1, 2, 9, 9}; + std::vector X = {-1.0158f, 0.1709f, -0.1660f, 0.3881f, 0.4017f, 1.5497f, 1.1205f, 0.2553f, -0.4359f, -0.0467f, + 1.1374f, -0.0713f, 0.2248f, 0.8915f, -0.7239f, 0.1679f, -1.5604f, -0.8521f, 0.8966f, 3.3743f, + -0.5516f, 0.2516f, -0.4091f, -0.9868f, 0.3008f, 1.1066f, -0.7039f, -1.5273f, -0.3666f, 0.9392f, + 0.1264f, -1.6604f, -1.4810f, 0.6654f, -0.2007f, -1.0660f, -0.5420f, -0.7030f, 0.0411f, 2.1082f, + -0.7995f, 0.2422f, 1.2848f, -0.1747f, 1.7935f, -0.1123f, -0.6668f, -2.2383f, 1.5419f, -2.7614f}; + std::vector X_shape = {1, 2, 5, 5}; + std::vector W = {-0.2057f, -0.0411f, 0.0277f, 0.2221f, 0.1901f, 0.1435f, + -0.2249f, 0.3299f, -0.2203f, -0.1013f, -0.3326f, 0.1005f, + -0.0536f, 0.3067f, 0.3297f, 0.2728f, 0.1649f, -0.2548f}; + std::vector W_shape = {2, 1, 3, 3}; + std::vector dX = {0.4431f, 0.4403f, 0.4403f, 0.4403f, 0.5171f, 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, + 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, + 0.3202f, 0.3366f, 0.3366f, 0.3366f, 0.1654f, 0.5465f, 0.7658f, 0.7658f, 0.7658f, 0.6908f, + 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, + 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, 0.4043f, 0.2494f, 0.2494f, 0.2494f, -0.1808f}; + std::vector dX_shape = X_shape; + std::vector dW = {2.2293f, 4.5327f, 1.6281f, 3.0240f, 4.3115f, 1.0052f, + 3.8675f, 5.7067f, 2.7011f, -2.7512f, -4.6026f, -5.5423f, + -4.4098f, -5.1546f, -7.0335f, -0.2852f, -0.9177f, -5.5580f}; + std::vector dW_shape = W_shape; + std::vector dB = {81.f, 81.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose3D) { + ConvTransposeGradOpAttributes attrs = { + std::vector{2, 2, 2}, // dilations + 2, // group + std::vector{2, 2, 2}, // kernel_shape + std::vector{2, 2, 2, 2, 2, 2}, // pads + std::vector{2, 2, 2}, // strides + }; + + std::vector dY(250U, 1.0f); + std::vector dY_shape = {1, 2, 5, 5, 5}; + std::vector X = {-0.2396f, 0.4280f, -1.3505f, -0.4366f, -1.3296f, 0.3531f, 0.0645f, -1.5480f, + -1.7464f, -0.9160f, 1.5065f, -0.0788f, 0.0487f, 2.4641f, 0.3855f, 2.0499f, + 0.7068f, -0.8076f, -0.4442f, 0.1003f, -0.5056f, -0.1430f, -0.3744f, -0.2637f, + -1.1012f, 1.0213f, 0.0503f, 0.0147f, -0.3664f, 0.8834f, -1.1478f, -0.8221f, + -0.5649f, -0.4224f, -0.6779f, -0.9363f, 1.1972f, 0.2094f, 0.5676f, -0.2718f, + -0.1678f, -0.4178f, -0.4672f, 0.2777f, -0.7953f, -0.5603f, -2.8694f, 1.5743f, + -0.5057f, -0.2529f, 0.5894f, -0.3980f, -0.6719f, -0.3425f, 0.0821f, 0.8672f, + 0.7218f, 1.5519f, 1.6513f, -1.1956f, 0.8471f, 0.4295f, -1.3917f, -1.2202f, + 0.1054f, -2.2191f, -0.9546f, 1.1750f, -2.3637f, 1.6297f, -0.5796f, 0.3850f, + 0.9287f, -0.3492f, -0.7284f, 0.2987f, -0.7534f, 0.7747f, -1.3198f, -0.3633f, + 1.8635f, -0.3187f, 0.9032f, -0.6083f, -0.4236f, -0.1929f, -1.1715f, -0.5591f, + -1.8290f, -1.1503f, 0.1430f, 0.6048f, -0.3148f, 1.0638f, -0.2946f, -0.4990f, + -1.4443f, -0.7757f, -1.5374f, -0.4567f, -0.2998f, 0.0521f, 1.6293f, -0.6720f, + -0.0102f, -0.6598f, 0.5005f, 0.4203f, 1.3911f, 1.5988f, 0.3991f, 1.4931f, + 0.9741f, 0.3557f, 0.1088f, -1.1806f, 1.1115f, -1.3283f, 1.7235f, 0.4177f, + 0.7992f, -1.7248f, -0.5339f, -0.3153f, 0.1379f, 0.7493f, 0.3028f, -0.9473f}; + std::vector X_shape = {1, 2, 4, 4, 4}; + std::vector W = {-0.1093f, -0.0511f, 0.1132f, 0.3369f, -0.3531f, -0.1766f, 0.0628f, 0.2118f, + 0.3068f, 0.3217f, -0.2903f, -0.1633f, -0.3261f, -0.0990f, 0.2497f, -0.1553f}; + std::vector W_shape = {2, 1, 2, 2, 2}; + std::vector dX = {0.2118f, 0.2746f, 0.2746f, 0.0628f, 0.0352f, -0.2550f, -0.2550f, -0.2902f, + 0.0352f, -0.2550f, -0.2550f, -0.2902f, -0.1766f, -0.5297f, -0.5297f, -0.3531f, + 0.5487f, 0.7247f, 0.7247f, 0.1760f, 0.3210f, 0.0346f, 0.0346f, -0.2864f, + 0.3210f, 0.0346f, 0.0346f, -0.2864f, -0.2277f, -0.6901f, -0.6901f, -0.4624f, + 0.5487f, 0.7247f, 0.7247f, 0.1760f, 0.3210f, 0.0346f, 0.0346f, -0.2864f, + 0.3210f, 0.0346f, 0.0346f, -0.2864f, -0.2277f, -0.6901f, -0.6901f, -0.4624f, + 0.3369f, 0.4501f, 0.4501f, 0.1132f, 0.2858f, 0.2897f, 0.2897f, 0.0038f, + 0.2858f, 0.2897f, 0.2897f, 0.0038f, -0.0511f, -0.1604f, -0.1604f, -0.1093f, + -0.1553f, 0.0944f, 0.0944f, 0.2497f, -0.2542f, -0.3307f, -0.3307f, -0.0765f, + -0.2542f, -0.3307f, -0.3307f, -0.0765f, -0.0990f, -0.4251f, -0.4251f, -0.3261f, + -0.3185f, -0.3592f, -0.3592f, -0.0407f, -0.0958f, -0.1557f, -0.1557f, -0.0600f, + -0.0958f, -0.1557f, -0.1557f, -0.0600f, 0.2227f, 0.2035f, 0.2035f, -0.0193f, + -0.3185f, -0.3592f, -0.3592f, -0.0407f, -0.0958f, -0.1557f, -0.1557f, -0.0600f, + -0.0958f, -0.1557f, -0.1557f, -0.0600f, 0.2227f, 0.2035f, 0.2035f, -0.0193f, + -0.1633f, -0.4536f, -0.4536f, -0.2903f, 0.1584f, 0.1749f, 0.1749f, 0.0165f, + 0.1584f, 0.1749f, 0.1749f, 0.0165f, 0.3217f, 0.6285f, 0.6285f, 0.3068f}; + std::vector dX_shape = X_shape; + std::vector dW = {-2.3068f, -2.1096f, -0.4322f, 0.4820f, 1.5420f, -4.1569f, -4.9628f, -5.5716f, + 1.0492f, 1.6683f, -6.3262f, -3.2359f, 2.4532f, -2.3299f, -5.1917f, -9.2525f}; + std::vector dW_shape = W_shape; + std::vector dB = {125.f, 125.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} +#endif // USE_CUDA + +} // namespace onnxruntime::contrib::test diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 6aac9ad7ec..8ec884382c 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -85,6 +85,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvTransposeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvTransposeGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DropoutGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropoutGrad); @@ -346,6 +349,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index f69da000be..f6c58445c0 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -3,13 +3,6 @@ #include "orttraining/training_ops/cuda/nn/conv_grad.h" -#include "core/providers/common.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "core/platform/ort_mutex.h" - -// The AlgoPerfCache and AlgoSearch here for Conv/ConvGrad is referenced on PyTorch's implementation -// from aten/src/ATen/native/cudnn/Conv_v7.cpp. - namespace onnxruntime { namespace cuda { @@ -22,229 +15,6 @@ REGISTER_GRADIENT_KERNEL_TYPED(float) REGISTER_GRADIENT_KERNEL_TYPED(double) REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16) -using T_BwdDataPerf = cudnnConvolutionBwdDataAlgoPerf_t; -using T_BwdDataAlgo = cudnnConvolutionBwdDataAlgo_t; -using T_BwdFilterPerf = cudnnConvolutionBwdFilterAlgoPerf_t; -using T_BwdFilterAlgo = cudnnConvolutionBwdFilterAlgo_t; - -cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) { - return cudnnGetConvolutionBackwardDataWorkspaceSize(args.handle, args.w_desc, args.y_tensor, args.conv_desc, - args.x_tensor, algo, workspace_size); -} - -cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) { - return cudnnGetConvolutionBackwardFilterWorkspaceSize(args.handle, args.x_tensor, args.y_tensor, args.conv_desc, - args.w_desc, algo, workspace_size); -} - -template -size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) { - // Calling cudaMemGetInfo is not ideal, but our cuda allocator doesn't have a way to get this info. - size_t free, total; - CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); - // Assuming 10% of fragmentation. - free = static_cast(static_cast(free) * 0.9); - size_t max_workspace_size = 0; - for (int i = 0; i < n_algo; i++) { - cudnnStatus_t status; - size_t workspace_size; - status = GetWorkspaceSize(args, algo[i], &workspace_size); - if (CUDNN_STATUS_SUCCESS != status || workspace_size == 0 || workspace_size < max_workspace_size || - workspace_size > free) - continue; - max_workspace_size = workspace_size; - } - - return max_workspace_size; -} - -template -std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { - std::vector result; - result.reserve(n_algo); - for (int i = 0; i < n_algo; i++) { - T_Perf perf = perf_results[i]; - if (perf.status == CUDNN_STATUS_SUCCESS) { - result.emplace_back(perf); - } - } - ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in CuDNN"); - // TODO: This is a cuDNN bug that gave wrong results in certain strided convolution gradient setups - // when cuDNN version < 7.5. Need to add handling for such special case. - return result; -} - -struct ConvParamsHash { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); - size_t operator()(const ConvParams& conv_params) const { - auto ptr = reinterpret_cast(&conv_params); - uint32_t value = 0x811C9DC5; - for (int i = 0; i < static_cast(sizeof(ConvParams)); ++i) { - value ^= ptr[i]; - value *= 0x01000193; - } - return static_cast(value); - } -}; - -struct ConvParamsEqual { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); - bool operator()(const ConvParams& a, const ConvParams& b) const { - auto ptr1 = reinterpret_cast(&a); - auto ptr2 = reinterpret_cast(&b); - return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0; - } -}; - -template -struct AlgoPerfCache { - mutable OrtMutex mutex; - std::unordered_map map; - - bool Find(const ConvParams& params, T_Perf* result) { - std::lock_guard guard(mutex); - auto it = map.find(params); - if (it == map.end()) { - return false; - } - *result = it->second; - return true; - } - - void Insert(const ConvParams& params, const T_Perf& algo_perf) { - std::lock_guard guard(mutex); - map[params] = algo_perf; - } -}; - -// TODO: Currently we use global AlgoPerfCache for ConvGrad only. Conv's perf cache is till per node. -// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc. -AlgoPerfCache bwd_data_algos; -AlgoPerfCache bwd_filter_algos; - -template -struct AlgoSearch {}; - -template <> -struct AlgoSearch { - static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - static AlgoPerfCache& Cache() { return bwd_data_algos; } - static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, - std::vector& perf_results) { - static const T_BwdDataAlgo algos[] = { - CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, - CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, - CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; - static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); - int perf_count; - std::unique_ptr candidates = std::make_unique(num_algos); - if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(args.handle, args.w_desc, args.y_tensor, - args.conv_desc, args.x_tensor, num_algos, - &perf_count, candidates.get())); - } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { - size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( - args.handle, args.w_desc, args.w_data, args.y_tensor, args.dy_data, args.conv_desc, args.x_tensor, - args.dx_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); - } else { - ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); - } - perf_results = GetValidAlgorithms(candidates.get(), perf_count); - return Status::OK(); - } -}; - -template <> -struct AlgoSearch { - static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - static AlgoPerfCache& Cache() { return bwd_filter_algos; } - static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, - std::vector& perf_results) { - static const T_BwdFilterAlgo algos[] = { - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, - }; - - // NOTE: - 1 because ALGO_WINOGRAD is not implemented. - static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); - std::unique_ptr candidates = std::make_unique(num_algos); - int perf_count; - if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(args.handle, args.x_tensor, args.y_tensor, - args.conv_desc, args.w_desc, num_algos, - &perf_count, candidates.get())); - } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { - size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardFilterAlgorithmEx( - args.handle, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, args.w_desc, - args.dw_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); - } else { - ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); - } - perf_results = GetValidAlgorithms(candidates.get(), perf_count); - return Status::OK(); - } -}; - -template -class AlgoIterator { - public: - AlgoIterator(const ConvArgs& args) : args_(args) {} - - static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { - perf_results.resize(1); - perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; - if (args.params.data_type == CUDNN_DATA_HALF) { - perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; - } else { - perf_results[0].mathType = CUDNN_DEFAULT_MATH; - } - CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].algo, &(perf_results[0].memory))); - return Status::OK(); - } - - Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, std::function f) { - auto& cache = AlgoSearch::Cache(); - - if (T_Perf algo_perf; cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) { - return Status::OK(); - } - - std::vector perf_results; - ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault - ? OnlyDefaultAlgorithm(args_, perf_results) - : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); - for (auto& algo_perf : perf_results) { - if (f(algo_perf) == Status::OK()) { - cache.Insert(args_.params, algo_perf); - return Status::OK(); - } - } - ORT_ENFORCE(false, "Unable to find a valid cuDNN algorithm to run convolution."); - return Status::OK(); - } - - private: - const ConvArgs& args_; -}; - template Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, Tensor* dW, cudnnHandle_t cudnn_handle) const { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h index 5d0c123fd9..9bbcd5b30d 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h @@ -3,47 +3,11 @@ #pragma once -#include "core/providers/cuda/cudnn_common.h" -#include "core/providers/cpu/nn/conv_attributes.h" -#include "core/providers/cuda/nn/conv.h" +#include "orttraining/training_ops/cuda/nn/conv_shared.h" namespace onnxruntime { namespace cuda { -// cuDNN only takes 4D or 5D x tensor. -static constexpr int MAX_DIM = 3; - -struct ConvParams { - int8_t device_id; - cudnnDataType_t data_type; - int input_size[2 + MAX_DIM]; - uint8_t input_dim; - int weight_size[2 + MAX_DIM]; - int padding[MAX_DIM * 2]; - int stride[MAX_DIM]; - int dilation[MAX_DIM]; - int64_t groups; - int algo_mode; -}; - -struct ConvArgs { - // Update needed if x or w's dims changed. - TensorShapeVector last_x_dims; - TensorShapeVector last_w_dims; - - cudnnHandle_t handle; - ConvParams params; - CudnnTensor x_tensor, y_tensor, b_tensor; - CudnnFilterDescriptor w_desc; - CudnnConvolutionDescriptor conv_desc; - const void* x_data; - const void* w_data; - const void* dy_data; - void* dx_data; - void* dw_data; - void* db_data; -}; - template class ConvGrad final : public CudaKernel { public: diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc new file mode 100644 index 0000000000..5dc16c68f6 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/nn/conv_shared.h" + +#include "core/platform/ort_mutex.h" +#include "core/providers/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime::cuda { + +namespace { + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionBackwardDataWorkspaceSize(args.handle, args.w_desc, args.y_tensor, args.conv_desc, + args.x_tensor, algo, workspace_size); +} + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionBackwardFilterWorkspaceSize(args.handle, args.x_tensor, args.y_tensor, args.conv_desc, + args.w_desc, algo, workspace_size); +} + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_FwdAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionForwardWorkspaceSize(args.handle, args.x_tensor, args.w_desc, args.conv_desc, + args.y_tensor, algo, workspace_size); +} + +template +size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) { + // Calling cudaMemGetInfo is not ideal, but our cuda allocator doesn't have a way to get this info. + size_t free, total; + CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation. + free = static_cast(static_cast(free) * 0.9); + size_t max_workspace_size = 0; + for (int i = 0; i < n_algo; i++) { + cudnnStatus_t status; + size_t workspace_size; + status = GetWorkspaceSize(args, algo[i], &workspace_size); + if (CUDNN_STATUS_SUCCESS != status || workspace_size == 0 || workspace_size < max_workspace_size || + workspace_size > free) + continue; + max_workspace_size = workspace_size; + } + + return max_workspace_size; +} + +template +std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { + std::vector result; + result.reserve(n_algo); + for (int i = 0; i < n_algo; i++) { + T_Perf perf = perf_results[i]; + if (perf.status == CUDNN_STATUS_SUCCESS) { + result.emplace_back(perf); + } + } + ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in CuDNN"); + // TODO: This is a cuDNN bug that gave wrong results in certain strided convolution gradient setups + // when cuDNN version < 7.5. Need to add handling for such special case. + return result; +} + +template +struct AlgoPerfCache { + mutable OrtMutex mutex; + std::unordered_map map; + + bool Find(const ConvParams& params, T_Perf* result) { + std::lock_guard guard(mutex); + auto it = map.find(params); + if (it == map.end()) { + return false; + } + *result = it->second; + return true; + } + + void Insert(const ConvParams& params, const T_Perf& algo_perf) { + std::lock_guard guard(mutex); + map[params] = algo_perf; + } +}; + +// TODO: Currently we use global AlgoPerfCache for ConvGrad and ConvTransposeGrad only. +// Conv's perf cache is still per node. +// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc. +AlgoPerfCache bwd_data_algos; +AlgoPerfCache bwd_filter_algos; +AlgoPerfCache fwd_algos; + +template +struct AlgoSearch {}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + static AlgoPerfCache& Cache() { return bwd_data_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::vector& perf_results) { + static const T_BwdDataAlgo algos[] = { + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; + static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); + int perf_count; + std::unique_ptr candidates = std::make_unique(num_algos); + if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(args.handle, args.w_desc, args.y_tensor, + args.conv_desc, args.x_tensor, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( + args.handle, args.w_desc, args.w_data, args.y_tensor, args.dy_data, args.conv_desc, args.x_tensor, + args.dx_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + static AlgoPerfCache& Cache() { return bwd_filter_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::vector& perf_results) { + static const T_BwdFilterAlgo algos[] = { + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, + }; + + // NOTE: - 1 because ALGO_WINOGRAD is not implemented. + static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); + int perf_count; + if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(args.handle, args.x_tensor, args.y_tensor, + args.conv_desc, args.w_desc, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardFilterAlgorithmEx( + args.handle, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, args.w_desc, + args.dw_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + static AlgoPerfCache& Cache() { return fwd_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::vector& perf_results) { + static const T_FwdAlgo algos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + }; + + static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); + int perf_count; + if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(args.handle, args.x_tensor, args.w_desc, + args.conv_desc, args.y_tensor, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = max_workspace_size == 0 + ? nullptr + : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( + args.handle, args.x_tensor, args.x_data, args.w_desc, args.w_data, args.conv_desc, args.y_tensor, + args.y_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +} // namespace + +size_t ConvParamsHash::operator()(const ConvParams& conv_params) const { + auto ptr = reinterpret_cast(&conv_params); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < static_cast(sizeof(ConvParams)); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return static_cast(value); +} + +bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0; +} + +template +Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { + perf_results.resize(1); + perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; + if (args.params.data_type == CUDNN_DATA_HALF) { + perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; + } else { + perf_results[0].mathType = CUDNN_DEFAULT_MATH; + } + CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].algo, &(perf_results[0].memory))); + return Status::OK(); +} + +template +Status AlgoIterator::TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::function f) { + auto& cache = AlgoSearch::Cache(); + + if (T_Perf algo_perf; cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) { + return Status::OK(); + } + + std::vector perf_results; + ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault + ? OnlyDefaultAlgorithm(args_, perf_results) + : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); + for (auto& algo_perf : perf_results) { + if (f(algo_perf) == Status::OK()) { + cache.Insert(args_.params, algo_perf); + return Status::OK(); + } + } + ORT_ENFORCE(false, "Unable to find a valid cuDNN algorithm to run convolution."); + return Status::OK(); +} + +template class AlgoIterator; +template class AlgoIterator; +template class AlgoIterator; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h new file mode 100644 index 0000000000..a2d4bf3bdc --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/nn/conv.h" + +// The AlgoPerfCache and AlgoSearch here for Conv/ConvGrad/ConvTransposeGrad is adapted from PyTorch's implementation +// in aten/src/ATen/native/cudnn/Conv_v7.cpp. + +namespace onnxruntime::cuda { + +using T_BwdDataPerf = cudnnConvolutionBwdDataAlgoPerf_t; +using T_BwdDataAlgo = cudnnConvolutionBwdDataAlgo_t; +using T_BwdFilterPerf = cudnnConvolutionBwdFilterAlgoPerf_t; +using T_BwdFilterAlgo = cudnnConvolutionBwdFilterAlgo_t; +using T_FwdAlgo = cudnnConvolutionFwdAlgo_t; +using T_FwdPerf = cudnnConvolutionFwdAlgoPerf_t; + +// cuDNN only takes 4D or 5D x tensor. +static constexpr int MAX_DIM = 3; + +struct ConvParams { + int8_t device_id; + cudnnDataType_t data_type; + int input_size[2 + MAX_DIM]; + uint8_t input_dim; + int weight_size[2 + MAX_DIM]; + int padding[MAX_DIM * 2]; + int stride[MAX_DIM]; + int dilation[MAX_DIM]; + int64_t groups; + int algo_mode; +}; + +struct ConvArgs { + // Update needed if x or w's dims changed. + TensorShapeVector last_x_dims; // Input to the convolution + TensorShapeVector last_w_dims; // Weights of the convolution + + cudnnHandle_t handle; + ConvParams params; + CudnnTensor x_tensor, y_tensor, b_tensor; + CudnnFilterDescriptor w_desc; + CudnnConvolutionDescriptor conv_desc; + const void* x_data; + const void* w_data; + const void* dy_data; + void* y_data; + void* dx_data; + void* dw_data; + void* db_data; +}; + +struct ConvParamsHash { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + + size_t operator()(const ConvParams& conv_params) const; +}; + +struct ConvParamsEqual { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + + bool operator()(const ConvParams& a, const ConvParams& b) const; +}; + +template +class AlgoIterator { + public: + AlgoIterator(const ConvArgs& args) : args_(args) {} + + Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::function f); + + static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); + + private: + const ConvArgs& args_; +}; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc new file mode 100644 index 0000000000..5f7206fc12 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/nn/conv_transpose_grad.h" + +namespace onnxruntime::cuda { + +#define REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTransposeGrad, kMSDomain, 1, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ConvTransposeGrad); + +REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(float) +REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(double) +REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(MLFloat16) + +template +Status ConvTransposeGrad::ComputeInternal(OpKernelContext* context) const { + const Tensor* dY = context->Input(0); + const Tensor* X = context->Input(1); + const Tensor* W = context->Input(2); + Tensor* dX = context->Output(0, X->Shape()); + Tensor* dW = context->Output(1, W->Shape()); + Tensor* dB = context->Output(2, {W->Shape()[1] * conv_attrs_.group}); + + if (dX) { + ORT_RETURN_IF_ERROR(PrepareConvForwardArgs(*dY, *W, *dX, GetCudnnHandle(context), args_dx_)); + ORT_RETURN_IF_ERROR(ComputeInputGradient(context->GetComputeStream(), args_dx_)); + } + + if (dW || dB) { + ORT_RETURN_IF_ERROR(PrepareConvBackwardFilterArgs(*dY, *W, *X, dW, dB, GetCudnnHandle(context), args_dw_)); + if (dW) ORT_RETURN_IF_ERROR(ComputeWeightGradient(context->GetComputeStream(), args_dw_)); + if (dB) ORT_RETURN_IF_ERROR(ComputeBiasGradient(args_dw_)); + } + + return Status::OK(); +} + +template +Status ConvTransposeGrad::ComputeInputGradient(onnxruntime::Stream* stream, const ConvArgs& args) const { + return AlgoIterator(args).TryAll( + static_cast(Info().GetExecutionProvider()), + Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + [&](const T_FwdPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory, stream); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args.conv_desc, algo_perf.mathType)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward( + args.handle, &one, args.x_tensor, args.x_data, args.w_desc, args.w_data, args.conv_desc, + algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.y_tensor, args.y_data)); + return Status::OK(); + }); + return Status::OK(); +} + +template +Status ConvTransposeGrad::ComputeWeightGradient(onnxruntime::Stream* stream, const ConvArgs& args) const { + return AlgoIterator(args).TryAll( + static_cast(Info().GetExecutionProvider()), + Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + [&](const T_BwdFilterPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory, stream); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args.conv_desc, algo_perf.mathType)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardFilter( + args.handle, &one, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, + algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.w_desc, args.dw_data)); + return Status::OK(); + }); + return Status::OK(); +} + +template +Status ConvTransposeGrad::ComputeBiasGradient(const ConvArgs& args) const { + const auto one = Consts::One; + const auto zero = Consts::Zero; + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardBias(args.handle, &one, args.x_tensor, args.x_data, &zero, + args.b_tensor, args.db_data)); + return Status::OK(); +} + +template +Status ConvTransposeGrad::PrepareConvForwardArgs(const Tensor& X, const Tensor& W, + Tensor& Y, cudnnHandle_t cudnn_handle, + ConvArgs& args) const { + const TensorShape& x_shape = X.Shape(); + auto x_dims = x_shape.AsShapeVector(); + args.x_data = reinterpret_cast(X.template Data()); + + const TensorShape& w_shape = W.Shape(); + auto w_dims = w_shape.AsShapeVector(); + args.w_data = reinterpret_cast(W.template Data()); + + const TensorShape& y_shape = Y.Shape(); + auto y_dims = y_shape.AsShapeVector(); + args.y_data = reinterpret_cast(Y.template MutableData()); + + args.dy_data = nullptr; + args.db_data = nullptr; + args.dx_data = nullptr; + args.dw_data = nullptr; + + bool x_dims_changed = (args.last_x_dims != x_dims); + bool w_dims_changed = (args.last_w_dims != w_dims); + if (x_dims_changed || w_dims_changed) { + if (x_dims_changed) args.last_x_dims = x_dims; + if (w_dims_changed) args.last_w_dims = w_dims; + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&X, &W)); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); + auto rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } + + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } + + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + if (rank < 2) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims.insert(x_dims.begin() + 2, 1); + y_dims.insert(y_dims.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims.push_back(1); + y_dims.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + memset(&args.params, 0, sizeof(ConvParams)); + args.params.device_id = static_cast(cuda_ep->GetDeviceId()); + args.params.data_type = CudnnTensor::GetDataType(); + args.params.input_dim = static_cast(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + args.params.input_size[i] = static_cast(x_dims[i]); + args.params.weight_size[i] = static_cast(w_dims[i]); + } + for (size_t i = 0; i < rank; i++) { + args.params.padding[i] = static_cast(pads[i]); + args.params.padding[i + rank] = static_cast(pads[i + rank]); + args.params.stride[i] = static_cast(strides[i]); + args.params.dilation[i] = static_cast(dilations[i]); + } + args.params.groups = conv_attrs_.group; + int algo_mode = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(algo_mode > -1 && algo_mode < 3, + "Algo mode should be EXHAUSTIVE (0), HEURISTIC (1) or DEFAULT (2), but got ", algo_mode); + args.params.algo_mode = algo_mode; + + args.handle = cudnn_handle; + ORT_RETURN_IF_ERROR(args.w_desc.Set(w_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.x_tensor.Set(x_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, + args.params.data_type)); + } + + return Status::OK(); +} + +template +Status ConvTransposeGrad::PrepareConvBackwardFilterArgs(const Tensor& X, const Tensor& W, const Tensor& dY, + Tensor* dW, Tensor* dB, cudnnHandle_t cudnn_handle, + ConvArgs& args) const { + const TensorShape& x_shape = X.Shape(); + auto x_dims = x_shape.AsShapeVector(); + args.x_data = reinterpret_cast(X.template Data()); + + const TensorShape& y_shape = dY.Shape(); + auto y_dims = y_shape.AsShapeVector(); + args.dy_data = reinterpret_cast(dY.template Data()); + + const TensorShape& w_shape = W.Shape(); + auto w_dims = w_shape.AsShapeVector(); + + args.y_data = nullptr; + args.dw_data = dW ? reinterpret_cast(dW->template MutableData()) : nullptr; + args.db_data = dB ? reinterpret_cast(dB->template MutableData()) : nullptr; + args.dx_data = nullptr; + args.w_data = nullptr; + + bool x_dims_changed = (args.last_x_dims != x_dims); + bool w_dims_changed = (args.last_w_dims != w_dims); + if (x_dims_changed || w_dims_changed) { + if (x_dims_changed) args.last_x_dims = x_dims; + if (w_dims_changed) args.last_w_dims = w_dims; + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&X, &W)); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); + auto rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } + + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } + + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + if (rank < 2) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims.insert(x_dims.begin() + 2, 1); + y_dims.insert(y_dims.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims.push_back(1); + y_dims.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + memset(&args.params, 0, sizeof(ConvParams)); + args.params.device_id = static_cast(cuda_ep->GetDeviceId()); + args.params.data_type = CudnnTensor::GetDataType(); + args.params.input_dim = static_cast(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + args.params.input_size[i] = static_cast(x_dims[i]); + args.params.weight_size[i] = static_cast(w_dims[i]); + } + for (size_t i = 0; i < rank; i++) { + args.params.padding[i] = static_cast(pads[i]); + args.params.padding[i + rank] = static_cast(pads[i + rank]); + args.params.stride[i] = static_cast(strides[i]); + args.params.dilation[i] = static_cast(dilations[i]); + } + args.params.groups = conv_attrs_.group; + int algo_mode = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(algo_mode > -1 && algo_mode < 3, + "Algo mode should be EXHAUSTIVE (0), HEURISTIC (1) or DEFAULT (2), but got ", algo_mode); + args.params.algo_mode = algo_mode; + + args.handle = cudnn_handle; + ORT_RETURN_IF_ERROR(args.w_desc.Set(w_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.x_tensor.Set(x_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, + args.params.data_type)); + + if (dB) { + const auto& b_shape = dB->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + kernel_shape.size()); + b_dims[0] = 1; // N + b_dims[1] = b_shape[0]; // C + for (size_t i = 0; i < kernel_shape.size(); i++) + b_dims[2 + i] = 1; + + ORT_RETURN_IF_ERROR(args.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + } + } + + return Status::OK(); +} + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h new file mode 100644 index 0000000000..72426323fe --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" + +#include "core/providers/cpu/nn/conv_attributes.h" +#include "orttraining/training_ops/cuda/nn/conv_shared.h" + +namespace onnxruntime::cuda { + +template +class ConvTransposeGrad final : public CudaKernel { + public: + using CudaT = typename ToCudaType::MappedType; + + ConvTransposeGrad(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + Status ComputeWeightGradient(onnxruntime::Stream* stream, const ConvArgs& args) const; + Status ComputeInputGradient(onnxruntime::Stream* stream, const ConvArgs& args) const; + Status ComputeBiasGradient(const ConvArgs& args) const; + + Status PrepareConvForwardArgs(const Tensor& X, const Tensor& W, + Tensor& Y, cudnnHandle_t cudnn_handle, + ConvArgs& args) const; + + Status PrepareConvBackwardFilterArgs(const Tensor& X, const Tensor& W, const Tensor& dY, + Tensor* dW, Tensor* dB, cudnnHandle_t cudnn_handle, + ConvArgs& args) const; + + ConvAttributes conv_attrs_; + mutable ConvArgs args_dx_; + mutable ConvArgs args_dw_; +}; + +} // namespace onnxruntime::cuda