mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Add ability to save calibration augmented models through external data format when model size exceeds 2Gb. (#10695)
This commit is contained in:
parent
62cc981599
commit
e5c6dc1fc8
2 changed files with 50 additions and 20 deletions
|
|
@ -39,12 +39,13 @@ class CalibrationDataReader(metaclass=abc.ABCMeta):
|
|||
|
||||
|
||||
class CalibraterBase:
|
||||
def __init__(self, model, op_types_to_calibrate=[], augmented_model_path='augmented_model.onnx', symmetric=False):
|
||||
def __init__(self, model, op_types_to_calibrate=[], augmented_model_path='augmented_model.onnx', symmetric=False, use_external_data_format=False):
|
||||
'''
|
||||
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
|
||||
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
||||
:param augmented_model_path: save augmented model to this path.
|
||||
:param symmetric: make range of tensor symmetric (central point is 0).
|
||||
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
||||
'''
|
||||
if isinstance(model, str):
|
||||
self.model = onnx.load(model)
|
||||
|
|
@ -56,6 +57,7 @@ class CalibraterBase:
|
|||
self.op_types_to_calibrate = op_types_to_calibrate
|
||||
self.augmented_model_path = augmented_model_path
|
||||
self.symmetric = symmetric
|
||||
self.use_external_data_format = use_external_data_format
|
||||
|
||||
# augment graph
|
||||
self.augment_model = None
|
||||
|
|
@ -138,14 +140,15 @@ class CalibraterBase:
|
|||
|
||||
|
||||
class MinMaxCalibrater(CalibraterBase):
|
||||
def __init__(self, model, op_types_to_calibrate=[], augmented_model_path='augmented_model.onnx', symmetric=False):
|
||||
def __init__(self, model, op_types_to_calibrate=[], augmented_model_path='augmented_model.onnx', symmetric=False, use_external_data_format=False):
|
||||
'''
|
||||
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
|
||||
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
||||
:param augmented_model_path: save augmented model to this path.
|
||||
:param symmetric: make range of tensor symmetric (central point is 0).
|
||||
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
||||
'''
|
||||
super(MinMaxCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, symmetric)
|
||||
super(MinMaxCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, symmetric, use_external_data_format)
|
||||
self.intermediate_outputs = []
|
||||
self.calibrate_tensors_range = None
|
||||
self.num_model_outputs = len(self.model.graph.output)
|
||||
|
|
@ -195,7 +198,7 @@ class MinMaxCalibrater(CalibraterBase):
|
|||
|
||||
model.graph.node.extend(added_nodes)
|
||||
model.graph.output.extend(added_outputs)
|
||||
onnx.save(model, self.augmented_model_path)
|
||||
onnx.save(model, self.augmented_model_path, save_as_external_data=self.use_external_data_format)
|
||||
self.augment_model = model
|
||||
|
||||
def clear_collected_data(self):
|
||||
|
|
@ -268,7 +271,6 @@ class MinMaxCalibrater(CalibraterBase):
|
|||
else:
|
||||
pairs.append(tuple([min_value, max_value]))
|
||||
|
||||
|
||||
new_calibrate_tensors_range = dict(zip(calibrate_tensor_names, pairs))
|
||||
if self.calibrate_tensors_range:
|
||||
self.calibrate_tensors_range = self.merge_range(self.calibrate_tensors_range, new_calibrate_tensors_range)
|
||||
|
|
@ -282,6 +284,7 @@ class HistogramCalibrater(CalibraterBase):
|
|||
model,
|
||||
op_types_to_calibrate=[],
|
||||
augmented_model_path='augmented_model.onnx',
|
||||
use_external_data_format=False,
|
||||
method='percentile',
|
||||
symmetric=False,
|
||||
num_bins=128,
|
||||
|
|
@ -291,13 +294,14 @@ class HistogramCalibrater(CalibraterBase):
|
|||
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
|
||||
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
||||
:param augmented_model_path: save augmented model to this path.
|
||||
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
||||
:param method: A string. One of ['entropy', 'percentile'].
|
||||
:param symmetric: make range of tensor symmetric (central point is 0).
|
||||
:param num_bins: number of bins to create a new histogram for collecting tensor values.
|
||||
:param num_quantized_bins: number of quantized bins. Default 128.
|
||||
:param percentile: A float number between [0, 100]. Default 99.99.
|
||||
'''
|
||||
super(HistogramCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path)
|
||||
super(HistogramCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, use_external_data_format)
|
||||
self.intermediate_outputs = []
|
||||
self.calibrate_tensors_range = None
|
||||
self.num_model_outputs = len(self.model.graph.output)
|
||||
|
|
@ -327,7 +331,7 @@ class HistogramCalibrater(CalibraterBase):
|
|||
|
||||
model.graph.node.extend(added_nodes)
|
||||
model.graph.output.extend(added_outputs)
|
||||
onnx.save(model, self.augmented_model_path)
|
||||
onnx.save(model, self.augmented_model_path, save_as_external_data=self.use_external_data_format)
|
||||
self.augment_model = model
|
||||
|
||||
def clear_collected_data(self):
|
||||
|
|
@ -343,7 +347,6 @@ class HistogramCalibrater(CalibraterBase):
|
|||
break
|
||||
self.intermediate_outputs.append(self.infer_session.run(None, inputs))
|
||||
|
||||
|
||||
if len(self.intermediate_outputs) == 0:
|
||||
raise ValueError("No data is collected.")
|
||||
|
||||
|
|
@ -384,6 +387,7 @@ class EntropyCalibrater(HistogramCalibrater):
|
|||
model,
|
||||
op_types_to_calibrate=[],
|
||||
augmented_model_path='augmented_model.onnx',
|
||||
use_external_data_format=False,
|
||||
method='entropy',
|
||||
symmetric=False,
|
||||
num_bins=128,
|
||||
|
|
@ -392,19 +396,21 @@ class EntropyCalibrater(HistogramCalibrater):
|
|||
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
|
||||
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
||||
:param augmented_model_path: save augmented model to this path.
|
||||
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
||||
:param method: A string. One of ['entropy', 'percentile'].
|
||||
:param symmetric: make range of tensor symmetric (central point is 0).
|
||||
:param num_bins: number of bins to create a new histogram for collecting tensor values.
|
||||
:param num_quantized_bins: number of quantized bins. Default 128.
|
||||
'''
|
||||
super(EntropyCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, method=method,
|
||||
symmetric=symmetric, num_bins=num_bins, num_quantized_bins=num_quantized_bins)
|
||||
super(EntropyCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, use_external_data_format,
|
||||
method=method, symmetric=symmetric, num_bins=num_bins, num_quantized_bins=num_quantized_bins)
|
||||
|
||||
class PercentileCalibrater(HistogramCalibrater):
|
||||
def __init__(self,
|
||||
model,
|
||||
op_types_to_calibrate=[],
|
||||
augmented_model_path='augmented_model.onnx',
|
||||
use_external_data_format=False,
|
||||
method='percentile',
|
||||
symmetric=False,
|
||||
num_bins=2048,
|
||||
|
|
@ -413,13 +419,14 @@ class PercentileCalibrater(HistogramCalibrater):
|
|||
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
|
||||
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
||||
:param augmented_model_path: save augmented model to this path.
|
||||
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
||||
:param method: A string. One of ['entropy', 'percentile'].
|
||||
:param symmetric: make range of tensor symmetric (central point is 0).
|
||||
:param num_quantized_bins: number of quantized bins. Default 128.
|
||||
:param percentile: A float number between [0, 100]. Default 99.99.
|
||||
'''
|
||||
super(PercentileCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, method=method,
|
||||
symmetric=symmetric, num_bins=num_bins, percentile=percentile)
|
||||
super(PercentileCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, use_external_data_format,
|
||||
method=method, symmetric=symmetric, num_bins=num_bins, percentile=percentile)
|
||||
|
||||
class CalibrationDataCollector(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
|
@ -635,13 +642,13 @@ class HistogramCollector(CalibrationDataCollector):
|
|||
# <--- quantized bins ---->
|
||||
# |======|===========|===========|=======|
|
||||
# zero bin index
|
||||
# ^ ^
|
||||
# ^ ^
|
||||
# | |
|
||||
# start index end index (start of iteration)
|
||||
# ^ ^
|
||||
# start index end index (start of iteration)
|
||||
# ^ ^
|
||||
# | |
|
||||
# start index end index ...
|
||||
# ^ ^
|
||||
# ^ ^
|
||||
# | |
|
||||
# start index end index (end of iteration)
|
||||
|
||||
|
|
@ -703,23 +710,40 @@ def create_calibrator(model,
|
|||
op_types_to_calibrate=[],
|
||||
augmented_model_path='augmented_model.onnx',
|
||||
calibrate_method=CalibrationMethod.MinMax,
|
||||
use_external_data_format=False,
|
||||
extra_options={}):
|
||||
|
||||
if calibrate_method == CalibrationMethod.MinMax:
|
||||
# default settings for min-max algorithm
|
||||
symmetric = False if 'symmetric' not in extra_options else extra_options['symmetric']
|
||||
return MinMaxCalibrater(model, op_types_to_calibrate, augmented_model_path, symmetric=symmetric)
|
||||
return MinMaxCalibrater(
|
||||
model, op_types_to_calibrate, augmented_model_path,
|
||||
use_external_data_format=use_external_data_format,
|
||||
symmetric=symmetric
|
||||
)
|
||||
elif calibrate_method == CalibrationMethod.Entropy:
|
||||
# default settings for entropy algorithm
|
||||
num_bins = 128 if 'num_bins' not in extra_options else extra_options['num_bins']
|
||||
num_quantized_bins = 128 if 'num_quantized_bins' not in extra_options else extra_options['num_quantized_bins']
|
||||
symmetric = False if 'symmetric' not in extra_options else extra_options['symmetric']
|
||||
return EntropyCalibrater(model, op_types_to_calibrate, augmented_model_path, symmetric=symmetric, num_bins=num_bins, num_quantized_bins=num_quantized_bins)
|
||||
return EntropyCalibrater(
|
||||
model, op_types_to_calibrate, augmented_model_path,
|
||||
use_external_data_format=use_external_data_format,
|
||||
symmetric=symmetric,
|
||||
num_bins=num_bins,
|
||||
num_quantized_bins=num_quantized_bins
|
||||
)
|
||||
elif calibrate_method == CalibrationMethod.Percentile:
|
||||
# default settings for percentile algorithm
|
||||
num_bins = 2048 if 'num_bins' not in extra_options else extra_options['num_bins']
|
||||
percentile = 99.999 if 'percentile' not in extra_options else extra_options['percentile']
|
||||
symmetric = True if 'symmetric' not in extra_options else extra_options['symmetric']
|
||||
return PercentileCalibrater(model, op_types_to_calibrate, augmented_model_path, symmetric=symmetric, num_bins=num_bins, percentile=percentile)
|
||||
return PercentileCalibrater(
|
||||
model, op_types_to_calibrate, augmented_model_path,
|
||||
use_external_data_format=use_external_data_format,
|
||||
symmetric=symmetric,
|
||||
num_bins=num_bins,
|
||||
percentile=percentile
|
||||
)
|
||||
|
||||
raise ValueError('Unsupported calibration method {}'.format(calibrate_method))
|
||||
|
|
|
|||
|
|
@ -236,7 +236,13 @@ def quantize_static(model_input,
|
|||
model = load_model(Path(model_input), optimize_model, False)
|
||||
|
||||
calib_extra_options = {} if 'CalibTensorRangeSymmetric' not in extra_options else {'symmetric': extra_options['CalibTensorRangeSymmetric']}
|
||||
calibrator = create_calibrator(model, op_types_to_quantize, calibrate_method=calibrate_method, extra_options=calib_extra_options)
|
||||
calibrator = create_calibrator(
|
||||
model,
|
||||
op_types_to_quantize,
|
||||
calibrate_method=calibrate_method,
|
||||
use_external_data_format=use_external_data_format,
|
||||
extra_options=calib_extra_options
|
||||
)
|
||||
calibrator.collect_data(calibration_data_reader)
|
||||
tensors_range = calibrator.compute_range()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue