[QAT] FakeQuantGrad and gradient building for FakeQuant (#13825)

This commit is contained in:
Baiju Meswani 2022-12-14 11:54:02 -08:00 committed by GitHub
parent 6090d8cd6e
commit 8c249cc8f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 266 additions and 5 deletions

View file

@ -1878,5 +1878,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetScatterElementsGradient) {
return result;
}
IMPLEMENT_GRADIENT_BUILDER(GetFakeQuantGradient) {
return {NodeDef(OpDef{"FakeQuantGrad", kMSDomain, 1}, {GO(0), O(1)}, {GI(0)})};
}
} // namespace training
} // namespace onnxruntime

View file

@ -81,6 +81,7 @@ DECLARE_GRADIENT_BUILDER(GetPythonOpGradient)
DECLARE_GRADIENT_BUILDER(GetScatterNDGradient)
DECLARE_GRADIENT_BUILDER(GetScatterElementsGradient)
DECLARE_GRADIENT_BUILDER(GetTriluGradient)
DECLARE_GRADIENT_BUILDER(GetFakeQuantGradient)
DECLARE_GRADIENT_BUILDER(GetExternalGradient)

View file

@ -113,6 +113,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("ScatterND", GetScatterNDGradient);
REGISTER_GRADIENT_BUILDER("ScatterElements", GetScatterElementsGradient);
REGISTER_GRADIENT_BUILDER("Trilu", GetTriluGradient);
REGISTER_GRADIENT_BUILDER("FakeQuant", GetFakeQuantGradient);
REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};

View file

@ -4015,6 +4015,29 @@ Return true if all elements are true and false otherwise.
updateOutputElemType(ctx, 1, ONNX_NAMESPACE::TensorProto::BOOL);
propagateShapeFromInputToOutput(ctx, 0, 1);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(FakeQuantGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(
"FakeQuantGrad op that computes the partial derivative of the loss with respect to the input tensor to "
"the FakeQuant op.")
.Input(0, "dY", "Gradient of loss with respect to the output Y of the FakeQuant op (fake quantized output)", "T")
.Input(1, "gradient_mask",
"Gradient mask that indicates whether the quantized value is within the quantization range.",
"T_BOOL")
.Output(0, "dX", "Gradient of loss with respect to the input X (of the FakeQuant node).", "T")
.TypeConstraint(
"T",
{"tensor(float)"},
"Constrain the gradient input and output types to float tensors.")
.TypeConstraint(
"T_BOOL",
{"tensor(bool)"},
"Constrain the gradient mask input to bool tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateShapeAndTypeFromFirstInput(ctx);
});
}
} // namespace training

View file

