diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index a0e010c419..b8d8d620d1 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3968,6 +3968,53 @@ Return true if all elements are true and false otherwise. .SetContextDependentFunctionBodyBuilder(BuildNllLossInternalFunction<13>) .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); }) .SetDoc(R"DOC(NegativeLogLikelihoodLossInternal)DOC"); + + ONNX_CONTRIB_OPERATOR_SCHEMA(FakeQuant) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc( + "FakeQuant operator that fuses quantization->dequantization pattern into a single node. " + "FakeQuant takes in a non quantized tensor as input and generates a non quantized tensor as output. " + "But internally, it will perform Quantization->Dequantization operation that simulates the effects of " + "quantization within the model. Loss in numerical precision introduced by model quantization is " + "corrected by adjusting the model weights through the FakeQuant op.") + .Input(0, "input", "Tensor to be fake quantized.", "T") + .Input(1, "scale", + "Quantization scale. It must be a scalar, which implies per-tensor quantization. " + "The scalar value must be greater than 0.", + "T") + .Input(2, "zero_point", + "Quantization zero point as non quantized type. It must be a scalar, which implies per-tensor " + "quantization.", + "T") + .Output(0, "output", "Input tensor after it has been fake quantized. It has the same shape as the input.", "T") + .Output(1, "mask", + "Mask where values indicate if the quantized value was in qmin, qmax range. " + "Needed for gradient computation. It has the same shape as the input.", + "T_BOOL") + .Attr( + "quant_min", + "Minimum quantization value.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "quant_max", + "Maximum quantization value.", + AttributeProto::INT, + static_cast(255)) + .TypeConstraint( + "T", + {"tensor(float)"}, + "Constrain the input tensor type to float tensors.") + .TypeConstraint( + "T_BOOL", + {"tensor(bool)"}, + "Constrain the gradient quantization mask type to boolean tensors.") + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + updateOutputElemType(ctx, 1, ONNX_NAMESPACE::TensorProto::BOOL); + propagateShapeFromInputToOutput(ctx, 0, 1); + }); } } // namespace training diff --git a/orttraining/orttraining/test/training_ops/cuda/fake_quant_test.cc b/orttraining/orttraining/test/training_ops/cuda/fake_quant_test.cc new file mode 100644 index 0000000000..c62584d04b --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/fake_quant_test.cc @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/providers/compare_provider_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime { +namespace test { + +namespace { + +#ifdef USE_CUDA +void CompareFakeQuantKernels(const std::vector& tensor_dim, + double per_sample_tolerance = 2e-4, + double relative_per_sample_tolerance = 2e-4) { + CompareOpTester test("FakeQuant", 1, onnxruntime::kMSDomain); + + test.AddAttribute("quant_min", 0); + test.AddAttribute("quant_max", 7500); + + // Create rand inputs for the input tensor, scale and zero point + RandomValueGenerator random{}; + std::vector input_data = random.Uniform(tensor_dim, -1000.0f, 1000.0f); + test.AddInput("input_tensor", tensor_dim, input_data); + std::vector scale = random.Uniform(std::vector({1}), 0.04f, 0.1f); + test.AddInput("scale", {1}, scale); + std::vector zero_point = random.Uniform(std::vector({1}), 0.f, 255.0f); + test.AddInput("zero_scale", {1}, std::vector({std::nearbyint(zero_point.front())})); + + // Create output tensors + std::vector fake_quantized_data = FillZeros(tensor_dim); + test.AddOutput("fake_quantized_tensor", tensor_dim, fake_quantized_data); + std::unique_ptr quantization_mask = std::make_unique(detail::SizeFromDims(tensor_dim)); + test.AddOutput("quantization_mask", tensor_dim, quantization_mask.get(), detail::SizeFromDims(tensor_dim)); + + test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance); +} +#endif + +} // namespace + +TEST(FakeQuantTest, FakeQuantComputation) { + std::vector> providers; + providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#endif + + OpTester test("FakeQuant", 1, onnxruntime::kMSDomain); + + test.AddAttribute("quant_min", 0); + test.AddAttribute("quant_max", 255); + + test.AddInput("input_tensor", {10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}); + test.AddInput("scale", {1}, {0.075f}); + test.AddInput("zero_scale", {1}, {128.0f}); + // quantized values = nearby_int(value / scale + zero_point) + // = {13.33+128, 26.66+128, 40.00+128, 53.33+128, 66.66+128, ...} + // = {141.33, 154.66, 168.00, 171.33, 184.66, ...} + // = {141, 155, 168, 181, 195, 208, 221, 235, 248, 261} + // de-quantized values = (clamp(value) - zero_point) * scale + // = {13*0.075, 27*0.075, 40*0.075, 53*0.075, 67*0.075, ..., 120*0.075, (255-128)*0.075} + // = {0.975, 2.025, 3.0, 3.975, 5.025, 6.0, 6.975, 8.025, 9.0, 9.525} + + test.AddOutput( + "fake_quantized_tensor", {10}, {0.975f, 2.025f, 3.0f, 3.975f, 5.025f, 6.0f, 6.975f, 8.025f, 9.0f, 9.525f}); + test.AddOutput("quantization_mask", {10}, {true, true, true, true, true, true, true, true, true, false}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +#ifdef USE_CUDA +TEST(CudaKernelTest, FakeQuant) { + std::vector> test_dims{{4}, {16, 2}, {8, 2, 128, 128}}; + for (const auto& test_dim : test_dims) { + CompareFakeQuantKernels(test_dim, false); + } +} +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index da74aae921..82a2e31485 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -88,6 +88,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InplaceClipGradNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FakeQuant); + // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training #ifdef ENABLE_TRAINING @@ -202,6 +204,9 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, + BuildKernelCreateInfo, + // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training #ifdef ENABLE_TRAINING diff --git a/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc b/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc new file mode 100644 index 0000000000..b49d27e8b4 --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cpu/quantization/fake_quant.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace contrib { + +namespace { + +template +void FakeQuantPerTensor(OpKernelContext* ctx, const int64_t num_elements, const T* input_data, T quant_scale, + T quant_zero_point, int64_t quant_min, int64_t quant_max, T* fake_quantized_data, + bool* quantization_mask_data) { + const auto zero_point_int = static_cast(quant_zero_point); + auto* tp = ctx->GetOperatorThreadPool(); + concurrency::ThreadPool::TryParallelFor( + tp, num_elements, /* 1 Read, 2 Writes, 4 Computes */ TensorOpCost{1.0, 2.0, 4.0}, + [quant_scale, zero_point_int, quant_min, quant_max, &input_data, &fake_quantized_data, &quantization_mask_data]( + std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t index = begin; index != end; ++index) { + size_t idx = static_cast(index); + + // Quantize + const auto quantized_value = static_cast(std::nearbyint(input_data[idx] / quant_scale)) + + zero_point_int; + + // Clamp and De-Quantize + fake_quantized_data[idx] = + (std::min(quant_max, std::max(quant_min, quantized_value)) - zero_point_int) * quant_scale; + + // Compute mask needed for gradient computation + quantization_mask_data[idx] = (quant_min <= quantized_value && quantized_value <= quant_max); + } + }); +} +} // namespace + +#define REGISTER_FAKEQUANT_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + FakeQuant, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + FakeQuant); + +REGISTER_FAKEQUANT_KERNEL_TYPED(float) + +template +Status FakeQuant::Compute(OpKernelContext* ctx) const { + // Prepare the input, scale, zero point + const auto* input_tensor = ctx->Input(0); + const T* input_data = input_tensor->Data(); + const auto* scale = ctx->Input(1); + ORT_ENFORCE(IsScalarOr1ElementVector(scale), "Quantization scale must be a scalar or 1D tensor of size 1."); + const T* quant_scale = scale->Data(); + ORT_ENFORCE(*quant_scale != static_cast(0), + "Quantization scale cannot be 0. It may result in undefined behavior."); + const auto* zero_point = ctx->Input(2); + ORT_ENFORCE(IsScalarOr1ElementVector(zero_point), "Quantization zero point must be a scalar or 1D tensor of size 1."); + const T* quant_zero_point = zero_point->Data(); + + // Prepare the output, mask for gradient computation + auto* fake_quantized_tensor = ctx->Output(0, input_tensor->Shape()); + T* fake_quantized_data = fake_quantized_tensor->MutableData(); + bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData(); + + // Copmute + // TODO(bmeswani): Add support for FakeQuantPerChannel + FakeQuantPerTensor(ctx, input_tensor->Shape().Size(), input_data, *quant_scale, *quant_zero_point, quant_min_, + quant_max_, fake_quantized_data, quantization_mask_data); + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.h b/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.h new file mode 100644 index 0000000000..fe1134024d --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/quantization/fake_quant.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +class FakeQuant final : public OpKernel { + public: + FakeQuant(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("quant_min", &quant_min_, static_cast(0)); + info.GetAttrOrDefault("quant_max", &quant_max_, static_cast(255)); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t quant_min_; + int64_t quant_max_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 1a48aae152..fe54adb680 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -210,6 +210,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, InplaceClipGradNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant); + #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send); @@ -435,6 +437,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, + BuildKernelCreateInfo, + // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI) BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.cc b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.cc new file mode 100644 index 0000000000..b8699bbd5a --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "orttraining/training_ops/cuda/quantization/fake_quant.h" +#include "orttraining/training_ops/cuda/quantization/fake_quant_impl.h" + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_FAKEQUANT_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + FakeQuant, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + FakeQuant); + +REGISTER_FAKEQUANT_KERNEL_TYPED(float) + +template +Status FakeQuant::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToCudaType::MappedType CudaT; + + // Prepare the input, scale, zero point + const auto* input_tensor = ctx->Input(0); + const CudaT* input_data = reinterpret_cast(input_tensor->Data()); + const auto* scale = ctx->Input(1); + ORT_ENFORCE(IsScalarOr1ElementVector(scale), "Quantization scale must be a scalar or 1D tensor of size 1."); + const CudaT* quant_scale = reinterpret_cast(scale->Data()); + ORT_ENFORCE(*quant_scale != static_cast(0), + "Quantization scale cannot be 0. It may result in undefined behavior."); + const auto* zero_point = ctx->Input(2); + ORT_ENFORCE(IsScalarOr1ElementVector(zero_point), "Quantization zero point must be a scalar or 1D tensor of size 1."); + const CudaT* quant_zero_point = reinterpret_cast(zero_point->Data()); + + // Prepare the output, mask for gradient computation + auto& fake_quantized_tensor = *ctx->Output(0, input_tensor->Shape()); + CudaT* fake_quantized_data = reinterpret_cast(fake_quantized_tensor.MutableData()); + bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData(); + + // Fake quantize the input tensor + // TODO(bmeswani): Add support for FakeQuantPerChannel + FakeQuantPerTensor(Stream(), input_tensor->Shape().Size(), input_data, *quant_scale, *quant_zero_point, quant_min_, + quant_max_, fake_quantized_data, quantization_mask_data); + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.h b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.h new file mode 100644 index 0000000000..2e9c7fa18a --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +template +class FakeQuant final : public CudaKernel { + public: + FakeQuant(const OpKernelInfo& info) : CudaKernel(info) { + info.GetAttrOrDefault("quant_min", &quant_min_, static_cast(0)); + info.GetAttrOrDefault("quant_max", &quant_max_, static_cast(255)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t quant_min_; + int64_t quant_max_; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.cu b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.cu new file mode 100644 index 0000000000..677088e253 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.cu @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/quantization/fake_quant_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" + +namespace onnxruntime { +namespace cuda { + +namespace { +constexpr int NumElementsPerThread = GridDim::maxElementsPerThread; +constexpr int NumThreadsPerBlock = GridDim::maxThreadsPerBlock; +} // namespace + +template +__global__ void FakeQuantPerTensorImpl(const int64_t num_elements, const T* input_data, const T quant_scale, + const T quant_zero_point, const int64_t quant_min, const int64_t quant_max, + T* fake_quantized_data, bool* quantization_mask_data) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + + T values[NumElementsPerThread]; + T fake_quantized_values[NumElementsPerThread]; + bool mask_values[NumElementsPerThread]; + + CUDA_LONG idx = start; + // Load +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (idx < num_elements) { + values[i] = input_data[idx]; + idx += NumThreadsPerBlock; + } + } + + // Compute +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + // Quantize + const auto quantized_value = std::nearbyint(values[i] / quant_scale) + quant_zero_point; + // Clamp and De-Quantize + fake_quantized_values[i] = + (fminf(quant_max, fmaxf(quant_min, quantized_value)) - quant_zero_point) * quant_scale; + // Compute mask + mask_values[i] = (quant_min <= quantized_value && quantized_value <= quant_max); + } + + // Write + idx = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (idx < num_elements) { + fake_quantized_data[idx] = fake_quantized_values[i]; + quantization_mask_data[idx] = mask_values[i]; + idx += NumThreadsPerBlock; + } + } +} + +template +void FakeQuantPerTensor(cudaStream_t stream, const int64_t num_elements, const T* input_data, const T quant_scale, + const T quant_zero_point, const int64_t quant_min, const int64_t quant_max, + T* fake_quantized_data, bool* quantization_mask_data) { + int blocksPerGrid = + static_cast(CeilDiv(num_elements, NumThreadsPerBlock * NumElementsPerThread)); + FakeQuantPerTensorImpl<<>>( + num_elements, input_data, quant_scale, quant_zero_point, + quant_min, quant_max, fake_quantized_data, quantization_mask_data); +} + +#define SPECIALIZED_FAKEQUANT_IMPL(T) \ + template void FakeQuantPerTensor(cudaStream_t stream, const int64_t num_elements, \ + const T* input_data, const T quant_scale, \ + const T quant_zero_point, const int64_t quant_min, \ + const int64_t quant_max, T* fake_quantized_data, \ + bool* quantization_mask_data); + +SPECIALIZED_FAKEQUANT_IMPL(float) + +#undef SPECIALIZED_FAKEQUANT_IMPL + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.h b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.h new file mode 100644 index 0000000000..a2f3cf9774 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/quantization/fake_quant_impl.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace cuda { + +template +void FakeQuantPerTensor(cudaStream_t stream, const int64_t num_elements, const T* input_data, const T quant_scale, + const T quant_zero_point, const int64_t quant_min, const int64_t quant_max, + T* fake_quantized_data, bool* quantization_mask_data); + +} // namespace cuda +} // namespace onnxruntime