From 8c249cc8f7933fa031e608bf2e10a03ae93dfa31 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 14 Dec 2022 11:54:02 -0800 Subject: [PATCH] [QAT] FakeQuantGrad and gradient building for FakeQuant (#13825) --- .../core/graph/gradient_builder.cc | 4 + .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../core/graph/training_op_defs.cc | 23 ++++ .../test/training_ops/cuda/fake_quant_test.cc | 107 +++++++++++++++++- .../training_ops/cpu/cpu_training_kernels.cc | 3 + .../cpu/quantization/fake_quant.cc | 38 ++++++- .../cpu/quantization/fake_quant.h | 9 ++ .../cuda/cuda_training_kernels.cc | 3 + .../cuda/quantization/fake_quant.cc | 37 +++++- .../cuda/quantization/fake_quant.h | 9 ++ .../cuda/quantization/fake_quant_impl.cu | 32 ++++++ .../cuda/quantization/fake_quant_impl.h | 4 + 13 files changed, 266 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index c8e0f2bd5e..207175f861 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1878,5 +1878,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetScatterElementsGradient) { return result; } +IMPLEMENT_GRADIENT_BUILDER(GetFakeQuantGradient) { + return {NodeDef(OpDef{"FakeQuantGrad", kMSDomain, 1}, {GO(0), O(1)}, {GI(0)})}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index cecf37d415..f4cb76a032 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -81,6 +81,7 @@ DECLARE_GRADIENT_BUILDER(GetPythonOpGradient) DECLARE_GRADIENT_BUILDER(GetScatterNDGradient) DECLARE_GRADIENT_BUILDER(GetScatterElementsGradient) DECLARE_GRADIENT_BUILDER(GetTriluGradient) +DECLARE_GRADIENT_BUILDER(GetFakeQuantGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index c5693fb616..fc9250fe69 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -113,6 +113,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("ScatterND", GetScatterNDGradient); REGISTER_GRADIENT_BUILDER("ScatterElements", GetScatterElementsGradient); REGISTER_GRADIENT_BUILDER("Trilu", GetTriluGradient); + REGISTER_GRADIENT_BUILDER("FakeQuant", GetFakeQuantGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index b8d8d620d1..a2dddc66b6 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4015,6 +4015,29 @@ Return true if all elements are true and false otherwise. updateOutputElemType(ctx, 1, ONNX_NAMESPACE::TensorProto::BOOL); propagateShapeFromInputToOutput(ctx, 0, 1); }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(FakeQuantGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc( + "FakeQuantGrad op that computes the partial derivative of the loss with respect to the input tensor to " + "the FakeQuant op.") + .Input(0, "dY", "Gradient of loss with respect to the output Y of the FakeQuant op (fake quantized output)", "T") + .Input(1, "gradient_mask", + "Gradient mask that indicates whether the quantized value is within the quantization range.", + "T_BOOL") + .Output(0, "dX", "Gradient of loss with respect to the input X (of the FakeQuant node).", "T") + .TypeConstraint( + "T", + {"tensor(float)"}, + "Constrain the gradient input and output types to float tensors.") + .TypeConstraint( + "T_BOOL", + {"tensor(bool)"}, + "Constrain the gradient mask input to bool tensors.") + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + }); } } // namespace training diff --git a/orttraining/orttraining/test/training_ops/cuda/fake_quant_test.cc b/orttraining/orttraining/test/training_ops/cuda/fake_quant_test.cc index c62584d04b..1067156122 100644 --- a/orttraining/orttraining/test/training_ops/cuda/fake_quant_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/fake_quant_test.cc @@ -26,7 +26,7 @@ void CompareFakeQuantKernels(const std::vector& tensor_dim, std::vector scale = random.Uniform(std::vector({1}), 0.04f, 0.1f); test.AddInput("scale", {1}, scale); std::vector zero_point = random.Uniform(std::vector({1}), 0.f, 255.0f); - test.AddInput("zero_scale", {1}, std::vector({std::nearbyint(zero_point.front())})); + test.AddInput("zero_point", {1}, std::vector({std::nearbyint(zero_point.front())})); // Create output tensors std::vector fake_quantized_data = FillZeros(tensor_dim); @@ -36,6 +36,7 @@ void CompareFakeQuantKernels(const std::vector& tensor_dim, test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance); } + #endif } // namespace @@ -54,7 +55,7 @@ TEST(FakeQuantTest, FakeQuantComputation) { test.AddInput("input_tensor", {10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}); test.AddInput("scale", {1}, {0.075f}); - test.AddInput("zero_scale", {1}, {128.0f}); + test.AddInput("zero_point", {1}, {128.0f}); // quantized values = nearby_int(value / scale + zero_point) // = {13.33+128, 26.66+128, 40.00+128, 53.33+128, 66.66+128, ...} // = {141.33, 154.66, 168.00, 171.33, 184.66, ...} @@ -79,5 +80,107 @@ TEST(CudaKernelTest, FakeQuant) { } #endif +class FakeQuantGradParameterizedTest : public ::testing::TestWithParam { +}; + +TEST_P(FakeQuantGradParameterizedTest, FakeQuantGradComputation) { + std::vector> providers; + providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#endif + OpTester test("FakeQuantGrad", 1, kMSDomain, true); + + auto x_shape = GetParam(); + RandomValueGenerator random{}; + + // Randomly generate the gradient w.r.t. the output Y. + std::vector dY_data = random.Uniform(x_shape.GetDims(), -1000.0f, 1000.0f); + + // Randomly generate the mask data + std::unique_ptr mask_data = [&random, &x_shape]() { + auto data_int = random.Uniform(x_shape.GetDims(), 0, 2); + auto data_bool = std::make_unique(detail::SizeFromDims(x_shape.GetDims())); + for (size_t i = 0; i < data_int.size(); ++i) { + data_bool.get()[i] = data_int[i] == 0; + } + return data_bool; + }(); + + // Calculate the gradient w.r.t. the input X. + std::vector dX_data = [](const auto& dY_data, const auto& mask_data) { + auto dX_data = dY_data; + for (size_t i = 0; i < dY_data.size(); ++i) { + if (!mask_data.get()[i]) + dX_data[i] = 0.0f; + } + return dX_data; + }(dY_data, mask_data); + + test.AddInput("dY", x_shape.AsShapeVector(), dY_data); + test.AddInput("gradient_mask", x_shape.AsShapeVector(), mask_data.get(), x_shape.Size()); + test.AddOutput("dX", x_shape.AsShapeVector(), dX_data); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +INSTANTIATE_TEST_SUITE_P( + FakeQuantGradTests, + FakeQuantGradParameterizedTest, + ::testing::Values( + TensorShape({4}), + TensorShape({8, 4}), + TensorShape({4, 7, 13}), + TensorShape({4, 8, 16, 32}), + TensorShape({4, 16, 32, 4096}))); + +#ifdef USE_CUDA + +class FakeQuantGradKernelComparisonParameterizedTest : public ::testing::TestWithParam> { +}; + +TEST_P(FakeQuantGradKernelComparisonParameterizedTest, FakeQuantGradKernels) { + CompareOpTester test("FakeQuantGrad", 1, onnxruntime::kMSDomain); + auto tensor_dim = GetParam(); + + // Randomly generate the gradient w.r.t. the output Y. + RandomValueGenerator random{}; + std::vector dY_data = random.Uniform(tensor_dim, -1000.0f, 1000.0f); + + // Randomly generate the mask data + std::unique_ptr mask_data = [&random, &tensor_dim]() { + auto data_int = random.Uniform(tensor_dim, 0, 2); + auto data_bool = std::make_unique(detail::SizeFromDims(tensor_dim)); + for (size_t i = 0; i < data_int.size(); ++i) { + data_bool.get()[i] = data_int[i] == 0; + } + return data_bool; + }(); + + // Initialize the gradient w.r.t. the input X with 0s. + std::vector dX_data = FillZeros(tensor_dim); + + test.AddInput("dY", tensor_dim, dY_data); + test.AddInput("gradient_mask", tensor_dim, mask_data.get(), detail::SizeFromDims(tensor_dim)); + test.AddOutput("dX", tensor_dim, dX_data); + + // Compare the outputs from the two kernels + const double per_sample_tolerance = 2e-4; + const double relative_per_sample_tolerance = 2e-4; + test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance); +} + +INSTANTIATE_TEST_SUITE_P( + FakeQuantGradTests, + FakeQuantGradKernelComparisonParameterizedTest, + ::testing::Values( + std::vector({4}), + std::vector({8, 4}), + std::vector({4, 7, 13}), + std::vector({4, 8, 16, 32}), + std::vector({4, 16, 32, 4096}))); + +#endif + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 82a2e31485..bf79ce6c0d 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -89,6 +89,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InplaceClipGradNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FakeQuant); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FakeQuantGrad); // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training @@ -206,6 +207,8 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, + BuildKernelCreateInfo, // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training diff --git a/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc b/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc index b49d27e8b4..fec39da5b1 100644 --- a/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc +++ b/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc @@ -3,6 +3,7 @@ #include "orttraining/training_ops/cpu/quantization/fake_quant.h" #include "core/providers/common.h" +#include "core/providers/cpu/math/element_wise_ops.h" namespace onnxruntime { namespace contrib { @@ -35,6 +36,13 @@ void FakeQuantPerTensor(OpKernelContext* ctx, const int64_t num_elements, const } }); } + +template +void FakeQuantGradImpl(const Tensor& dY, const Tensor& gradient_mask, Tensor& dX) { + // If gradient_mask is true (i.e. quantization was in range), return dY, else return 0 + MakeEigenArrayMap(dX) = MakeEigenArrayMap(dY) * MakeEigenArrayMap(gradient_mask).template cast(); +} + } // namespace #define REGISTER_FAKEQUANT_KERNEL_TYPED(T) \ @@ -69,7 +77,7 @@ Status FakeQuant::Compute(OpKernelContext* ctx) const { T* fake_quantized_data = fake_quantized_tensor->MutableData(); bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData(); - // Copmute + // Compute // TODO(bmeswani): Add support for FakeQuantPerChannel FakeQuantPerTensor(ctx, input_tensor->Shape().Size(), input_data, *quant_scale, *quant_zero_point, quant_min_, quant_max_, fake_quantized_data, quantization_mask_data); @@ -77,5 +85,33 @@ Status FakeQuant::Compute(OpKernelContext* ctx) const { return Status::OK(); } +#define REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + FakeQuantGrad, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + FakeQuantGrad); + +REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(float) + +template +Status FakeQuantGrad::Compute(OpKernelContext* ctx) const { + // Prepare the gradient wrt the output and gradient mask input + const auto* dY = ctx->Input(0); + const auto* gradient_mask = ctx->Input(1); + + // Prepare the output + auto* dX = ctx->Output(0, dY->Shape()); + + // Compute + FakeQuantGradImpl(*dY, *gradient_mask, *dX); + + return Status::OK(); +} + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.h b/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.h index fe1134024d..854ee244ca 100644 --- a/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.h +++ b/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.h @@ -24,5 +24,14 @@ class FakeQuant final : public OpKernel { int64_t quant_max_; }; +template +class FakeQuantGrad final : public OpKernel { + public: + FakeQuantGrad(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; +}; + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 642a71fe4c..54f4ff27d6 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -193,6 +193,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, InplaceClipGradNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad); // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training @@ -426,6 +427,8 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, + BuildKernelCreateInfo, // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training diff --git a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.cc b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.cc index b8699bbd5a..1e5c3f55b7 100644 --- a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.cc +++ b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.cc @@ -42,8 +42,8 @@ Status FakeQuant::ComputeInternal(OpKernelContext* ctx) const { const CudaT* quant_zero_point = reinterpret_cast(zero_point->Data()); // Prepare the output, mask for gradient computation - auto& fake_quantized_tensor = *ctx->Output(0, input_tensor->Shape()); - CudaT* fake_quantized_data = reinterpret_cast(fake_quantized_tensor.MutableData()); + auto* fake_quantized_tensor = ctx->Output(0, input_tensor->Shape()); + CudaT* fake_quantized_data = reinterpret_cast(fake_quantized_tensor->MutableData()); bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData(); // Fake quantize the input tensor @@ -54,5 +54,38 @@ Status FakeQuant::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } +#define REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + FakeQuantGrad, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + FakeQuantGrad); + +REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(float) + +template +Status FakeQuantGrad::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToCudaType::MappedType CudaT; + + // Prepare the gradient wrt the output and gradient mask input + const auto* dY = ctx->Input(0); + const CudaT* dY_data = reinterpret_cast(dY->Data()); + const auto* gradient_mask = ctx->Input(1); + const bool* gradient_mask_data = gradient_mask->Data(); + + // Prepare the output + auto* dX = ctx->Output(0, dY->Shape()); + CudaT* dX_data = reinterpret_cast(dX->MutableData()); + + // Compute + FakeQuantGradImpl(Stream(), dY->Shape().Size(), dY_data, gradient_mask_data, dX_data); + + return Status::OK(); +} + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.h b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.h index 2e9c7fa18a..0bc016ebf5 100644 --- a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.h +++ b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.h @@ -25,5 +25,14 @@ class FakeQuant final : public CudaKernel { int64_t quant_max_; }; +template +class FakeQuantGrad final : public CudaKernel { + public: + FakeQuantGrad(const OpKernelInfo& info) : CudaKernel(info) { + } + + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.cu b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.cu index 677088e253..d3d479664b 100644 --- a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.cu @@ -3,6 +3,7 @@ #include "orttraining/training_ops/cuda/quantization/fake_quant_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cu_inc/elementwise_impl.cuh" namespace onnxruntime { namespace cuda { @@ -78,5 +79,36 @@ SPECIALIZED_FAKEQUANT_IMPL(float) #undef SPECIALIZED_FAKEQUANT_IMPL +template +struct FakeQuantGradFunctor { + FakeQuantGradFunctor(const T* dY_data, const bool* gradient_mask_data) + : dY_data_(dY_data), + gradient_mask_data_(gradient_mask_data) {} + + __device__ __inline__ T operator()(CUDA_LONG idx) const { + // If gradient_mask is true (i.e. quantization was in range), return dY, else return 0 + return gradient_mask_data_[idx] ? dY_data_[idx] : static_cast(0); + } + + const T* dY_data_; + const bool* gradient_mask_data_; +}; + +template +void FakeQuantGradImpl(cudaStream_t stream, const int64_t num_elements, const T* dY_data, + const bool* gradient_mask_data, T* dX_data) { + FakeQuantGradFunctor fake_quant_grad_functor(dY_data, gradient_mask_data); + LaunchElementwiseKernel( + stream, dX_data, fake_quant_grad_functor, num_elements); +} + +#define SPECIALIZED_FAKEQUANTGRAD_IMPL(T) \ + template void FakeQuantGradImpl(cudaStream_t stream, const int64_t num_elements, \ + const T* dY_data, const bool* gradient_mask_data, T* dX_data); + +SPECIALIZED_FAKEQUANTGRAD_IMPL(float) + +#undef SPECIALIZED_FAKEQUANTGRAD_IMPL + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.h b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.h index a2f3cf9774..6d22231a21 100644 --- a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.h +++ b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.h @@ -13,5 +13,9 @@ void FakeQuantPerTensor(cudaStream_t stream, const int64_t num_elements, const T const T quant_zero_point, const int64_t quant_min, const int64_t quant_max, T* fake_quantized_data, bool* quantization_mask_data); +template +void FakeQuantGradImpl(cudaStream_t stream, const int64_t num_elements, const T* dY_data, + const bool* gradient_mask_data, T* dX_data); + } // namespace cuda } // namespace onnxruntime