@ -26,7 +26,7 @@ void CompareFakeQuantKernels(const std::vector<int64_t>& tensor_dim,
std::vector<float> scale = random.Uniform<float>(std::vector<int64_t>({1}), 0.04f, 0.1f);
test.AddInput<float>("scale", {1}, scale);
std::vector<float> zero_point = random.Uniform<float>(std::vector<int64_t>({1}), 0.f, 255.0f);
test.AddInput<float>("zero_scale", {1}, std::vector<float>({std::nearbyint(zero_point.front())}));
test.AddInput<float>("zero_point", {1}, std::vector<float>({std::nearbyint(zero_point.front())}));
// Create output tensors
std::vector<float> fake_quantized_data = FillZeros<float>(tensor_dim);
@ -36,6 +36,7 @@ void CompareFakeQuantKernels(const std::vector<int64_t>& tensor_dim,
test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
}
#endif
} // namespace
@ -54,7 +55,7 @@ TEST(FakeQuantTest, FakeQuantComputation) {
test.AddInput<float>("input_tensor", {10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
test.AddInput<float>("scale", {1}, {0.075f});
test.AddInput<float>("zero_scale", {1}, {128.0f});
test.AddInput<float>("zero_point", {1}, {128.0f});
// quantized values = nearby_int(value / scale + zero_point)
// = {13.33+128, 26.66+128, 40.00+128, 53.33+128, 66.66+128, ...}
// = {141.33, 154.66, 168.00, 171.33, 184.66, ...}
@ -79,5 +80,107 @@ TEST(CudaKernelTest, FakeQuant) {
}
#endif
class FakeQuantGradParameterizedTest : public ::testing::TestWithParam<TensorShape> {
};
TEST_P(FakeQuantGradParameterizedTest, FakeQuantGradComputation) {
std::vector<std::unique_ptr<IExecutionProvider>> providers;
providers.emplace_back(DefaultCpuExecutionProvider());
#ifdef USE_CUDA
providers.emplace_back(DefaultCudaExecutionProvider());
#endif
OpTester test("FakeQuantGrad", 1, kMSDomain, true);
auto x_shape = GetParam();
RandomValueGenerator random{};
// Randomly generate the gradient w.r.t. the output Y.
std::vector<float> dY_data = random.Uniform<float>(x_shape.GetDims(), -1000.0f, 1000.0f);
// Randomly generate the mask data
std::unique_ptr<bool[]> mask_data = [&random, &x_shape]() {
auto data_int = random.Uniform<int>(x_shape.GetDims(), 0, 2);
auto data_bool = std::make_unique<bool[]>(detail::SizeFromDims(x_shape.GetDims()));
for (size_t i = 0; i < data_int.size(); ++i) {
data_bool.get()[i] = data_int[i] == 0;
}
return data_bool;
}();
// Calculate the gradient w.r.t. the input X.
std::vector<float> dX_data = [](const auto& dY_data, const auto& mask_data) {
auto dX_data = dY_data;
for (size_t i = 0; i < dY_data.size(); ++i) {
if (!mask_data.get()[i])
dX_data[i] = 0.0f;
}
return dX_data;
}(dY_data, mask_data);
test.AddInput<float>("dY", x_shape.AsShapeVector(), dY_data);
test.AddInput<bool>("gradient_mask", x_shape.AsShapeVector(), mask_data.get(), x_shape.Size());
test.AddOutput<float>("dX", x_shape.AsShapeVector(), dX_data);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
}
INSTANTIATE_TEST_SUITE_P(
FakeQuantGradTests,
FakeQuantGradParameterizedTest,
::testing::Values(
TensorShape({4}),
TensorShape({8, 4}),
TensorShape({4, 7, 13}),
TensorShape({4, 8, 16, 32}),
TensorShape({4, 16, 32, 4096})));
#ifdef USE_CUDA
class FakeQuantGradKernelComparisonParameterizedTest : public ::testing::TestWithParam<std::vector<int64_t>> {
};
TEST_P(FakeQuantGradKernelComparisonParameterizedTest, FakeQuantGradKernels) {
CompareOpTester test("FakeQuantGrad", 1, onnxruntime::kMSDomain);
auto tensor_dim = GetParam();
// Randomly generate the gradient w.r.t. the output Y.
RandomValueGenerator random{};
std::vector<float> dY_data = random.Uniform<float>(tensor_dim, -1000.0f, 1000.0f);
// Randomly generate the mask data
std::unique_ptr<bool[]> mask_data = [&random, &tensor_dim]() {
auto data_int = random.Uniform<int>(tensor_dim, 0, 2);
auto data_bool = std::make_unique<bool[]>(detail::SizeFromDims(tensor_dim));
for (size_t i = 0; i < data_int.size(); ++i) {
data_bool.get()[i] = data_int[i] == 0;
}
return data_bool;
}();
// Initialize the gradient w.r.t. the input X with 0s.
std::vector<float> dX_data = FillZeros<float>(tensor_dim);
test.AddInput<float>("dY", tensor_dim, dY_data);
test.AddInput<bool>("gradient_mask", tensor_dim, mask_data.get(), detail::SizeFromDims(tensor_dim));
test.AddOutput<float>("dX", tensor_dim, dX_data);
// Compare the outputs from the two kernels
const double per_sample_tolerance = 2e-4;
const double relative_per_sample_tolerance = 2e-4;
test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
}
INSTANTIATE_TEST_SUITE_P(
FakeQuantGradTests,
FakeQuantGradKernelComparisonParameterizedTest,
::testing::Values(
std::vector<int64_t>({4}),
std::vector<int64_t>({8, 4}),
std::vector<int64_t>({4, 7, 13}),
std::vector<int64_t>({4, 8, 16, 32}),
std::vector<int64_t>({4, 16, 32, 4096})));
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -89,6 +89,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InplaceClipGradNorm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FakeQuant);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FakeQuantGrad);
// the kernels within the following ifdef are not included in a build with
// --enable_training_ops but without --enable_training
@ -206,6 +207,8 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
kCpuExecutionProvider, kMSDomain, 1, float, FakeQuant)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
kCpuExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>,
// the kernels within the following ifdef are not included in a build with
// --enable_training_ops but without --enable_training

