mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Fix bug for percentile calibration module. (#13376)
### Description Fix bug for percentile calibration module. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
fc12abf6b1
commit
ac5948cb48
1 changed files with 18 additions and 3 deletions
|
|
@ -570,14 +570,23 @@ class HistogramCollector(CalibrationDataCollector):
|
|||
for tensor, data_arr in name_to_arr.items():
|
||||
data_arr = np.asarray(data_arr)
|
||||
data_arr = data_arr.flatten()
|
||||
if data_arr.size > 0:
|
||||
min_value = np.min(data_arr)
|
||||
max_value = np.max(data_arr)
|
||||
else:
|
||||
min_value = 0
|
||||
max_value = 0
|
||||
|
||||
data_arr = np.absolute(data_arr) # only consider absolute value
|
||||
|
||||
if tensor not in self.histogram_dict:
|
||||
# first time it uses num_bins to compute histogram.
|
||||
hist, hist_edges = np.histogram(data_arr, bins=self.num_bins)
|
||||
self.histogram_dict[tensor] = (hist, hist_edges)
|
||||
self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value)
|
||||
else:
|
||||
old_histogram = self.histogram_dict[tensor]
|
||||
old_min = old_histogram[2]
|
||||
old_max = old_histogram[3]
|
||||
old_hist = old_histogram[0]
|
||||
old_hist_edges = old_histogram[1]
|
||||
temp_amax = np.max(data_arr)
|
||||
|
|
@ -589,7 +598,7 @@ class HistogramCollector(CalibrationDataCollector):
|
|||
old_hist_edges = np.hstack((old_hist_edges, new_bin_edges))
|
||||
hist, hist_edges = np.histogram(data_arr, bins=old_hist_edges)
|
||||
hist[: len(old_hist)] += old_hist
|
||||
self.histogram_dict[tensor] = (hist, hist_edges)
|
||||
self.histogram_dict[tensor] = (hist, hist_edges, min(old_min, min_value), max(old_max, max_value))
|
||||
|
||||
def collect_value(self, name_to_arr):
|
||||
"""
|
||||
|
|
@ -688,6 +697,7 @@ class HistogramCollector(CalibrationDataCollector):
|
|||
cdf = np.cumsum(hist / total)
|
||||
if self.symmetric:
|
||||
idx_right = np.searchsorted(cdf, percentile / 100.0)
|
||||
|
||||
thresholds_dict[tensor] = (
|
||||
-float(hist_edges[idx_right]),
|
||||
float(hist_edges[idx_right]),
|
||||
|
|
@ -700,7 +710,12 @@ class HistogramCollector(CalibrationDataCollector):
|
|||
float(hist_edges[idx_left]),
|
||||
float(hist_edges[idx_right]),
|
||||
)
|
||||
|
||||
min_value = histogram[2]
|
||||
max_value = histogram[3]
|
||||
if thresholds_dict[tensor][0] < min_value:
|
||||
thresholds_dict[tensor] = (min_value, thresholds_dict[tensor][1])
|
||||
if thresholds_dict[tensor][1] > max_value:
|
||||
thresholds_dict[tensor] = (thresholds_dict[tensor][0], max_value)
|
||||
# Plot histogram for debug only
|
||||
if False:
|
||||
apply_plot(hist, hist_edges)
|
||||
|
|
|
|||
Loading…
Reference in a new issue