mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[QAT] FakeQuantGrad and gradient building for FakeQuant (#13825)
This commit is contained in:
parent
6090d8cd6e
commit
8c249cc8f7
13 changed files with 266 additions and 5 deletions
|
|
@ -1878,5 +1878,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetScatterElementsGradient) {
|
|||
return result;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetFakeQuantGradient) {
|
||||
return {NodeDef(OpDef{"FakeQuantGrad", kMSDomain, 1}, {GO(0), O(1)}, {GI(0)})};
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ DECLARE_GRADIENT_BUILDER(GetPythonOpGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetScatterNDGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetScatterElementsGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetTriluGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetFakeQuantGradient)
|
||||
|
||||
DECLARE_GRADIENT_BUILDER(GetExternalGradient)
|
||||
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("ScatterND", GetScatterNDGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ScatterElements", GetScatterElementsGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Trilu", GetTriluGradient);
|
||||
REGISTER_GRADIENT_BUILDER("FakeQuant", GetFakeQuantGradient);
|
||||
|
||||
REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -4015,6 +4015,29 @@ Return true if all elements are true and false otherwise.
|
|||
updateOutputElemType(ctx, 1, ONNX_NAMESPACE::TensorProto::BOOL);
|
||||
propagateShapeFromInputToOutput(ctx, 0, 1);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(FakeQuantGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(
|
||||
"FakeQuantGrad op that computes the partial derivative of the loss with respect to the input tensor to "
|
||||
"the FakeQuant op.")
|
||||
.Input(0, "dY", "Gradient of loss with respect to the output Y of the FakeQuant op (fake quantized output)", "T")
|
||||
.Input(1, "gradient_mask",
|
||||
"Gradient mask that indicates whether the quantized value is within the quantization range.",
|
||||
"T_BOOL")
|
||||
.Output(0, "dX", "Gradient of loss with respect to the input X (of the FakeQuant node).", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float)"},
|
||||
"Constrain the gradient input and output types to float tensors.")
|
||||
.TypeConstraint(
|
||||
"T_BOOL",
|
||||
{"tensor(bool)"},
|
||||
"Constrain the gradient mask input to bool tensors.")
|
||||
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
|
||||
propagateShapeAndTypeFromFirstInput(ctx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ void CompareFakeQuantKernels(const std::vector<int64_t>& tensor_dim,
|
|||
std::vector<float> scale = random.Uniform<float>(std::vector<int64_t>({1}), 0.04f, 0.1f);
|
||||
test.AddInput<float>("scale", {1}, scale);
|
||||
std::vector<float> zero_point = random.Uniform<float>(std::vector<int64_t>({1}), 0.f, 255.0f);
|
||||
test.AddInput<float>("zero_scale", {1}, std::vector<float>({std::nearbyint(zero_point.front())}));
|
||||
test.AddInput<float>("zero_point", {1}, std::vector<float>({std::nearbyint(zero_point.front())}));
|
||||
|
||||
// Create output tensors
|
||||
std::vector<float> fake_quantized_data = FillZeros<float>(tensor_dim);
|
||||
|
|
@ -36,6 +36,7 @@ void CompareFakeQuantKernels(const std::vector<int64_t>& tensor_dim,
|
|||
|
||||
test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
|
|
@ -54,7 +55,7 @@ TEST(FakeQuantTest, FakeQuantComputation) {
|
|||
|
||||
test.AddInput<float>("input_tensor", {10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
|
||||
test.AddInput<float>("scale", {1}, {0.075f});
|
||||
test.AddInput<float>("zero_scale", {1}, {128.0f});
|
||||
test.AddInput<float>("zero_point", {1}, {128.0f});
|
||||
// quantized values = nearby_int(value / scale + zero_point)
|
||||
// = {13.33+128, 26.66+128, 40.00+128, 53.33+128, 66.66+128, ...}
|
||||
// = {141.33, 154.66, 168.00, 171.33, 184.66, ...}
|
||||
|
|
@ -79,5 +80,107 @@ TEST(CudaKernelTest, FakeQuant) {
|
|||
}
|
||||
#endif
|
||||
|
||||
class FakeQuantGradParameterizedTest : public ::testing::TestWithParam<TensorShape> {
|
||||
};
|
||||
|
||||
TEST_P(FakeQuantGradParameterizedTest, FakeQuantGradComputation) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> providers;
|
||||
providers.emplace_back(DefaultCpuExecutionProvider());
|
||||
#ifdef USE_CUDA
|
||||
providers.emplace_back(DefaultCudaExecutionProvider());
|
||||
#endif
|
||||
OpTester test("FakeQuantGrad", 1, kMSDomain, true);
|
||||
|
||||
auto x_shape = GetParam();
|
||||
RandomValueGenerator random{};
|
||||
|
||||
// Randomly generate the gradient w.r.t. the output Y.
|
||||
std::vector<float> dY_data = random.Uniform<float>(x_shape.GetDims(), -1000.0f, 1000.0f);
|
||||
|
||||
// Randomly generate the mask data
|
||||
std::unique_ptr<bool[]> mask_data = [&random, &x_shape]() {
|
||||
auto data_int = random.Uniform<int>(x_shape.GetDims(), 0, 2);
|
||||
auto data_bool = std::make_unique<bool[]>(detail::SizeFromDims(x_shape.GetDims()));
|
||||
for (size_t i = 0; i < data_int.size(); ++i) {
|
||||
data_bool.get()[i] = data_int[i] == 0;
|
||||
}
|
||||
return data_bool;
|
||||
}();
|
||||
|
||||
// Calculate the gradient w.r.t. the input X.
|
||||
std::vector<float> dX_data = [](const auto& dY_data, const auto& mask_data) {
|
||||
auto dX_data = dY_data;
|
||||
for (size_t i = 0; i < dY_data.size(); ++i) {
|
||||
if (!mask_data.get()[i])
|
||||
dX_data[i] = 0.0f;
|
||||
}
|
||||
return dX_data;
|
||||
}(dY_data, mask_data);
|
||||
|
||||
test.AddInput<float>("dY", x_shape.AsShapeVector(), dY_data);
|
||||
test.AddInput<bool>("gradient_mask", x_shape.AsShapeVector(), mask_data.get(), x_shape.Size());
|
||||
test.AddOutput<float>("dX", x_shape.AsShapeVector(), dX_data);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
FakeQuantGradTests,
|
||||
FakeQuantGradParameterizedTest,
|
||||
::testing::Values(
|
||||
TensorShape({4}),
|
||||
TensorShape({8, 4}),
|
||||
TensorShape({4, 7, 13}),
|
||||
TensorShape({4, 8, 16, 32}),
|
||||
TensorShape({4, 16, 32, 4096})));
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
class FakeQuantGradKernelComparisonParameterizedTest : public ::testing::TestWithParam<std::vector<int64_t>> {
|
||||
};
|
||||
|
||||
TEST_P(FakeQuantGradKernelComparisonParameterizedTest, FakeQuantGradKernels) {
|
||||
CompareOpTester test("FakeQuantGrad", 1, onnxruntime::kMSDomain);
|
||||
auto tensor_dim = GetParam();
|
||||
|
||||
// Randomly generate the gradient w.r.t. the output Y.
|
||||
RandomValueGenerator random{};
|
||||
std::vector<float> dY_data = random.Uniform<float>(tensor_dim, -1000.0f, 1000.0f);
|
||||
|
||||
// Randomly generate the mask data
|
||||
std::unique_ptr<bool[]> mask_data = [&random, &tensor_dim]() {
|
||||
auto data_int = random.Uniform<int>(tensor_dim, 0, 2);
|
||||
auto data_bool = std::make_unique<bool[]>(detail::SizeFromDims(tensor_dim));
|
||||
for (size_t i = 0; i < data_int.size(); ++i) {
|
||||
data_bool.get()[i] = data_int[i] == 0;
|
||||
}
|
||||
return data_bool;
|
||||
}();
|
||||
|
||||
// Initialize the gradient w.r.t. the input X with 0s.
|
||||
std::vector<float> dX_data = FillZeros<float>(tensor_dim);
|
||||
|
||||
test.AddInput<float>("dY", tensor_dim, dY_data);
|
||||
test.AddInput<bool>("gradient_mask", tensor_dim, mask_data.get(), detail::SizeFromDims(tensor_dim));
|
||||
test.AddOutput<float>("dX", tensor_dim, dX_data);
|
||||
|
||||
// Compare the outputs from the two kernels
|
||||
const double per_sample_tolerance = 2e-4;
|
||||
const double relative_per_sample_tolerance = 2e-4;
|
||||
test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
FakeQuantGradTests,
|
||||
FakeQuantGradKernelComparisonParameterizedTest,
|
||||
::testing::Values(
|
||||
std::vector<int64_t>({4}),
|
||||
std::vector<int64_t>({8, 4}),
|
||||
std::vector<int64_t>({4, 7, 13}),
|
||||
std::vector<int64_t>({4, 8, 16, 32}),
|
||||
std::vector<int64_t>({4, 16, 32, 4096})));
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InplaceClipGradNorm);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FakeQuant);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FakeQuantGrad);
|
||||
|
||||
// the kernels within the following ifdef are not included in a build with
|
||||
// --enable_training_ops but without --enable_training
|
||||
|
|
@ -206,6 +207,8 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCpuExecutionProvider, kMSDomain, 1, float, FakeQuant)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCpuExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>,
|
||||
|
||||
// the kernels within the following ifdef are not included in a build with
|
||||
// --enable_training_ops but without --enable_training
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "orttraining/training_ops/cpu/quantization/fake_quant.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -35,6 +36,13 @@ void FakeQuantPerTensor(OpKernelContext* ctx, const int64_t num_elements, const
|
|||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FakeQuantGradImpl(const Tensor& dY, const Tensor& gradient_mask, Tensor& dX) {
|
||||
// If gradient_mask is true (i.e. quantization was in range), return dY, else return 0
|
||||
MakeEigenArrayMap<T>(dX) = MakeEigenArrayMap<T>(dY) * MakeEigenArrayMap<bool>(gradient_mask).template cast<T>();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#define REGISTER_FAKEQUANT_KERNEL_TYPED(T) \
|
||||
|
|
@ -69,7 +77,7 @@ Status FakeQuant<T>::Compute(OpKernelContext* ctx) const {
|
|||
T* fake_quantized_data = fake_quantized_tensor->MutableData<T>();
|
||||
bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData<bool>();
|
||||
|
||||
// Copmute
|
||||
// Compute
|
||||
// TODO(bmeswani): Add support for FakeQuantPerChannel
|
||||
FakeQuantPerTensor(ctx, input_tensor->Shape().Size(), input_data, *quant_scale, *quant_zero_point, quant_min_,
|
||||
quant_max_, fake_quantized_data, quantization_mask_data);
|
||||
|
|
@ -77,5 +85,33 @@ Status FakeQuant<T>::Compute(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
FakeQuantGrad, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCpuExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
FakeQuantGrad<T>);
|
||||
|
||||
REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(float)
|
||||
|
||||
template <typename T>
|
||||
Status FakeQuantGrad<T>::Compute(OpKernelContext* ctx) const {
|
||||
// Prepare the gradient wrt the output and gradient mask input
|
||||
const auto* dY = ctx->Input<Tensor>(0);
|
||||
const auto* gradient_mask = ctx->Input<Tensor>(1);
|
||||
|
||||
// Prepare the output
|
||||
auto* dX = ctx->Output(0, dY->Shape());
|
||||
|
||||
// Compute
|
||||
FakeQuantGradImpl<T>(*dY, *gradient_mask, *dX);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -24,5 +24,14 @@ class FakeQuant final : public OpKernel {
|
|||
int64_t quant_max_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FakeQuantGrad final : public OpKernel {
|
||||
public:
|
||||
FakeQuantGrad(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -193,6 +193,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, InplaceClipGradNorm);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad);
|
||||
|
||||
// the kernels within the following ifdef are not included in a build with
|
||||
// --enable_training_ops but without --enable_training
|
||||
|
|
@ -426,6 +427,8 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>,
|
||||
|
||||
// the kernels within the following ifdef are not included in a build with
|
||||
// --enable_training_ops but without --enable_training
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ Status FakeQuant<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
const CudaT* quant_zero_point = reinterpret_cast<const CudaT*>(zero_point->Data<T>());
|
||||
|
||||
// Prepare the output, mask for gradient computation
|
||||
auto& fake_quantized_tensor = *ctx->Output(0, input_tensor->Shape());
|
||||
CudaT* fake_quantized_data = reinterpret_cast<CudaT*>(fake_quantized_tensor.MutableData<T>());
|
||||
auto* fake_quantized_tensor = ctx->Output(0, input_tensor->Shape());
|
||||
CudaT* fake_quantized_data = reinterpret_cast<CudaT*>(fake_quantized_tensor->MutableData<T>());
|
||||
bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData<bool>();
|
||||
|
||||
// Fake quantize the input tensor
|
||||
|
|
@ -54,5 +54,38 @@ Status FakeQuant<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
FakeQuantGrad, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
FakeQuantGrad<T>);
|
||||
|
||||
REGISTER_FAKEQUANTGRAD_KERNEL_TYPED(float)
|
||||
|
||||
template <typename T>
|
||||
Status FakeQuantGrad<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
// Prepare the gradient wrt the output and gradient mask input
|
||||
const auto* dY = ctx->Input<Tensor>(0);
|
||||
const CudaT* dY_data = reinterpret_cast<const CudaT*>(dY->Data<T>());
|
||||
const auto* gradient_mask = ctx->Input<Tensor>(1);
|
||||
const bool* gradient_mask_data = gradient_mask->Data<bool>();
|
||||
|
||||
// Prepare the output
|
||||
auto* dX = ctx->Output(0, dY->Shape());
|
||||
CudaT* dX_data = reinterpret_cast<CudaT*>(dX->MutableData<T>());
|
||||
|
||||
// Compute
|
||||
FakeQuantGradImpl(Stream(), dY->Shape().Size(), dY_data, gradient_mask_data, dX_data);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -25,5 +25,14 @@ class FakeQuant final : public CudaKernel {
|
|||
int64_t quant_max_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FakeQuantGrad final : public CudaKernel {
|
||||
public:
|
||||
FakeQuantGrad(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "orttraining/training_ops/cuda/quantization/fake_quant_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cu_inc/elementwise_impl.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
@ -78,5 +79,36 @@ SPECIALIZED_FAKEQUANT_IMPL(float)
|
|||
|
||||
#undef SPECIALIZED_FAKEQUANT_IMPL
|
||||
|
||||
template <typename T>
|
||||
struct FakeQuantGradFunctor {
|
||||
FakeQuantGradFunctor(const T* dY_data, const bool* gradient_mask_data)
|
||||
: dY_data_(dY_data),
|
||||
gradient_mask_data_(gradient_mask_data) {}
|
||||
|
||||
__device__ __inline__ T operator()(CUDA_LONG idx) const {
|
||||
// If gradient_mask is true (i.e. quantization was in range), return dY, else return 0
|
||||
return gradient_mask_data_[idx] ? dY_data_[idx] : static_cast<T>(0);
|
||||
}
|
||||
|
||||
const T* dY_data_;
|
||||
const bool* gradient_mask_data_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void FakeQuantGradImpl(cudaStream_t stream, const int64_t num_elements, const T* dY_data,
|
||||
const bool* gradient_mask_data, T* dX_data) {
|
||||
FakeQuantGradFunctor<T> fake_quant_grad_functor(dY_data, gradient_mask_data);
|
||||
LaunchElementwiseKernel<T, decltype(fake_quant_grad_functor)>(
|
||||
stream, dX_data, fake_quant_grad_functor, num_elements);
|
||||
}
|
||||
|
||||
#define SPECIALIZED_FAKEQUANTGRAD_IMPL(T) \
|
||||
template void FakeQuantGradImpl<T>(cudaStream_t stream, const int64_t num_elements, \
|
||||
const T* dY_data, const bool* gradient_mask_data, T* dX_data);
|
||||
|
||||
SPECIALIZED_FAKEQUANTGRAD_IMPL(float)
|
||||
|
||||
#undef SPECIALIZED_FAKEQUANTGRAD_IMPL
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -13,5 +13,9 @@ void FakeQuantPerTensor(cudaStream_t stream, const int64_t num_elements, const T
|
|||
const T quant_zero_point, const int64_t quant_min, const int64_t quant_max,
|
||||
T* fake_quantized_data, bool* quantization_mask_data);
|
||||
|
||||
template <typename T>
|
||||
void FakeQuantGradImpl(cudaStream_t stream, const int64_t num_elements, const T* dY_data,
|
||||
const bool* gradient_mask_data, T* dX_data);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue