Fix bug for percentile calibration module. (#13376)

### Description
Fix bug for percentile calibration module.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Jian Chen 2022-10-20 12:33:07 -04:00 committed by GitHub
parent fc12abf6b1
commit ac5948cb48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -570,14 +570,23 @@ class HistogramCollector(CalibrationDataCollector):
for tensor, data_arr in name_to_arr.items():
data_arr = np.asarray(data_arr)
data_arr = data_arr.flatten()
if data_arr.size > 0:
min_value = np.min(data_arr)
max_value = np.max(data_arr)
else:
min_value = 0
max_value = 0
data_arr = np.absolute(data_arr) # only consider absolute value
if tensor not in self.histogram_dict:
# first time it uses num_bins to compute histogram.
hist, hist_edges = np.histogram(data_arr, bins=self.num_bins)
self.histogram_dict[tensor] = (hist, hist_edges)
self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value)
else:
old_histogram = self.histogram_dict[tensor]
old_min = old_histogram[2]
old_max = old_histogram[3]
old_hist = old_histogram[0]
old_hist_edges = old_histogram[1]
temp_amax = np.max(data_arr)
@ -589,7 +598,7 @@ class HistogramCollector(CalibrationDataCollector):
old_hist_edges = np.hstack((old_hist_edges, new_bin_edges))
hist, hist_edges = np.histogram(data_arr, bins=old_hist_edges)
hist[: len(old_hist)] += old_hist
self.histogram_dict[tensor] = (hist, hist_edges)
self.histogram_dict[tensor] = (hist, hist_edges, min(old_min, min_value), max(old_max, max_value))
def collect_value(self, name_to_arr):
"""
@ -688,6 +697,7 @@ class HistogramCollector(CalibrationDataCollector):
cdf = np.cumsum(hist / total)
if self.symmetric:
idx_right = np.searchsorted(cdf, percentile / 100.0)
thresholds_dict[tensor] = (
-float(hist_edges[idx_right]),
float(hist_edges[idx_right]),
@ -700,7 +710,12 @@ class HistogramCollector(CalibrationDataCollector):
float(hist_edges[idx_left]),
float(hist_edges[idx_right]),
)
min_value = histogram[2]
max_value = histogram[3]
if thresholds_dict[tensor][0] < min_value:
thresholds_dict[tensor] = (min_value, thresholds_dict[tensor][1])
if thresholds_dict[tensor][1] > max_value:
thresholds_dict[tensor] = (thresholds_dict[tensor][0], max_value)
# Plot histogram for debug only
if False:
apply_plot(hist, hist_edges)