Only quantize gather with initializer (#4469)

This commit is contained in:
Yufeng Li 2020-07-09 13:33:43 -07:00 committed by GitHub
parent bec18eb3f4
commit d4db83858b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -348,7 +348,7 @@ class ONNXQuantizer:
new_list += self._quantize_convolution(node, new_list)
elif node.op_type == 'MatMul':
new_list += self._quantize_matmul(node, new_list)
elif node.op_type == 'Gather' and self._is_valid_quantize_value(node.input[0]):
elif node.op_type == 'Gather' and self._is_valid_initializer_value(node.input[0]):
new_list += self._quantize_gather_ops(node, new_list)
elif node.op_type == 'Add' or node.op_type == 'Mul':
new_list += self._quantize_binary_math_ops(node, new_list)
@ -395,6 +395,9 @@ class ONNXQuantizer:
value_info = self.value_infos[value_name]
return value_info.type.HasField(
'tensor_type') and value_info.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT
return self._is_valid_initializer_value(value_name)
def _is_valid_initializer_value(self, value_name):
weight = _find_by_name(value_name, self.model.graph.initializer)
return weight is not None and weight.data_type == onnx_proto.TensorProto.FLOAT