From a43c57f59db35d8cceef1ff8f44985d745eb94b4 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Fri, 20 Oct 2023 11:39:57 -0700 Subject: [PATCH] ResizeGrad CUDA/ROCM kernel implementation (#17772) --- .../python/tools/symbolic_shape_infer.py | 1 - .../core/graph/gradient_builder.cc | 8 + .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../core/graph/training_op_defs.cc | 20 ++ .../ortmodule/_custom_gradient_registry.py | 5 - .../ortmodule/_custom_op_symbolic_registry.py | 13 - .../test/gradient/gradient_ops_test.cc | 35 +++ .../python/orttraining_test_ortmodule_api.py | 8 +- .../training_ops/cuda/resize_grad_test.cc | 227 ++++++++++++++++++ .../cuda/cuda_training_kernels.cc | 12 +- .../training_ops/cuda/tensor/resize_grad.cc | 81 +++++++ .../training_ops/cuda/tensor/resize_grad.h | 41 ++++ .../cuda/tensor/resize_grad_impl.cu | 151 ++++++++++++ .../cuda/tensor/resize_grad_impl.h | 20 ++ .../rocm/rocm_training_kernels.cc | 6 + 16 files changed, 605 insertions(+), 25 deletions(-) create mode 100644 orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 6d954bd540..67e9f1b55e 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -230,7 +230,6 @@ class SymbolicShapeInference: "upsample_nearest1d": self._infer_aten_upsample, "upsample_nearest2d": self._infer_aten_upsample, "upsample_nearest3d": self._infer_aten_upsample, - "upsample_bilinear2d": self._infer_aten_upsample, } self.run_ = True self.suggested_merge_ = {} diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 133cab71f2..6547f53a3c 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2147,5 +2147,13 @@ IMPLEMENT_GRADIENT_BUILDER(GetScaledSumGradient) { ORT_THROW("ScaledSum gradient builder does not support ", input_count, " inputs"); } +IMPLEMENT_GRADIENT_BUILDER(GetResizeGradient) { + return std::vector{ + NodeDef(OpDef{"ResizeGrad", kMSDomain, 1}, + {GO(0), I(0), I(1), I(2)}, + {GI(0)}, + SrcNodeAttributes())}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index a517e8af13..28a316261e 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -90,6 +90,7 @@ DECLARE_GRADIENT_BUILDER(GetGRUGradient) DECLARE_GRADIENT_BUILDER(GetReciprocalGradient) DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient) DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient) +DECLARE_GRADIENT_BUILDER(GetResizeGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 4062b5d097..4b8c68aef0 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -122,6 +122,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient); REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient); REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient); + REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index cfc79455c4..c90acfdb7b 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -5001,6 +5001,26 @@ Return true if all elements are true and false otherwise. "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(ResizeGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "dY", "Gradient of output Y.", "T") + .Input(1, "X", "Input tensor to the Resize operator.", "T") + .Input(2, "roi", "The roi input to the Resize operator.", "T", OpSchema::Optional) + .Input(3, "scales", "The scales input to the Resize operator.", "tensor(float)", OpSchema::Optional) + .Output(0, "dX", "Gradient of the input X.", "T") + .AllowUncheckedAttributes() + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); + if (hasInputShape(ctx, 1)) { + propagateShapeFromInputToOutput(ctx, 1, 0); + } + }); } } // namespace training diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 156c3e001d..7731724272 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -271,8 +271,3 @@ def upsample_nearest2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec") def upsample_nearest3d_gradient(): return _upsample_gradient("upsample_nearest3d_backward", 3) - - -@register_gradient("org.pytorch.aten", "ATen", "upsample_bilinear2d", "vec") -def upsample_bilinear2d_gradient(): - return _upsample_gradient("upsample_bilinear2d_backward", 2) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 64c7abe1c9..6e694dcdf2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -808,16 +808,3 @@ def upsample_nearest2d(g, input, output_size, scale_factors): @register_symbolic("upsample_nearest3d") def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") - - -@register_symbolic("upsample_bilinear2d") -def upsample_bilinear2d(g, input, output_size, align_corners, scale_factors): - return g.op( - "org.pytorch.aten::ATen", - input, - output_size, - align_corners, - scale_factors, - operator_s="upsample_bilinear2d", - overload_name_s="vec", - ) diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 597801f403..890a1bbccb 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3298,6 +3298,41 @@ TEST(GradientCheckerTest, ConvTransposeGrad) { execution_providers.push_back(DefaultCudaExecutionProvider()); ConvTransposeGradientCheckerTest(&execution_providers); } + +// TODO: Enable test for ROCM +TEST(GradientCheckerTest, ResizeGrad) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + const std::vector attributes = { + MakeAttribute("coordinate_transformation_mode", "half_pixel"), + MakeAttribute("cubic_coeff_a", -0.75f), + MakeAttribute("exclude_outside", static_cast(0)), + MakeAttribute("extrapolation_value", 0.0f), + MakeAttribute("mode", "linear"), + MakeAttribute("nearest_mode", "floor")}; + + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"Resize", kOnnxDomain, 18}; + + TensorInfo x_info({1, 2, 4, 4}, true); + TensorInfo roi_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo scales_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + + TensorInfo y_info({1, 2, 8, 8}, true); + + std::vector> x_datas = {{0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 2.0f, 2.0f}}; + + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, roi_info, scales_info}, + {y_info}, &max_error, x_datas, attributes, true, false, &execution_providers)); + EXPECT_IS_TINY(max_error); +} + #endif // USE_CUDA } // namespace test diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 643d47b0d0..c8ec2e52f3 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1773,13 +1773,17 @@ def test_aten_upsample_nearest(input_rank, use_factor): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) -def test_aten_upsample_bilinear(): +@pytest.mark.parametrize("interpolate_size_scale", ({"size": (8, 12)}, {"scale_factor": 4.7})) +@pytest.mark.parametrize("align_corners", (True, False)) +def test_resize_grad_correctness_bilinear_2d(interpolate_size_scale, align_corners): class _NeuralNetUpsampleBilinear(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input): - return torch.nn.functional.interpolate(input, size=(8, 12), mode="bilinear") + return torch.nn.functional.interpolate( + input, align_corners=align_corners, mode="bilinear", **interpolate_size_scale + ) device = "cuda" pt_model = _NeuralNetUpsampleBilinear().to(device) diff --git a/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc new file mode 100644 index 0000000000..8fc13af881 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/providers/compare_provider_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime::test { + +#if defined(USE_CUDA) || defined(USE_ROCM) + +namespace { + +void AddResizeGradAttributes(OpTester& test, const std::string& coordinate_transformation_mode) { + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", coordinate_transformation_mode); +} + +} // namespace + +TEST(ResizeGradTest, ResizeGradWithSizes) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(128, 1.0f); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX(32, 4.0f); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithSizesHalf) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(128, 1.0f); + std::vector dY_half(dY.size()); + ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast(dY.size())); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_half(X.size()); + ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast(X.size())); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX(32, 4.0f); + std::vector dX_half(dX.size()); + ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast(dX.size())); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY_half); + test.AddInput("X", X_shape, X_half); + + test.AddOutput("dX", dX_shape, dX_half); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithSizesAndAlignCorners) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "align_corners"); + + std::vector dY(128, 1.0f); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f, + 3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f, + 2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f, + 3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScales) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(72, 1.0f); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f, + 2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScalesHalf) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(72, 1.0f); + std::vector dY_half(dY.size()); + ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast(dY.size())); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_half(X.size()); + ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast(X.size())); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f, + 2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f}); + std::vector dX_half(dX.size()); + ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast(dX.size())); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY_half); + test.AddInput("X", X_shape, X_half); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX_half); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScalesAndAlignCorners) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "align_corners"); + + std::vector dY(72, 1.0f); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f, + 2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f, + 1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f, + 2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +#endif // defined(USE_CUDA) || defined(USE_ROCM) + +} // namespace onnxruntime::test diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 8e61dbee50..ae4f48b6b4 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -207,6 +207,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ResizeGrad); // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training @@ -453,13 +456,14 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training #ifdef ENABLE_TRAINING diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc new file mode 100644 index 0000000000..a5e8f7cd35 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "orttraining/training_ops/cuda/tensor/resize_grad.h" +#include "orttraining/training_ops/cuda/tensor/resize_grad_impl.h" + +namespace onnxruntime::cuda { + +#define REGISTER_RESIZEGRAD_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ResizeGrad, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) /* Keep roi on CPU */ \ + .InputMemoryType(OrtMemTypeCPUInput, 3) /* Keep scales on CPU */ \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ResizeGrad); + +REGISTER_RESIZEGRAD_KERNEL_TYPED(MLFloat16) +REGISTER_RESIZEGRAD_KERNEL_TYPED(float) +REGISTER_RESIZEGRAD_KERNEL_TYPED(double) + +template +Status ResizeGrad::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + + const Tensor* dY = context->Input(0); + const Tensor* X = context->Input(1); + const Tensor* scales = context->Input(3); + + ORT_ENFORCE(X->Shape().NumDimensions() == 4, "Expected input tensor to have 4 dimensions. Actual: ", + X->Shape().NumDimensions()); + + const auto get_scales_from_input = [](const Tensor* scales) { + if (nullptr == scales) { + return std::make_pair(std::optional{}, std::optional{}); + } + + ORT_ENFORCE(scales->Shape().Size() == 4, "There must be a scale for each dimension."); + + const auto* scales_data = scales->Data(); + return std::make_pair(std::optional{scales_data[2]}, std::optional{scales_data[3]}); + }; + + std::pair, std::optional> scale_factors = get_scales_from_input(scales); + + Tensor* dX = context->Output(0, X->Shape()); + + const int64_t batch_size = X->Shape()[0]; + const int64_t num_channels = X->Shape()[1]; + const int64_t output_height = dY->Shape()[2]; + const int64_t output_width = dY->Shape()[3]; + const int64_t input_height = X->Shape()[2]; + const int64_t input_width = X->Shape()[3]; + + if (dX->Shape() == dY->Shape()) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dX->MutableDataRaw(), dY->DataRaw(), dY->SizeInBytes(), cudaMemcpyDeviceToDevice)); + return Status::OK(); + } + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(dX->MutableDataRaw(), 0, dX->SizeInBytes(), Stream(context))); + + const bool align_corners = coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS; + const CudaT* dy_data = reinterpret_cast(dY->Data()); + CudaT* dx_data = reinterpret_cast(dX->MutableData()); + + ResizeGradImpl(Stream(context), input_height, input_width, output_height, + output_width, batch_size, num_channels, align_corners, + scale_factors.first, scale_factors.second, + dy_data, dx_data); + + return Status::OK(); +} + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h new file mode 100644 index 0000000000..53f8d5f0d7 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cpu/tensor/upsamplebase.h" + +namespace onnxruntime::cuda { + +template +class ResizeGrad final : public UpsampleBase, public CudaKernel { + public: + ResizeGrad(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { + ORT_ENFORCE(!antialias_, "Antialiasing is not supported in ResizeGrad yet."); + + ORT_ENFORCE(axes_.empty(), "ReizeGrad does not support the `axes` attribute yet."); + + std::string coordinate_transform_mode = + info.GetAttrOrDefault("coordinate_transformation_mode", "half_pixel"); + coordinate_transform_mode_ = StringToCoordinateTransformationMode(coordinate_transform_mode); + ORT_ENFORCE(coordinate_transform_mode_ == ResizeCoordinateTransformationMode::HALF_PIXEL || + coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS, + "ReizeGrad only supports the `HALF_PIXEL` and `ALIGN_CORNERS` coordinate_transform_mode ", + coordinate_transform_mode, " is not supported yet."); + + ORT_ENFORCE(keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH, + "ReizeGrad only supports the `STRETCH` policy."); + + std::string mode; + ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); + ORT_ENFORCE((UpsampleMode::LINEAR == mode_), + "ReizeGrad only supports the `LINEAR` mode. ", mode, " mode is not supported yet."); + } + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu new file mode 100644 index 0000000000..0507cda623 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Contents of this file are derived from the pytorch cuda implementation of +// the upsample_bilinear2d_backward implementation at: +// https://github.com/pytorch/pytorch/blob/ce50132748f652ed6079c3db8008a6817594dbae/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu + +#include "orttraining/training_ops/cuda/tensor/resize_grad_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/atomic/common.cuh" + +namespace onnxruntime::cuda { + +namespace { + +constexpr int NumThreadsPerBlock = GridDim::maxThreadsPerBlock; + +} // namespace + +__device__ __forceinline__ size_t +idx(const size_t nc, + const size_t height, + const size_t width, + const size_t h, + const size_t w) { + return (nc * height + h) * width + w; +} + +template +__device__ __forceinline__ static T AreaPixelComputeSourceIndex( + T scale, + int dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + T src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + return (!cubic && src_idx < static_cast(0)) + ? static_cast(0) + : src_idx; + } +} + +template +__global__ void UpsampleGrad(const int64_t nc, const int64_t input_height, + const int64_t input_width, const int64_t output_height, + const int64_t output_width, const AccT rheight, + const AccT rwidth, const bool align_corners, + const T* dY_data, T* dX_data) { + const size_t dy_numel = nc * output_width * output_height; + const size_t dx_numel = nc * input_width * input_height; + for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; + index < dy_numel; + index += blockDim.x * gridDim.x) { + size_t index_temp = index; + const int w2 = index_temp % output_width; // 0:width2-1 + index_temp /= output_width; + const int h2 = index_temp % output_height; // 0:height2-1 + const size_t nc = index_temp / output_height; + + const AccT h1r = AreaPixelComputeSourceIndex( + rheight, h2, align_corners, /*cubic=*/false); + const int h1 = h1r; + const int h1p = (h1 < input_height - 1) ? 1 : 0; + const AccT h1lambda = h1r - h1; + const AccT h0lambda = static_cast(1) - h1lambda; + + const AccT w1r = AreaPixelComputeSourceIndex( + rwidth, w2, align_corners, /*cubic=*/false); + const int w1 = w1r; + const int w1p = (w1 < input_width - 1) ? 1 : 0; + const AccT w1lambda = w1r - w1; + const AccT w0lambda = static_cast(1) - w1lambda; + + const T d2val = dY_data[index]; + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1, w1), + dx_numel, + static_cast(h0lambda * w0lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1, w1 + w1p), + dx_numel, + static_cast(h0lambda * w1lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1 + h1p, w1), + dx_numel, + static_cast(h1lambda * w0lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1 + h1p, w1 + w1p), + dx_numel, + static_cast(h1lambda * w1lambda) * d2val); + } +} + +template +T AreaPixelComputeScale(int64_t input_size, int64_t output_size, bool align_corners, + const std::optional& scale) { + if (align_corners) { + if (output_size <= 1) { + return T{0}; + } + return static_cast(input_size - 1) / static_cast(output_size - 1); + } else { + if (scale.has_value()) { + return static_cast(T{1.0} / *scale); + } else { + return static_cast(input_size) / static_cast(output_size); + } + } +} + +template +void ResizeGradImpl(cudaStream_t stream, int64_t input_height, + int64_t input_width, int64_t output_height, + int64_t output_width, int64_t batch_size, + int64_t channels, bool align_corners, + const std::optional& scale_height, + const std::optional& scale_width, + const T* dY_data, T* dX_data) { + float rheight = AreaPixelComputeScale(input_height, output_height, align_corners, scale_height); + float rwidth = AreaPixelComputeScale(input_width, output_width, align_corners, scale_width); + + const size_t output_numel = batch_size * channels * output_height * output_width; + int blocks_per_grid = (int)(ceil(static_cast(output_numel) / NumThreadsPerBlock)); + UpsampleGrad<<>>( + batch_size * channels, input_height, input_width, output_height, output_width, + rheight, rwidth, align_corners, dY_data, dX_data); +} + +#define SPECIALIZED_RESIZEGRAD_IMPL(T) \ + template void ResizeGradImpl(cudaStream_t stream, int64_t input_height, \ + int64_t input_width, int64_t output_height, \ + int64_t output_width, int64_t batch_size, \ + int64_t channels, bool align_corners, \ + const std::optional& scale_height, \ + const std::optional& scale_width, \ + const T* dY_data, T* dX_data); + +SPECIALIZED_RESIZEGRAD_IMPL(half) +SPECIALIZED_RESIZEGRAD_IMPL(float) +SPECIALIZED_RESIZEGRAD_IMPL(double) + +#undef SPECIALIZED_RESIZEGRAD_IMPL + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h new file mode 100644 index 0000000000..3e917f9071 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime::cuda { + +template +void ResizeGradImpl(cudaStream_t stream, int64_t input_height, + int64_t input_width, int64_t output_height, + int64_t output_width, int64_t batch_size, + int64_t channels, bool align_corners, + const std::optional& scale_height, + const std::optional& scale_width, + const T* dY_data, T* dX_data); + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 2321aa23dd..e0749c2fb4 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -187,6 +187,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad); #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. @@ -387,6 +390,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI)