mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Add percentile method for PTQ (#9342)
* Add percentile method for calibration * Update configuration
This commit is contained in:
parent
2406a425a7
commit
c8151b4037
1 changed files with 140 additions and 15 deletions
|
|
@ -25,7 +25,7 @@ import itertools
|
|||
class CalibrationMethod(Enum):
|
||||
MinMax = 0
|
||||
Entropy = 1
|
||||
|
||||
Percentile = 2
|
||||
|
||||
class CalibrationDataReader(metaclass=abc.ABCMeta):
|
||||
@classmethod
|
||||
|
|
@ -269,19 +269,31 @@ class MinMaxCalibrater(CalibraterBase):
|
|||
|
||||
return self.calibrate_tensors_range
|
||||
|
||||
class EntropyCalibrater(CalibraterBase):
|
||||
def __init__(self, model, op_types_to_calibrate=[], augmented_model_path='augmented_model.onnx'):
|
||||
class HistogramCalibrater(CalibraterBase):
|
||||
def __init__(self,
|
||||
model,
|
||||
op_types_to_calibrate=[],
|
||||
augmented_model_path='augmented_model.onnx',
|
||||
method='percentile',
|
||||
num_quantized_bins=128,
|
||||
percentile=99.99):
|
||||
'''
|
||||
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
|
||||
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
||||
:param augmented_model_path: save augmented model to this path.
|
||||
:param method: A string. One of ['entropy', 'percentile'].
|
||||
:param num_quantized_bins: number of quantized bins. Default 128.
|
||||
:param percentile: A float number between [0, 100]. Default 99.99.
|
||||
'''
|
||||
super(EntropyCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path)
|
||||
super(HistogramCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path)
|
||||
self.intermediate_outputs = []
|
||||
self.calibrate_tensors_range = None
|
||||
self.num_model_outputs = len(self.model.graph.output)
|
||||
self.model_original_outputs = set(output.name for output in self.model.graph.output)
|
||||
self.collector = None
|
||||
self.method = method
|
||||
self.num_quantized_bins = num_quantized_bins
|
||||
self.percentile = percentile
|
||||
|
||||
def augment_graph(self):
|
||||
'''
|
||||
|
|
@ -334,7 +346,9 @@ class EntropyCalibrater(CalibraterBase):
|
|||
clean_merged_dict = dict((i, merged_dict[i]) for i in merged_dict if i not in self.model_original_outputs)
|
||||
|
||||
if not self.collector:
|
||||
self.collector = HistogramCollector()
|
||||
self.collector = HistogramCollector(method=self.method,
|
||||
num_quantized_bins=self.num_quantized_bins,
|
||||
percentile=self.percentile)
|
||||
self.collector.collect(clean_merged_dict)
|
||||
|
||||
self.clear_collected_data()
|
||||
|
|
@ -347,8 +361,44 @@ class EntropyCalibrater(CalibraterBase):
|
|||
if not self.collector:
|
||||
raise ValueError("No collector created and can't generate calibration data.")
|
||||
|
||||
return self.collector.get_optimal_collection_result()
|
||||
return self.collector.compute_collection_result()
|
||||
|
||||
class EntropyCalibrater(HistogramCalibrater):
|
||||
def __init__(self,
|
||||
model,
|
||||
op_types_to_calibrate=[],
|
||||
augmented_model_path='augmented_model.onnx',
|
||||
method='entropy',
|
||||
num_quantized_bins=128):
|
||||
'''
|
||||
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
|
||||
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
||||
:param augmented_model_path: save augmented model to this path.
|
||||
:param method: A string. One of ['entropy', 'percentile'].
|
||||
:param num_quantized_bins: number of quantized bins. Default 128.
|
||||
'''
|
||||
super(EntropyCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path,
|
||||
method=method, num_quantized_bins=num_quantized_bins)
|
||||
|
||||
class PercentileCalibrater(HistogramCalibrater):
|
||||
def __init__(self,
|
||||
model,
|
||||
op_types_to_calibrate=[],
|
||||
augmented_model_path='augmented_model.onnx',
|
||||
method='percentile',
|
||||
num_quantized_bins=2048,
|
||||
percentile=99.999):
|
||||
'''
|
||||
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
|
||||
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
||||
:param augmented_model_path: save augmented model to this path.
|
||||
:param method: A string. One of ['entropy', 'percentile'].
|
||||
:param num_quantized_bins: number of quantized bins. Default 128.
|
||||
:param percentile: A float number between [0, 100]. Default 99.99.
|
||||
'''
|
||||
super(PercentileCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path,
|
||||
method=method, num_quantized_bins=num_quantized_bins,
|
||||
percentile=percentile)
|
||||
|
||||
class CalibrationDataCollector(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
|
@ -365,7 +415,7 @@ class CalibrationDataCollector(metaclass=abc.ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimal_collection_result(self):
|
||||
def compute_collection_result(self):
|
||||
"""
|
||||
Get the optimal result among collection data.
|
||||
"""
|
||||
|
|
@ -373,18 +423,57 @@ class CalibrationDataCollector(metaclass=abc.ABCMeta):
|
|||
|
||||
class HistogramCollector(CalibrationDataCollector):
|
||||
"""
|
||||
Implementation of collecting histogram data as dict for each tensor targeting on entropy calibration.
|
||||
Collecting histogram for each tensor. Percentile and Entropy method are supported.
|
||||
|
||||
ref: https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
|
||||
ref: https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/_modules/
|
||||
pytorch_quantization/calib/histogram.html
|
||||
"""
|
||||
def __init__(self, num_quantized_bins=128):
|
||||
def __init__(self, method, num_quantized_bins, percentile):
|
||||
self.histogram_dict = {}
|
||||
self.method = method
|
||||
self.num_quantized_bins= num_quantized_bins
|
||||
self.percentile = percentile
|
||||
|
||||
def get_histogram_dict(self):
|
||||
return self.histogram_dict
|
||||
|
||||
def collect(self, name_to_arr):
|
||||
# TODO: Currently we have different collect() for percentile and percentile method respectively.
|
||||
# Need unified collect in the future.
|
||||
if self.method == 'entropy':
|
||||
return self.collect_for_entropy(name_to_arr)
|
||||
elif self.method == 'percentile':
|
||||
return self.collect_for_percentile(name_to_arr)
|
||||
else:
|
||||
raise ValueError('Only \'entropy\' or \'percentile\' method are supported')
|
||||
|
||||
def collect_for_percentile(self, name_to_arr):
|
||||
for tensor, data_arr in name_to_arr.items():
|
||||
data_arr = np.asarray(data_arr)
|
||||
data_arr = data_arr.flatten()
|
||||
data_arr = np.absolute(data_arr) # only consider absolute value
|
||||
|
||||
if tensor not in self.histogram_dict:
|
||||
# first time it uses num_quantized_bins to compute histogram.
|
||||
hist, hist_edges = np.histogram(data_arr, bins=self.num_quantized_bins)
|
||||
self.histogram_dict[tensor] = (hist, hist_edges)
|
||||
else:
|
||||
old_histogram = self.histogram_dict[tensor]
|
||||
old_hist = old_histogram[0]
|
||||
old_hist_edges = old_histogram[1]
|
||||
temp_amax = np.max(data_arr)
|
||||
if temp_amax > old_hist_edges[-1]:
|
||||
# increase the number of bins
|
||||
width = old_hist_edges[1] - old_hist_edges[0]
|
||||
# NOTE: np.arange may create an extra bin after the one containing temp_amax
|
||||
new_bin_edges = np.arange(old_hist_edges[-1] + width, temp_amax + width, width)
|
||||
old_hist_edges = np.hstack((old_hist_edges, new_bin_edges))
|
||||
hist, hist_edges = np.histogram(data_arr, bins=old_hist_edges)
|
||||
hist[:len(old_hist)] += old_hist
|
||||
self.histogram_dict[tensor] = (hist, hist_edges)
|
||||
|
||||
def collect_for_entropy(self, name_to_arr):
|
||||
for tensor, data_arr in name_to_arr.items():
|
||||
data_arr = np.asarray(data_arr)
|
||||
data_arr = data_arr.flatten()
|
||||
|
|
@ -402,7 +491,6 @@ class HistogramCollector(CalibrationDataCollector):
|
|||
old_histogram = self.histogram_dict[tensor]
|
||||
self.histogram_dict[tensor] = self.merge_histogram(old_histogram, data_arr, min_value, max_value, threshold)
|
||||
else:
|
||||
# hist, hist_edges = np.histogram(data_arr, self.num_quantized_bins, range=(min_value, max_value))
|
||||
hist, hist_edges = np.histogram(data_arr, self.num_quantized_bins, range=(-threshold, threshold))
|
||||
self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value, threshold)
|
||||
|
||||
|
|
@ -415,8 +503,8 @@ class HistogramCollector(CalibrationDataCollector):
|
|||
return (new_hist + old_hist, old_hist_edges, min(old_min, new_min), max(old_max, new_max), old_threshold)
|
||||
else:
|
||||
if old_threshold == 0:
|
||||
hist, hist_edges = np.histogram(data_arr, new_num_bins, range=(-new_threshold, new_threshold))
|
||||
hist[len(hist) // 2] += len(old_hist)
|
||||
hist, hist_edges = np.histogram(data_arr, len(old_hist), range=(-new_threshold, new_threshold))
|
||||
hist += old_hist
|
||||
else:
|
||||
old_num_bins = len(old_hist)
|
||||
old_stride = 2 * old_threshold / old_num_bins
|
||||
|
|
@ -427,19 +515,54 @@ class HistogramCollector(CalibrationDataCollector):
|
|||
hist[half_increased_bins:new_num_bins-half_increased_bins] += old_hist
|
||||
return (hist, hist_edges, min(old_min, new_min), max(old_max, new_max), new_threshold)
|
||||
|
||||
def get_optimal_collection_result(self):
|
||||
def compute_collection_result(self):
|
||||
if not self.histogram_dict or len(self.histogram_dict) == 0:
|
||||
raise ValueError("Histogram has not been collected. Please run collect() first.")
|
||||
|
||||
if self.method == 'entropy':
|
||||
return self.compute_entropy()
|
||||
elif self.method == 'percentile':
|
||||
return self.compute_percentile()
|
||||
else:
|
||||
raise ValueError('Only \'entropy\' or \'percentile\' method are supported')
|
||||
|
||||
def compute_percentile(self):
|
||||
if self.percentile < 0 or self.percentile > 100:
|
||||
raise ValueError("Invalid percentile. Must be in range 0 <= percentile <= 100.")
|
||||
|
||||
histogram_dict = self.histogram_dict
|
||||
percentile = self.percentile
|
||||
|
||||
thresholds_dict = {} # per tensor thresholds
|
||||
|
||||
for tensor, histogram in histogram_dict.items():
|
||||
hist = histogram[0]
|
||||
hist_edges = histogram[1]
|
||||
total = hist.sum()
|
||||
cdf = np.cumsum(hist/total)
|
||||
idx = np.searchsorted(cdf, percentile/100)
|
||||
thresholds_dict[tensor] = (float(hist_edges[idx]), float(hist_edges[idx]))
|
||||
|
||||
return thresholds_dict
|
||||
|
||||
def compute_entropy(self):
|
||||
histogram_dict = self.histogram_dict
|
||||
num_quantized_bins = self.num_quantized_bins
|
||||
|
||||
thresholds_dict = {} # per tensor thresholds
|
||||
|
||||
for tensor, histogram in histogram_dict.items():
|
||||
optimal_threshold = self.get_optimal_threshold(histogram, num_quantized_bins)
|
||||
optimal_threshold = self.get_entropy_threshold(histogram, num_quantized_bins)
|
||||
thresholds_dict[tensor] = optimal_threshold
|
||||
|
||||
return thresholds_dict
|
||||
|
||||
def get_optimal_threshold(self, histogram, num_quantized_bins):
|
||||
def get_entropy_threshold(self, histogram, num_quantized_bins):
|
||||
"""Given a dataset, find the optimal threshold for quantizing it.
|
||||
The reference distribution is `q`, and the candidate distribution is `p`.
|
||||
`q` is a truncated version of the original distribution.
|
||||
Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
|
||||
"""
|
||||
from scipy.stats import entropy
|
||||
import copy
|
||||
|
||||
|
|
@ -513,5 +636,7 @@ def create_calibrator(model,
|
|||
return MinMaxCalibrater(model, op_types_to_calibrate, augmented_model_path)
|
||||
elif calibrate_method == CalibrationMethod.Entropy:
|
||||
return EntropyCalibrater(model, op_types_to_calibrate, augmented_model_path)
|
||||
elif calibrate_method == CalibrationMethod.Percentile:
|
||||
return PercentileCalibrater(model, op_types_to_calibrate, augmented_model_path)
|
||||
|
||||
raise ValueError('Unsupported calibration method {}'.format(calibrate_method))
|
||||
|
|
|
|||
Loading…
Reference in a new issue