View file

@ -3,6 +3,7 @@
#include "orttraining/training_ops/cpu/quantization/fake_quant.h"
#include "core/providers/common.h"
#include "core/providers/cpu/math/element_wise_ops.h"
namespace onnxruntime {
namespace contrib {
@ -35,6 +36,13 @@ void FakeQuantPerTensor(OpKernelContext* ctx, const int64_t num_elements, const
}
});
}
template <typename T>
void FakeQuantGradImpl(const Tensor& dY, const Tensor& gradient_mask, Tensor& dX) {
// If gradient_mask is true (i.e. quantization was in range), return dY, else return 0
MakeEigenArrayMap<T>(dX) = MakeEigenArrayMap<T>(dY) * MakeEigenArrayMap<bool>(gradient_mask).template cast<T>();
}
} // namespace
#define REGISTER_FAKEQUANT_KERNEL_TYPED(T) \
@ -69,7 +77,7 @@ Status FakeQuant<T>::Compute(OpKernelContext* ctx) const {
T* fake_quantized_data = fake_quantized_tensor->MutableData<T>();
bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData<bool>();
// Copmute
// Compute
// TODO(bmeswani): Add support for FakeQuantPerChannel
FakeQuantPerTensor(ctx, input_tensor->Shape().Size(), input_data, *quant_scale, *quant_zero_point, quant_min_,
quant_max_, fake_quantized_data, quantization_mask_data);
@ -77,5 +85,33 @@ Status FakeQuant<T>::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
#define REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
FakeQuantGrad, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
FakeQuantGrad<T>);
REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(float)
template <typename T>
Status FakeQuantGrad<T>::Compute(OpKernelContext* ctx) const {
// Prepare the gradient wrt the output and gradient mask input
const auto* dY = ctx->Input<Tensor>(0);
const auto* gradient_mask = ctx->Input<Tensor>(1);
// Prepare the output
auto* dX = ctx->Output(0, dY->Shape());
// Compute
FakeQuantGradImpl<T>(*dY, *gradient_mask, *dX);
return Status::OK();
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -24,5 +24,14 @@ class FakeQuant final : public OpKernel {
int64_t quant_max_;
};
template <typename T>
class FakeQuantGrad final : public OpKernel {
public:
FakeQuantGrad(const OpKernelInfo& info) : OpKernel(info) {
}
Status Compute(OpKernelContext* context) const override;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -193,6 +193,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, InplaceClipGradNorm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad);
// the kernels within the following ifdef are not included in a build with
// --enable_training_ops but without --enable_training
@ -426,6 +427,8 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>,
// the kernels within the following ifdef are not included in a build with
// --enable_training_ops but without --enable_training

View file

@ -42,8 +42,8 @@ Status FakeQuant<T>::ComputeInternal(OpKernelContext* ctx) const {
const CudaT* quant_zero_point = reinterpret_cast<const CudaT*>(zero_point->Data<T>());
// Prepare the output, mask for gradient computation
auto& fake_quantized_tensor = *ctx->Output(0, input_tensor->Shape());
CudaT* fake_quantized_data = reinterpret_cast<CudaT*>(fake_quantized_tensor.MutableData<T>());
auto* fake_quantized_tensor = ctx->Output(0, input_tensor->Shape());
CudaT* fake_quantized_data = reinterpret_cast<CudaT*>(fake_quantized_tensor->MutableData<T>());
bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData<bool>();
// Fake quantize the input tensor
@ -54,5 +54,38 @@ Status FakeQuant<T>::ComputeInternal(OpKernelContext* ctx) const {
return Status::OK();
}
#define REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
FakeQuantGrad, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
FakeQuantGrad<T>);
REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(float)
template <typename T>
Status FakeQuantGrad<T>::ComputeInternal(OpKernelContext* ctx) const {
typedef typename ToCudaType<T>::MappedType CudaT;
// Prepare the gradient wrt the output and gradient mask input
const auto* dY = ctx->Input<Tensor>(0);
const CudaT* dY_data = reinterpret_cast<const CudaT*>(dY->Data<T>());
const auto* gradient_mask = ctx->Input<Tensor>(1);
const bool* gradient_mask_data = gradient_mask->Data<bool>();
// Prepare the output
auto* dX = ctx->Output(0, dY->Shape());
CudaT* dX_data = reinterpret_cast<CudaT*>(dX->MutableData<T>());
// Compute
FakeQuantGradImpl(Stream(), dY->Shape().Size(), dY_data, gradient_mask_data, dX_data);
return Status::OK();
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -25,5 +25,14 @@ class FakeQuant final : public CudaKernel {
int64_t quant_max_;
};
template <typename T>
class FakeQuantGrad final : public CudaKernel {
public:
FakeQuantGrad(const OpKernelInfo& info) : CudaKernel(info) {
}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -3,6 +3,7 @@
#include "orttraining/training_ops/cuda/quantization/fake_quant_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cu_inc/elementwise_impl.cuh"
namespace onnxruntime {
namespace cuda {
@ -78,5 +79,36 @@ SPECIALIZED_FAKEQUANT_IMPL(float)
#undef SPECIALIZED_FAKEQUANT_IMPL
template <typename T>
struct FakeQuantGradFunctor {
FakeQuantGradFunctor(const T* dY_data, const bool* gradient_mask_data)
: dY_data_(dY_data),
gradient_mask_data_(gradient_mask_data) {}
__device__ __inline__ T operator()(CUDA_LONG idx) const {
// If gradient_mask is true (i.e. quantization was in range), return dY, else return 0
return gradient_mask_data_[idx] ? dY_data_[idx] : static_cast<T>(0);
}
const T* dY_data_;
const bool* gradient_mask_data_;
};
template <typename T>
void FakeQuantGradImpl(cudaStream_t stream, const int64_t num_elements, const T* dY_data,
const bool* gradient_mask_data, T* dX_data) {
FakeQuantGradFunctor<T> fake_quant_grad_functor(dY_data, gradient_mask_data);
LaunchElementwiseKernel<T, decltype(fake_quant_grad_functor)>(
stream, dX_data, fake_quant_grad_functor, num_elements);
}
#define SPECIALIZED_FAKEQUANTGRAD_IMPL(T) \
template void FakeQuantGradImpl<T>(cudaStream_t stream, const int64_t num_elements, \
const T* dY_data, const bool* gradient_mask_data, T* dX_data);
SPECIALIZED_FAKEQUANTGRAD_IMPL(float)
#undef SPECIALIZED_FAKEQUANTGRAD_IMPL
} // namespace cuda
} // namespace onnxruntime

View file

@ -13,5 +13,9 @@ void FakeQuantPerTensor(cudaStream_t stream, const int64_t num_elements, const T
const T quant_zero_point, const int64_t quant_min, const int64_t quant_max,
T* fake_quantized_data, bool* quantization_mask_data);
template <typename T>
void FakeQuantGradImpl(cudaStream_t stream, const int64_t num_elements, const T* dY_data,
const bool* gradient_mask_data, T* dX_data);
} // namespace cuda
} // namespace onnxruntime