mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Add python example of TensorRT INT8 inference on ResNet model (#6255)
* add trt int8 example on resnet model * Update e2e_tensorrt_resnet_example.py * remove keras dependency and update class names * move ImageNetDataReader and ImageClassificationEvaluator to tensorrt resnet example * simplify e2e_tensorrt_resnet_example.py * Update preprocessing.py * merge tensorrt_calibrate * Update calibrate.py * Update calibrate.py * generalize calibrate * Update calibrate.py * fix issues * fix formating * remove augment_all
This commit is contained in:
parent
f5a4f7fc2a
commit
eab164e1a5
3 changed files with 368 additions and 30 deletions
|
|
@ -0,0 +1,332 @@
|
|||
import os
|
||||
import onnx
|
||||
import glob
|
||||
import scipy.io
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import onnx
|
||||
import onnxruntime
|
||||
from onnxruntime.quantization import CalibrationDataReader, calibrate, write_calibration_table
|
||||
|
||||
class ImageNetDataReader(CalibrationDataReader):
|
||||
def __init__(self, image_folder,
|
||||
width=224,
|
||||
height=224,
|
||||
start_index=0,
|
||||
end_index=0,
|
||||
stride=1,
|
||||
batch_size=1,
|
||||
model_path='augmented_model.onnx',
|
||||
input_name='data'):
|
||||
'''
|
||||
:param image_folder: image dataset folder
|
||||
:param width: image width
|
||||
:param height: image height
|
||||
:param start_index: start index of images
|
||||
:param end_index: end index of images
|
||||
:param stride: image size of each data get
|
||||
:param batch_size: batch size of inference
|
||||
:param model_path: model name and path
|
||||
:param input_name: model input name
|
||||
'''
|
||||
|
||||
self.image_folder = image_folder + "/val"
|
||||
self.model_path = model_path
|
||||
self.preprocess_flag = True
|
||||
self.enum_data_dicts = iter([])
|
||||
self.datasize = 0
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.start_index = start_index
|
||||
self.end_index = len(os.listdir(self.image_folder)) if end_index == 0 else end_index
|
||||
self.stride = stride if stride >= 1 else 1
|
||||
self.batch_size = batch_size
|
||||
self.input_name = input_name
|
||||
|
||||
def get_dataset_size(self):
|
||||
return len(os.listdir(self.image_folder))
|
||||
|
||||
def get_input_name(self):
|
||||
if self.input_name:
|
||||
return
|
||||
session = onnxruntime.InferenceSession(self.model_path, providers=['CPUExecutionProvider'])
|
||||
self.input_name = session.get_inputs()[0].name
|
||||
|
||||
def get_next(self):
|
||||
iter_data = next(self.enum_data_dicts, None)
|
||||
if iter_data:
|
||||
return iter_data
|
||||
|
||||
self.enum_data_dicts = None
|
||||
if self.start_index < self.end_index:
|
||||
if self.batch_size == 1:
|
||||
data = self.load_serial()
|
||||
else:
|
||||
data = self.load_batches()
|
||||
|
||||
self.start_index += self.stride
|
||||
self.enum_data_dicts = iter(data)
|
||||
|
||||
return next(self.enum_data_dicts, None)
|
||||
else:
|
||||
return None
|
||||
|
||||
def load_serial(self):
|
||||
width = self.width
|
||||
height = self.width
|
||||
nchw_data_list, filename_list, image_size_list = self.preprocess_imagenet(self.image_folder, height, width, self.start_index, self.stride)
|
||||
input_name = self.input_name
|
||||
|
||||
data = []
|
||||
for i in range(len(nchw_data_list)):
|
||||
nhwc_data = nchw_data_list[i]
|
||||
file_name = filename_list[i]
|
||||
data.append({input_name: nhwc_data})
|
||||
return data
|
||||
|
||||
def load_batches(self):
|
||||
width = self.width
|
||||
height = self.height
|
||||
batch_size = self.batch_size
|
||||
stride = self.stride
|
||||
input_name = self.input_name
|
||||
|
||||
batches = []
|
||||
for index in range(0, stride, batch_size):
|
||||
start_index = self.start_index + index
|
||||
nchw_data_list, filename_list, image_size_list = self.preprocess_imagenet(self.image_folder, height, width, start_index, batch_size)
|
||||
|
||||
if nchw_data_list.size == 0:
|
||||
break
|
||||
|
||||
nchw_data_batch = []
|
||||
for i in range(len(nchw_data_list)):
|
||||
nhwc_data = np.squeeze(nchw_data_list[i], 0)
|
||||
nchw_data_batch.append(nhwc_data)
|
||||
batch_data = np.concatenate(np.expand_dims(nchw_data_batch, axis=0), axis=0)
|
||||
data = {input_name: batch_data}
|
||||
|
||||
batches.append(data)
|
||||
|
||||
return batches
|
||||
|
||||
def preprocess_imagenet(self, images_folder, height, width, start_index=0, size_limit=0):
|
||||
'''
|
||||
Loads a batch of images and preprocess them
|
||||
parameter images_folder: path to folder storing images
|
||||
parameter height: image height in pixels
|
||||
parameter width: image width in pixels
|
||||
parameter start_index: image index to start with
|
||||
parameter size_limit: number of images to load. Default is 0 which means all images are picked.
|
||||
return: list of matrices characterizing multiple images
|
||||
'''
|
||||
|
||||
def preprocess_images(input, channels=3, height=224, width=224):
|
||||
image = input.resize((width, height), Image.ANTIALIAS)
|
||||
input_data = np.asarray(image).astype(np.float32)
|
||||
if len(input_data.shape) != 2:
|
||||
input_data = input_data.transpose([2, 0, 1])
|
||||
else:
|
||||
input_data = np.stack([input_data] * 3)
|
||||
mean = np.array([0.079, 0.05, 0]) + 0.406
|
||||
std = np.array([0.005, 0, 0.001]) + 0.224
|
||||
for channel in range(input_data.shape[0]):
|
||||
input_data[channel, :, :] = (input_data[channel, :, :] / 255 - mean[channel]) / std[channel]
|
||||
return input_data
|
||||
|
||||
image_names = os.listdir(images_folder)
|
||||
image_names.sort()
|
||||
if size_limit > 0 and len(image_names) >= size_limit:
|
||||
end_index = start_index + size_limit
|
||||
if end_index > len(image_names):
|
||||
end_index = len(image_names)
|
||||
batch_filenames = [image_names[i] for i in range(start_index, end_index)]
|
||||
else:
|
||||
batch_filenames = image_names
|
||||
|
||||
unconcatenated_batch_data = []
|
||||
image_size_list = []
|
||||
|
||||
for image_name in batch_filenames:
|
||||
image_filepath = images_folder + '/' + image_name
|
||||
img = Image.open(image_filepath)
|
||||
image_data = preprocess_images(img)
|
||||
image_data = np.expand_dims(image_data, 0)
|
||||
unconcatenated_batch_data.append(image_data)
|
||||
image_size_list.append(np.array([img.size[1], img.size[0]], dtype=np.float32).reshape(1, 2))
|
||||
|
||||
batch_data = np.concatenate(np.expand_dims(unconcatenated_batch_data, axis=0), axis=0)
|
||||
return batch_data, batch_filenames, image_size_list
|
||||
|
||||
def get_synset_id(self, image_folder, offset, dataset_size):
|
||||
ilsvrc2012_meta = scipy.io.loadmat(image_folder + "/devkit/data/meta.mat")
|
||||
id_to_synset = {}
|
||||
for i in range(1000):
|
||||
id = int(ilsvrc2012_meta["synsets"][i,0][0][0][0])
|
||||
id_to_synset[id] = ilsvrc2012_meta["synsets"][i,0][1][0]
|
||||
|
||||
synset_to_id = {}
|
||||
file = open(image_folder + "/synset_words.txt","r")
|
||||
index = 0
|
||||
for line in file:
|
||||
parts = line.split(" ")
|
||||
synset_to_id[parts[0]] = index
|
||||
index = index + 1
|
||||
file.close()
|
||||
|
||||
file = open(image_folder + "/devkit/data/ILSVRC2012_validation_ground_truth.txt","r")
|
||||
id = file.read().strip().split("\n")
|
||||
id = list(map(int, id))
|
||||
file.close()
|
||||
|
||||
image_names = os.listdir(image_folder + "/val")
|
||||
image_names.sort()
|
||||
image_names = image_names[offset : offset + dataset_size]
|
||||
seq_num = []
|
||||
for file in image_names:
|
||||
seq_num.append(int(file.split("_")[-1].split(".")[0]))
|
||||
id = np.array([id[index - 1] for index in seq_num])
|
||||
synset_id = np.array([synset_to_id[id_to_synset[index]] for index in id])
|
||||
|
||||
# one-hot encoding
|
||||
synset_id_onehot = np.zeros((len(synset_id), 1000), dtype=np.float32)
|
||||
for i, id in enumerate(synset_id):
|
||||
synset_id_onehot[i, id] = 1.0
|
||||
return synset_id_onehot
|
||||
|
||||
class ImageClassificationEvaluator:
|
||||
def __init__(self, model_path, synset_id,
|
||||
data_reader: CalibrationDataReader,
|
||||
providers=["TensorrtExecutionProvider"]
|
||||
):
|
||||
'''
|
||||
:param model_path: ONNX model to validate
|
||||
:param synset_id: ILSVRC2012 synset id
|
||||
:param data_reader: user implemented object to read in and preprocess calibration dataset
|
||||
based on CalibrationDataReader Interface
|
||||
:param providers: ORT execution provider type
|
||||
'''
|
||||
|
||||
self.model_path = model_path
|
||||
self.data_reader = data_reader
|
||||
self.providers = providers
|
||||
self.prediction_result_list = []
|
||||
self.synset_id = synset_id
|
||||
|
||||
def get_result(self):
|
||||
return self.prediction_result_list
|
||||
|
||||
def predict(self):
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.log_severity_level = 0
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=self.providers)
|
||||
|
||||
inference_outputs_list = []
|
||||
while True:
|
||||
inputs = self.data_reader.get_next()
|
||||
if not inputs:
|
||||
break
|
||||
output = session.run(None, inputs)
|
||||
inference_outputs_list.append(output)
|
||||
self.prediction_result_list = inference_outputs_list
|
||||
|
||||
def top_k_accuracy(self, truth, prediction, k=1):
|
||||
'''From https://github.com/chainer/chainer/issues/606
|
||||
'''
|
||||
|
||||
y = np.argsort(prediction)[:,-k:]
|
||||
return np.any(y.T == truth.argmax(axis=1), axis=0).mean()
|
||||
|
||||
def evaluate(self, prediction_results):
|
||||
batch_size = len(prediction_results[0][0])
|
||||
total_val_images = len(prediction_results) * batch_size
|
||||
y_prediction = np.empty((total_val_images, 1000), dtype=np.float32)
|
||||
i = 0
|
||||
for res in prediction_results:
|
||||
y_prediction[i:i + batch_size,:] = res[0]
|
||||
i = i + batch_size
|
||||
print("top 1: ", self.top_k_accuracy(self.synset_id, y_prediction, k=1))
|
||||
print("top 5: ", self.top_k_accuracy(self.synset_id, y_prediction, k=5))
|
||||
|
||||
def convert_model_batch_to_dynamic(model_path):
|
||||
model = onnx.load(model_path)
|
||||
input = model.graph.input
|
||||
input_name = input[0].name
|
||||
shape = input[0].type.tensor_type.shape
|
||||
dim = shape.dim
|
||||
if not dim[0].dim_param:
|
||||
dim[0].dim_param = 'N'
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
model_name = model_path.split(".")
|
||||
model_path = model_name[0] + "_dynamic.onnx"
|
||||
onnx.save(model, model_path)
|
||||
return [model_path, input_name]
|
||||
|
||||
def get_dataset_size(dataset_path, calibration_dataset_size):
|
||||
total_dataset_size = len(os.listdir(dataset_path + "/val"))
|
||||
if calibration_dataset_size > total_dataset_size:
|
||||
print("Warning: calibration data size is bigger than available dataset. Will assign half of the dataset for calibration")
|
||||
calibration_dataset_size = total_dataset_size // 2
|
||||
calibration_dataset_size = (calibration_dataset_size // batch_size) * batch_size
|
||||
if calibration_dataset_size == 0:
|
||||
print("Warning: No dataset is assigned for calibration. Please use bigger dataset")
|
||||
|
||||
prediction_dataset_size = ((total_dataset_size - calibration_dataset_size) // batch_size) * batch_size
|
||||
if prediction_dataset_size <= 0:
|
||||
print("Warning: No dataset is assigned for evaluation. Please use bigger dataset")
|
||||
return [calibration_dataset_size, prediction_dataset_size]
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
TensorRT EP INT8 Inference on Resnet model
|
||||
|
||||
The script is using ILSVRC2012 ImageNet dataset for calibration and prediction.
|
||||
Please prepare the dataset as below,
|
||||
1. Create dataset folder 'ILSVRC2012' in workspace.
|
||||
2. Download ILSVRC2012 validation dataset and development kit from http://www.image-net.org/challenges/LSVRC/2012/downloads.
|
||||
3. Extract validation dataset JPEG files to 'ILSVRC2012/val'.
|
||||
4. Extract development kit to 'ILSVRC2012/devkit'. Two files in the development kit are used, 'ILSVRC2012_validation_ground_truth.txt' and 'meta.mat'.
|
||||
5. Download 'synset_words.txt' from https://github.com/HoldenCaulfieldRye/caffe/blob/master/data/ilsvrc12/synset_words.txt into 'ILSVRC2012/'.
|
||||
|
||||
Please download Resnet50 model from ONNX model zoo https://github.com/onnx/models/blob/master/vision/classification/resnet/model/resnet50-v2-7.tar.gz
|
||||
Untar the model into the workspace
|
||||
'''
|
||||
|
||||
# Dataset settings
|
||||
model_path = "./resnet50-v2-7.onnx"
|
||||
ilsvrc2012_dataset_path = "./ILSVRC2012"
|
||||
augmented_model_path = "./augmented_model.onnx"
|
||||
batch_size = 20
|
||||
calibration_dataset_size = 1000 # Size of dataset for calibration
|
||||
|
||||
# INT8 calibration setting
|
||||
calibration_table_generation_enable = True # Enable/Disable INT8 calibration
|
||||
|
||||
# TensorRT EP INT8 settings
|
||||
os.environ["ORT_TENSORRT_FP16_ENABLE"] = "1" # Enable FP16 precision
|
||||
os.environ["ORT_TENSORRT_INT8_ENABLE"] = "1" # Enable INT8 precision
|
||||
os.environ["ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name
|
||||
os.environ["ORT_TENSORRT_ENGINE_CACHE_ENABLE"] = "1" # Enable engine caching
|
||||
execution_provider = ["TensorrtExecutionProvider"]
|
||||
|
||||
# Convert static batch to dynamic batch
|
||||
[new_model_path, input_name] = convert_model_batch_to_dynamic(model_path)
|
||||
|
||||
# Get calibration and prediction dataset size
|
||||
[calibration_dataset_size, prediction_dataset_size] = get_dataset_size(ilsvrc2012_dataset_path, calibration_dataset_size)
|
||||
|
||||
# Generate INT8 calibration table
|
||||
if calibration_table_generation_enable:
|
||||
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,start_index=0, end_index=calibration_dataset_size, stride=calibration_dataset_size, batch_size=batch_size, model_path=augmented_model_path, input_name=input_name)
|
||||
# For TensorRT calibration, augment all FP32 tensors (empty op_types), disable ORT graph optimization and skip quantization parameter calculation
|
||||
calibration_cache = calibrate(new_model_path, data_reader, op_types=[], providers=["CUDAExecutionProvider"], ort_graph_optimization_enable=False, quantization_params_calculation_enable=False)
|
||||
write_calibration_table(calibration_cache)
|
||||
|
||||
# Run prediction in Tensorrt EP
|
||||
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path, start_index=calibration_dataset_size, end_index=calibration_dataset_size + prediction_dataset_size, stride=prediction_dataset_size, batch_size=batch_size, model_path=new_model_path, input_name=input_name)
|
||||
synset_id = data_reader.get_synset_id(ilsvrc2012_dataset_path, calibration_dataset_size, prediction_dataset_size) # Generate synset id
|
||||
evaluator = ImageClassificationEvaluator(new_model_path, synset_id, data_reader, providers=execution_provider)
|
||||
evaluator.predict()
|
||||
result = evaluator.get_result()
|
||||
evaluator.evaluate(result)
|
||||
|
|
@ -43,7 +43,6 @@ class ONNXCalibrater:
|
|||
:param black_nodes: operator names that should not be quantized, default = ''
|
||||
:param white_nodes: operator names that force to be quantized, default = ''
|
||||
:param augmented_model_path: save augmented_model to this path
|
||||
|
||||
'''
|
||||
if isinstance(model, string_types):
|
||||
self.model = onnx.load(model)
|
||||
|
|
@ -66,7 +65,7 @@ class ONNXCalibrater:
|
|||
def get_calibration_cache(self):
|
||||
return self.calibration_cache
|
||||
|
||||
def augment_graph(self, augment_all_ops=False):
|
||||
def augment_graph(self):
|
||||
'''
|
||||
Adds ReduceMin and ReduceMax nodes to all quantization_candidates op type nodes in
|
||||
model and ensures their outputs are stored as part of the graph output
|
||||
|
|
@ -84,11 +83,8 @@ class ONNXCalibrater:
|
|||
tensors_to_calibrate = set()
|
||||
|
||||
for node in model.graph.node:
|
||||
if augment_all_ops:
|
||||
should_be_calibrate = True
|
||||
else:
|
||||
should_be_calibrate = ((node.op_type in self.calibrate_op_types) and
|
||||
(node.name not in self.black_nodes)) or (node.name in self.white_nodes)
|
||||
should_be_calibrate = ((node.op_type in self.calibrate_op_types) and
|
||||
(node.name not in self.black_nodes)) or (node.name in self.white_nodes) or ((not self.calibrate_op_types) and (node.name not in self.black_nodes))
|
||||
if should_be_calibrate:
|
||||
for tensor_name in itertools.chain(node.input, node.output):
|
||||
if tensor_name in value_infos.keys():
|
||||
|
|
@ -99,10 +95,10 @@ class ONNXCalibrater:
|
|||
|
||||
# If augmenting all ops, it's possible that some nodes' input value are 0.
|
||||
# Can't reduce on dim with value of 0 if 'keepdims' is false, therefore set keepdims to 1.
|
||||
if augment_all_ops:
|
||||
keepdims_value = 1
|
||||
else:
|
||||
if self.calibrate_op_types:
|
||||
keepdims_value = 0
|
||||
else:
|
||||
keepdims_value = 1
|
||||
|
||||
for tensor in tensors_to_calibrate:
|
||||
# Adding ReduceMin nodes
|
||||
|
|
@ -129,26 +125,28 @@ class ONNXCalibrater:
|
|||
return model
|
||||
|
||||
#Using augmented outputs to generate inputs for quantization
|
||||
def get_intermediate_outputs(self, calib_mode='naive', providers=None):
|
||||
def get_intermediate_outputs(self, calib_mode='naive', providers=None, ort_graph_optimization_enable=True):
|
||||
'''
|
||||
Gather intermediate model outputs after running inference
|
||||
parameter calib_mode: type 'naive' gives (ReduceMin, ReduceMax) pairs
|
||||
for each augmented node across test data sets, where
|
||||
the first element is a minimum of all ReduceMin values
|
||||
and the second element is a maximum of all ReduceMax
|
||||
values;
|
||||
:return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs }
|
||||
Gather intermediate model outputs after running inference
|
||||
parameter calib_mode: type 'naive' gives (ReduceMin, ReduceMax) pairs
|
||||
for each augmented node across test data sets, where
|
||||
the first element is a minimum of all ReduceMin values
|
||||
and the second element is a maximum of all ReduceMax
|
||||
values;
|
||||
parameter providers: Onnxruntime execution providers
|
||||
parameter ort_graph_optimization_enable: Enable all OnnxRuntime graph optimizations, default = True
|
||||
:return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs }
|
||||
'''
|
||||
|
||||
#conduct inference session and get intermediate outputs
|
||||
if providers:
|
||||
if ort_graph_optimization_enable:
|
||||
session = onnxruntime.InferenceSession(self.augmented_model_path, None)
|
||||
else:
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL #ORT_ENABLE_BASIC
|
||||
session = onnxruntime.InferenceSession(self.augmented_model_path,
|
||||
sess_options=sess_options,
|
||||
providers=providers)
|
||||
else:
|
||||
session = onnxruntime.InferenceSession(self.augmented_model_path, None)
|
||||
|
||||
#number of outputs in original model
|
||||
num_model_outputs = len(self.model.graph.output)
|
||||
|
|
@ -329,8 +327,8 @@ def calculate_calibration_data(model,
|
|||
calibrator = get_calibrator(model,
|
||||
calibration_data_reader,
|
||||
op_types_to_quantize,
|
||||
nodes_to_exclude,
|
||||
nodes_to_quantize,
|
||||
nodes_to_exclude,
|
||||
augmented_model_path=augmented_model_path)
|
||||
|
||||
if not os.path.exists(augmented_model_path):
|
||||
|
|
@ -365,15 +363,21 @@ def calibrate(model,
|
|||
op_types=['Conv', 'MatMul'],
|
||||
black_nodes=[],
|
||||
white_nodes=[],
|
||||
augmented_model_path='augmented_model.onnx'):
|
||||
augmented_model_path='augmented_model.onnx',
|
||||
providers=["CPUExecutionProvider"],
|
||||
ort_graph_optimization_enable=True,
|
||||
quantization_params_calculation_enable=True):
|
||||
'''
|
||||
Given an onnx model, augment and run the augmented model on calibration data set, aggregate and calculate the quantization parameters.
|
||||
Given an onnx model, augment and run the augmented model on calibration data set, aggregate and calculate the quantization parameters.
|
||||
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
|
||||
:param data_reader: user implemented object to read in and preprocess calibration dataset based on CalibrationDataReader interface
|
||||
:param op_types: operator types to be calibrated and quantized, default = 'Conv,MatMul'
|
||||
:param op_types: operator types to be calibrated and quantized, default = 'Conv,MatMul'. Empty means to quantize all FP32 tensors (except black_nodes)
|
||||
:param black_nodes: operator names that should not be quantized, default = ''
|
||||
:param white_nodes: operator names that force to be quantized, default = ''
|
||||
:param augmented_model_path: save augmented_model to this path
|
||||
:param providers: execution providers to run calibration
|
||||
:param ort_graph_optimization_enable: enable all OnnxRuntime graph optimizations, default = True
|
||||
:param quantization_params_calculation_enable: enable quantization parameter calculation, default = True
|
||||
'''
|
||||
#1. initialize a calibrater
|
||||
calibrater = ONNXCalibrater(model, data_reader, op_types, black_nodes, white_nodes, augmented_model_path)
|
||||
|
|
@ -381,9 +385,11 @@ def calibrate(model,
|
|||
augmented_model = calibrater.augment_graph()
|
||||
onnx.save(augmented_model, augmented_model_path)
|
||||
#3. generate quantization thresholds
|
||||
dict_for_quantization = calibrater.get_intermediate_outputs()
|
||||
dict_for_quantization = calibrater.get_intermediate_outputs(providers=providers, ort_graph_optimization_enable=ort_graph_optimization_enable)
|
||||
#4. generate quantization parameters dict
|
||||
quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization)
|
||||
|
||||
quantization_params_dict = {}
|
||||
if quantization_params_calculation_enable:
|
||||
quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization)
|
||||
print("Calibrated,quantized parameters calculated and returned.")
|
||||
return quantization_params_dict
|
||||
|
||||
return quantization_params_dict if quantization_params_calculation_enable else dict_for_quantization
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ def write_calibration_table(calibration_cache):
|
|||
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
|
||||
import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
|
||||
|
||||
print(calibration_cache)
|
||||
print("calibration cache: ", calibration_cache)
|
||||
|
||||
with open("calibration.json", 'w') as file:
|
||||
file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse
|
||||
|
|
|
|||
Loading…
Reference in a new issue