Change keepdims of ReduceMax/ReduceMin to always 1 when using quatization calibration MinMax approach (#9167)

* Change keepdims to always 1

* fix typo

* Refine code
This commit is contained in:
Chi Lo 2021-09-25 10:13:54 -07:00 committed by GitHub
parent fd91bf91c9
commit 9fda95fec9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -164,23 +164,17 @@ class MinMaxCalibrater(CalibraterBase):
for tensor in tensors:
# When doing ReduceMax/ReduceMin, keep dimension if tensor contains dim with value of 0,
# for example:
# dim = [ dim_value: 0 ]
#
# otherwise, don't keep dimension.
#
# When doing ReduceMax/ReduceMin, ORT can't reduce on dim with value of 0 if 'keepdims' is false.
# To make the code simple, we always let keepdims to be 1.
keepdims = 1
# dim could be:
# [dim_param: "batch_size", dim_value: 256, dim_value: 36, dim_value: 64],
# [dim_value: 0],
# ...
# Please see the definition of TensorShapeProto https://github.com/onnx/onnx/blob/master/onnx/onnx.proto#L651
dim = value_infos[tensor].type.tensor_type.shape.dim
keepdims = 0
shape = ()
for d in dim:
# A dimension can be either an integer value or a symbolic variable.
# Dimension with integer value and value of 0 is what we are looking for to keep dimension.
# Please see the def of TensorShapeProto https://github.com/onnx/onnx/blob/master/onnx/onnx.proto#L630
if d.WhichOneof('value') == 'dim_value' and d.dim_value == 0:
keepdims = 1
shape = (1,) if len(dim) == 1 else list(1 for i in range(len(dim)))
break
shape = (1,) if len(dim) == 1 else tuple(1 for i in range(len(dim)))
# Adding ReduceMin nodes
reduce_min_name = tensor + '_ReduceMin'