mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
[QAT] Introduce FakeQuant op (#13649)
This commit is contained in:
parent
49c3768985
commit
2c29938846
10 changed files with 435 additions and 0 deletions
|
|
@ -3968,6 +3968,53 @@ Return true if all elements are true and false otherwise.
|
|||
.SetContextDependentFunctionBodyBuilder(BuildNllLossInternalFunction<13>)
|
||||
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); })
|
||||
.SetDoc(R"DOC(NegativeLogLikelihoodLossInternal)DOC");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(FakeQuant)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(
|
||||
"FakeQuant operator that fuses quantization->dequantization pattern into a single node. "
|
||||
"FakeQuant takes in a non quantized tensor as input and generates a non quantized tensor as output. "
|
||||
"But internally, it will perform Quantization->Dequantization operation that simulates the effects of "
|
||||
"quantization within the model. Loss in numerical precision introduced by model quantization is "
|
||||
"corrected by adjusting the model weights through the FakeQuant op.")
|
||||
.Input(0, "input", "Tensor to be fake quantized.", "T")
|
||||
.Input(1, "scale",
|
||||
"Quantization scale. It must be a scalar, which implies per-tensor quantization. "
|
||||
"The scalar value must be greater than 0.",
|
||||
"T")
|
||||
.Input(2, "zero_point",
|
||||
"Quantization zero point as non quantized type. It must be a scalar, which implies per-tensor "
|
||||
"quantization.",
|
||||
"T")
|
||||
.Output(0, "output", "Input tensor after it has been fake quantized. It has the same shape as the input.", "T")
|
||||
.Output(1, "mask",
|
||||
"Mask where values indicate if the quantized value was in qmin, qmax range. "
|
||||
"Needed for gradient computation. It has the same shape as the input.",
|
||||
"T_BOOL")
|
||||
.Attr(
|
||||
"quant_min",
|
||||
"Minimum quantization value.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr(
|
||||
"quant_max",
|
||||
"Maximum quantization value.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(255))
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float)"},
|
||||
"Constrain the input tensor type to float tensors.")
|
||||
.TypeConstraint(
|
||||
"T_BOOL",
|
||||
{"tensor(bool)"},
|
||||
"Constrain the gradient quantization mask type to boolean tensors.")
|
||||
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
|
||||
propagateShapeAndTypeFromFirstInput(ctx);
|
||||
updateOutputElemType(ctx, 1, ONNX_NAMESPACE::TensorProto::BOOL);
|
||||
propagateShapeFromInputToOutput(ctx, 0, 1);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -0,0 +1,83 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "test/providers/compare_provider_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
namespace {
|
||||
|
||||
#ifdef USE_CUDA
|
||||
void CompareFakeQuantKernels(const std::vector<int64_t>& tensor_dim,
|
||||
double per_sample_tolerance = 2e-4,
|
||||
double relative_per_sample_tolerance = 2e-4) {
|
||||
CompareOpTester test("FakeQuant", 1, onnxruntime::kMSDomain);
|
||||
|
||||
test.AddAttribute<int64_t>("quant_min", 0);
|
||||
test.AddAttribute<int64_t>("quant_max", 7500);
|
||||
|
||||
// Create rand inputs for the input tensor, scale and zero point
|
||||
RandomValueGenerator random{};
|
||||
std::vector<float> input_data = random.Uniform<float>(tensor_dim, -1000.0f, 1000.0f);
|
||||
test.AddInput<float>("input_tensor", tensor_dim, input_data);
|
||||
std::vector<float> scale = random.Uniform<float>(std::vector<int64_t>({1}), 0.04f, 0.1f);
|
||||
test.AddInput<float>("scale", {1}, scale);
|
||||
std::vector<float> zero_point = random.Uniform<float>(std::vector<int64_t>({1}), 0.f, 255.0f);
|
||||
test.AddInput<float>("zero_scale", {1}, std::vector<float>({std::nearbyint(zero_point.front())}));
|
||||
|
||||
// Create output tensors
|
||||
std::vector<float> fake_quantized_data = FillZeros<float>(tensor_dim);
|
||||
test.AddOutput<float>("fake_quantized_tensor", tensor_dim, fake_quantized_data);
|
||||
std::unique_ptr<bool[]> quantization_mask = std::make_unique<bool[]>(detail::SizeFromDims(tensor_dim));
|
||||
test.AddOutput<bool>("quantization_mask", tensor_dim, quantization_mask.get(), detail::SizeFromDims(tensor_dim));
|
||||
|
||||
test.CompareWithCPU(kCudaExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(FakeQuantTest, FakeQuantComputation) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> providers;
|
||||
providers.emplace_back(DefaultCpuExecutionProvider());
|
||||
#ifdef USE_CUDA
|
||||
providers.emplace_back(DefaultCudaExecutionProvider());
|
||||
#endif
|
||||
|
||||
OpTester test("FakeQuant", 1, onnxruntime::kMSDomain);
|
||||
|
||||
test.AddAttribute<int64_t>("quant_min", 0);
|
||||
test.AddAttribute<int64_t>("quant_max", 255);
|
||||
|
||||
test.AddInput<float>("input_tensor", {10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
|
||||
test.AddInput<float>("scale", {1}, {0.075f});
|
||||
test.AddInput<float>("zero_scale", {1}, {128.0f});
|
||||
// quantized values = nearby_int(value / scale + zero_point)
|
||||
// = {13.33+128, 26.66+128, 40.00+128, 53.33+128, 66.66+128, ...}
|
||||
// = {141.33, 154.66, 168.00, 171.33, 184.66, ...}
|
||||
// = {141, 155, 168, 181, 195, 208, 221, 235, 248, 261}
|
||||
// de-quantized values = (clamp(value) - zero_point) * scale
|
||||
// = {13*0.075, 27*0.075, 40*0.075, 53*0.075, 67*0.075, ..., 120*0.075, (255-128)*0.075}
|
||||
// = {0.975, 2.025, 3.0, 3.975, 5.025, 6.0, 6.975, 8.025, 9.0, 9.525}
|
||||
|
||||
test.AddOutput<float>(
|
||||
"fake_quantized_tensor", {10}, {0.975f, 2.025f, 3.0f, 3.975f, 5.025f, 6.0f, 6.975f, 8.025f, 9.0f, 9.525f});
|
||||
test.AddOutput<bool>("quantization_mask", {10}, {true, true, true, true, true, true, true, true, true, false});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST(CudaKernelTest, FakeQuant) {
|
||||
std::vector<std::vector<int64_t>> test_dims{{4}, {16, 2}, {8, 2, 128, 128}};
|
||||
for (const auto& test_dim : test_dims) {
|
||||
CompareFakeQuantKernels(test_dim, false);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -88,6 +88,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InplaceClipGradNorm);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FakeQuant);
|
||||
|
||||
// the kernels within the following ifdef are not included in a build with
|
||||
// --enable_training_ops but without --enable_training
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
@ -202,6 +204,9 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InplaceClipGradNorm)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCpuExecutionProvider, kMSDomain, 1, float, FakeQuant)>,
|
||||
|
||||
// the kernels within the following ifdef are not included in a build with
|
||||
// --enable_training_ops but without --enable_training
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cpu/quantization/fake_quant.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void FakeQuantPerTensor(OpKernelContext* ctx, const int64_t num_elements, const T* input_data, T quant_scale,
|
||||
T quant_zero_point, int64_t quant_min, int64_t quant_max, T* fake_quantized_data,
|
||||
bool* quantization_mask_data) {
|
||||
const auto zero_point_int = static_cast<int64_t>(quant_zero_point);
|
||||
auto* tp = ctx->GetOperatorThreadPool();
|
||||
concurrency::ThreadPool::TryParallelFor(
|
||||
tp, num_elements, /* 1 Read, 2 Writes, 4 Computes */ TensorOpCost{1.0, 2.0, 4.0},
|
||||
[quant_scale, zero_point_int, quant_min, quant_max, &input_data, &fake_quantized_data, &quantization_mask_data](
|
||||
std::ptrdiff_t begin, std::ptrdiff_t end) {
|
||||
for (std::ptrdiff_t index = begin; index != end; ++index) {
|
||||
size_t idx = static_cast<size_t>(index);
|
||||
|
||||
// Quantize
|
||||
const auto quantized_value = static_cast<int64_t>(std::nearbyint(input_data[idx] / quant_scale)) +
|
||||
zero_point_int;
|
||||
|
||||
// Clamp and De-Quantize
|
||||
fake_quantized_data[idx] =
|
||||
(std::min(quant_max, std::max(quant_min, quantized_value)) - zero_point_int) * quant_scale;
|
||||
|
||||
// Compute mask needed for gradient computation
|
||||
quantization_mask_data[idx] = (quant_min <= quantized_value && quantized_value <= quant_max);
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#define REGISTER_FAKEQUANT_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
FakeQuant, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCpuExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
FakeQuant<T>);
|
||||
|
||||
REGISTER_FAKEQUANT_KERNEL_TYPED(float)
|
||||
|
||||
template <typename T>
|
||||
Status FakeQuant<T>::Compute(OpKernelContext* ctx) const {
|
||||
// Prepare the input, scale, zero point
|
||||
const auto* input_tensor = ctx->Input<Tensor>(0);
|
||||
const T* input_data = input_tensor->Data<T>();
|
||||
const auto* scale = ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(IsScalarOr1ElementVector(scale), "Quantization scale must be a scalar or 1D tensor of size 1.");
|
||||
const T* quant_scale = scale->Data<T>();
|
||||
ORT_ENFORCE(*quant_scale != static_cast<T>(0),
|
||||
"Quantization scale cannot be 0. It may result in undefined behavior.");
|
||||
const auto* zero_point = ctx->Input<Tensor>(2);
|
||||
ORT_ENFORCE(IsScalarOr1ElementVector(zero_point), "Quantization zero point must be a scalar or 1D tensor of size 1.");
|
||||
const T* quant_zero_point = zero_point->Data<T>();
|
||||
|
||||
// Prepare the output, mask for gradient computation
|
||||
auto* fake_quantized_tensor = ctx->Output(0, input_tensor->Shape());
|
||||
T* fake_quantized_data = fake_quantized_tensor->MutableData<T>();
|
||||
bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData<bool>();
|
||||
|
||||
// Copmute
|
||||
// TODO(bmeswani): Add support for FakeQuantPerChannel
|
||||
FakeQuantPerTensor(ctx, input_tensor->Shape().Size(), input_data, *quant_scale, *quant_zero_point, quant_min_,
|
||||
quant_max_, fake_quantized_data, quantization_mask_data);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
class FakeQuant final : public OpKernel {
|
||||
public:
|
||||
FakeQuant(const OpKernelInfo& info) : OpKernel(info) {
|
||||
info.GetAttrOrDefault("quant_min", &quant_min_, static_cast<decltype(quant_min_)>(0));
|
||||
info.GetAttrOrDefault("quant_max", &quant_max_, static_cast<decltype(quant_max_)>(255));
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
int64_t quant_min_;
|
||||
int64_t quant_max_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -210,6 +210,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, InplaceClipGradNorm);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant);
|
||||
|
||||
#if defined(ORT_USE_NCCL) || defined(USE_MPI)
|
||||
// P2P communication operators.
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send);
|
||||
|
|
@ -435,6 +437,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, InplaceClipGradNorm)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant)>,
|
||||
|
||||
// P2P communication operators.
|
||||
#if defined(ORT_USE_NCCL) || defined(USE_MPI)
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send)>,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "orttraining/training_ops/cuda/quantization/fake_quant.h"
|
||||
#include "orttraining/training_ops/cuda/quantization/fake_quant_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_FAKEQUANT_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
FakeQuant, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
FakeQuant<T>);
|
||||
|
||||
REGISTER_FAKEQUANT_KERNEL_TYPED(float)
|
||||
|
||||
template <typename T>
|
||||
Status FakeQuant<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
// Prepare the input, scale, zero point
|
||||
const auto* input_tensor = ctx->Input<Tensor>(0);
|
||||
const CudaT* input_data = reinterpret_cast<const CudaT*>(input_tensor->Data<T>());
|
||||
const auto* scale = ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(IsScalarOr1ElementVector(scale), "Quantization scale must be a scalar or 1D tensor of size 1.");
|
||||
const CudaT* quant_scale = reinterpret_cast<const CudaT*>(scale->Data<T>());
|
||||
ORT_ENFORCE(*quant_scale != static_cast<const CudaT>(0),
|
||||
"Quantization scale cannot be 0. It may result in undefined behavior.");
|
||||
const auto* zero_point = ctx->Input<Tensor>(2);
|
||||
ORT_ENFORCE(IsScalarOr1ElementVector(zero_point), "Quantization zero point must be a scalar or 1D tensor of size 1.");
|
||||
const CudaT* quant_zero_point = reinterpret_cast<const CudaT*>(zero_point->Data<T>());
|
||||
|
||||
// Prepare the output, mask for gradient computation
|
||||
auto& fake_quantized_tensor = *ctx->Output(0, input_tensor->Shape());
|
||||
CudaT* fake_quantized_data = reinterpret_cast<CudaT*>(fake_quantized_tensor.MutableData<T>());
|
||||
bool* quantization_mask_data = ctx->Output(1, input_tensor->Shape())->MutableData<bool>();
|
||||
|
||||
// Fake quantize the input tensor
|
||||
// TODO(bmeswani): Add support for FakeQuantPerChannel
|
||||
FakeQuantPerTensor(Stream(), input_tensor->Shape().Size(), input_data, *quant_scale, *quant_zero_point, quant_min_,
|
||||
quant_max_, fake_quantized_data, quantization_mask_data);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
class FakeQuant final : public CudaKernel {
|
||||
public:
|
||||
FakeQuant(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
info.GetAttrOrDefault("quant_min", &quant_min_, static_cast<decltype(quant_min_)>(0));
|
||||
info.GetAttrOrDefault("quant_max", &quant_max_, static_cast<decltype(quant_max_)>(255));
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
int64_t quant_min_;
|
||||
int64_t quant_max_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/quantization/fake_quant_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
namespace {
|
||||
constexpr int NumElementsPerThread = GridDim::maxElementsPerThread;
|
||||
constexpr int NumThreadsPerBlock = GridDim::maxThreadsPerBlock;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
__global__ void FakeQuantPerTensorImpl(const int64_t num_elements, const T* input_data, const T quant_scale,
|
||||
const T quant_zero_point, const int64_t quant_min, const int64_t quant_max,
|
||||
T* fake_quantized_data, bool* quantization_mask_data) {
|
||||
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
|
||||
T values[NumElementsPerThread];
|
||||
T fake_quantized_values[NumElementsPerThread];
|
||||
bool mask_values[NumElementsPerThread];
|
||||
|
||||
CUDA_LONG idx = start;
|
||||
// Load
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NumElementsPerThread; i++) {
|
||||
if (idx < num_elements) {
|
||||
values[i] = input_data[idx];
|
||||
idx += NumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NumElementsPerThread; i++) {
|
||||
// Quantize
|
||||
const auto quantized_value = std::nearbyint(values[i] / quant_scale) + quant_zero_point;
|
||||
// Clamp and De-Quantize
|
||||
fake_quantized_values[i] =
|
||||
(fminf(quant_max, fmaxf(quant_min, quantized_value)) - quant_zero_point) * quant_scale;
|
||||
// Compute mask
|
||||
mask_values[i] = (quant_min <= quantized_value && quantized_value <= quant_max);
|
||||
}
|
||||
|
||||
// Write
|
||||
idx = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NumElementsPerThread; i++) {
|
||||
if (idx < num_elements) {
|
||||
fake_quantized_data[idx] = fake_quantized_values[i];
|
||||
quantization_mask_data[idx] = mask_values[i];
|
||||
idx += NumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FakeQuantPerTensor(cudaStream_t stream, const int64_t num_elements, const T* input_data, const T quant_scale,
|
||||
const T quant_zero_point, const int64_t quant_min, const int64_t quant_max,
|
||||
T* fake_quantized_data, bool* quantization_mask_data) {
|
||||
int blocksPerGrid =
|
||||
static_cast<int>(CeilDiv(num_elements, NumThreadsPerBlock * NumElementsPerThread));
|
||||
FakeQuantPerTensorImpl<T><<<blocksPerGrid, NumThreadsPerBlock, 0, stream>>>(
|
||||
num_elements, input_data, quant_scale, quant_zero_point,
|
||||
quant_min, quant_max, fake_quantized_data, quantization_mask_data);
|
||||
}
|
||||
|
||||
#define SPECIALIZED_FAKEQUANT_IMPL(T) \
|
||||
template void FakeQuantPerTensor<T>(cudaStream_t stream, const int64_t num_elements, \
|
||||
const T* input_data, const T quant_scale, \
|
||||
const T quant_zero_point, const int64_t quant_min, \
|
||||
const int64_t quant_max, T* fake_quantized_data, \
|
||||
bool* quantization_mask_data);
|
||||
|
||||
SPECIALIZED_FAKEQUANT_IMPL(float)
|
||||
|
||||
#undef SPECIALIZED_FAKEQUANT_IMPL
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
void FakeQuantPerTensor(cudaStream_t stream, const int64_t num_elements, const T* input_data, const T quant_scale,
|
||||
const T quant_zero_point, const int64_t quant_min, const int64_t quant_max,
|
||||
T* fake_quantized_data, bool* quantization_mask_data);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue