update quatizelinear to process int8 input (#1576)

This commit is contained in:
Ashwini Khade 2019-08-07 10:09:15 -07:00 committed by GitHub
parent aeb0bcb4a3
commit a93ece2727
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 11 deletions

View file

@ -265,7 +265,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QuantizeLinear);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, MatMulInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger);
@ -543,7 +544,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, MatMulInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger)>,

View file

@ -63,13 +63,22 @@ Status DequantizeLinear<T>::Compute(OpKernelContext* ctx) const {
ONNX_CPU_OPERATOR_TYPED_KERNEL(
QuantizeLinear,
10,
float,
uint8_t,
KernelDefBuilder()
.TypeConstraint("x", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("y_scale", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("y_zero_point", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("y", DataTypeImpl::GetTensorType<uint8_t>()),
QuantizeLinear<float>);
QuantizeLinear<uint8_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
QuantizeLinear,
10,
int8_t,
KernelDefBuilder()
.TypeConstraint("x", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("y_zero_point", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("y", DataTypeImpl::GetTensorType<int8_t>()),
QuantizeLinear<int8_t>);
// clamp doesn't exist in the version of <algorithm> that we're using, so
// make a local one.
@ -85,9 +94,9 @@ static float RoundHalfToEven(float input) {
return result;
}
template <>
template <typename T>
// formula is Y = X / Scale + ZeroPoint
Status QuantizeLinear<float>::Compute(OpKernelContext* ctx) const {
Status QuantizeLinear<T>::Compute(OpKernelContext* ctx) const {
auto& x = *ctx->Input<Tensor>(0);
auto& y_scale = *ctx->Input<Tensor>(1);
auto& y_zero_point = *ctx->Input<Tensor>(2);
@ -102,14 +111,18 @@ Status QuantizeLinear<float>::Compute(OpKernelContext* ctx) const {
ORT_ENFORCE(scale_shape.NumDimensions() == 0 || (scale_shape.NumDimensions() == 1 && scale_shape.GetDims().size() == 1), "x_scale must be a scalar.");
ORT_ENFORCE(zero_point_shape.NumDimensions() == 0 || (zero_point_shape.NumDimensions() == 1 && zero_point_shape.GetDims().size() == 1), "x_zero_point must be a scalar.");
const uint8_t zero_point = *(y_zero_point.template Data<uint8_t>());
const T zero_point = *(y_zero_point.template Data<T>());
const float scale = *(y_scale.template Data<float>());
const auto* input = x.template Data<float>();
auto* output = y.template MutableData<uint8_t>();
auto* output = y.template MutableData<T>();
const auto num_of_elements = x_shape.Size();
const float qmax = std::numeric_limits<T>::max();
const float qmin_default = std::numeric_limits<T>::min();
// adjust qmin for int8 inputs. This is required to keep zero point as zero
const float qmin = qmin_default == -128 ? -127 : qmin_default;
for (int i = 0; i < num_of_elements; ++i) {
output[i] = static_cast<uint8_t>(clamp(RoundHalfToEven(static_cast<float>(input[i]/scale)) + zero_point, 0.0f, float(UINT8_MAX)));
output[i] = static_cast<T>(clamp(RoundHalfToEven(static_cast<float>(input[i]/scale)) + zero_point, qmin, qmax));
}
return Status::OK();

View file

@ -47,7 +47,7 @@ TEST(DequantizeLinearOpTest, DequantizeLinear_2) {
// quantize with scalar zero point and scale
TEST(QuantizeLinearOpTest, QuantizeLinear_0) {
TEST(QuantizeLinearOpTest, QuantizeLinear_uint8) {
OpTester test("QuantizeLinear", 10);
std::vector<int64_t> dims{6};
test.AddInput<float>("x", dims, {0, 2, 3, 1000, -254, -1000});
@ -57,6 +57,16 @@ TEST(QuantizeLinearOpTest, QuantizeLinear_0) {
test.Run();
}
// quantize with scalar zero point and scale
TEST(QuantizeLinearOpTest, QuantizeLinear_int8) {
OpTester test("QuantizeLinear", 10);
std::vector<int64_t> dims{6};
test.AddInput<float>("x", dims, {0, 2, 3, 5, -2, -5});
test.AddInput<float>("y_scale", {}, {.039215686f});
test.AddInput<int8_t>("y_zero_point", {}, {0});
test.AddOutput<int8_t>("y", dims, {0, 51, 76, 127, -51, -127});
test.Run();
}
// quantize with 2D data
TEST(QuantizeLinearOpTest, QuantizeLinear_1) {