Add python example of TensorRT INT8 inference on ResNet model (#6255)

* add trt int8 example on resnet model

* Update e2e_tensorrt_resnet_example.py

* remove keras dependency and update class names

* move ImageNetDataReader and ImageClassificationEvaluator to tensorrt resnet example

* simplify e2e_tensorrt_resnet_example.py

* Update preprocessing.py

* merge tensorrt_calibrate

* Update calibrate.py

* Update calibrate.py

* generalize calibrate

* Update calibrate.py

* fix issues

* fix formating

* remove augment_all
This commit is contained in:
stevenlix 2021-01-15 09:59:56 -08:00 committed by GitHub
parent f5a4f7fc2a
commit eab164e1a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 368 additions and 30 deletions

View file

@ -0,0 +1,332 @@
import os
import onnx
import glob
import scipy.io
import numpy as np
from PIL import Image
import onnx
import onnxruntime
from onnxruntime.quantization import CalibrationDataReader, calibrate, write_calibration_table
class ImageNetDataReader(CalibrationDataReader):
def __init__(self, image_folder,
width=224,
height=224,
start_index=0,
end_index=0,
stride=1,
batch_size=1,
model_path='augmented_model.onnx',
input_name='data'):
'''
:param image_folder: image dataset folder
:param width: image width
:param height: image height
:param start_index: start index of images
:param end_index: end index of images
:param stride: image size of each data get
:param batch_size: batch size of inference
:param model_path: model name and path
:param input_name: model input name
'''
self.image_folder = image_folder + "/val"
self.model_path = model_path
self.preprocess_flag = True
self.enum_data_dicts = iter([])
self.datasize = 0
self.width = width
self.height = height
self.start_index = start_index
self.end_index = len(os.listdir(self.image_folder)) if end_index == 0 else end_index
self.stride = stride if stride >= 1 else 1
self.batch_size = batch_size
self.input_name = input_name
def get_dataset_size(self):
return len(os.listdir(self.image_folder))
def get_input_name(self):
if self.input_name:
return
session = onnxruntime.InferenceSession(self.model_path, providers=['CPUExecutionProvider'])
self.input_name = session.get_inputs()[0].name
def get_next(self):
iter_data = next(self.enum_data_dicts, None)
if iter_data:
return iter_data
self.enum_data_dicts = None
if self.start_index < self.end_index:
if self.batch_size == 1:
data = self.load_serial()
else:
data = self.load_batches()
self.start_index += self.stride
self.enum_data_dicts = iter(data)
return next(self.enum_data_dicts, None)
else:
return None
def load_serial(self):
width = self.width
height = self.width
nchw_data_list, filename_list, image_size_list = self.preprocess_imagenet(self.image_folder, height, width, self.start_index, self.stride)
input_name = self.input_name
data = []
for i in range(len(nchw_data_list)):
nhwc_data = nchw_data_list[i]
file_name = filename_list[i]
data.append({input_name: nhwc_data})
return data
def load_batches(self):
width = self.width
height = self.height
batch_size = self.batch_size
stride = self.stride
input_name = self.input_name
batches = []
for index in range(0, stride, batch_size):
start_index = self.start_index + index
nchw_data_list, filename_list, image_size_list = self.preprocess_imagenet(self.image_folder, height, width, start_index, batch_size)
if nchw_data_list.size == 0:
break
nchw_data_batch = []
for i in range(len(nchw_data_list)):
nhwc_data = np.squeeze(nchw_data_list[i], 0)
nchw_data_batch.append(nhwc_data)
batch_data = np.concatenate(np.expand_dims(nchw_data_batch, axis=0), axis=0)
data = {input_name: batch_data}
batches.append(data)
return batches
def preprocess_imagenet(self, images_folder, height, width, start_index=0, size_limit=0):
'''
Loads a batch of images and preprocess them
parameter images_folder: path to folder storing images
parameter height: image height in pixels
parameter width: image width in pixels
parameter start_index: image index to start with
parameter size_limit: number of images to load. Default is 0 which means all images are picked.
return: list of matrices characterizing multiple images
'''
def preprocess_images(input, channels=3, height=224, width=224):
image = input.resize((width, height), Image.ANTIALIAS)
input_data = np.asarray(image).astype(np.float32)
if len(input_data.shape) != 2:
input_data = input_data.transpose([2, 0, 1])
else:
input_data = np.stack([input_data] * 3)
mean = np.array([0.079, 0.05, 0]) + 0.406
std = np.array([0.005, 0, 0.001]) + 0.224
for channel in range(input_data.shape[0]):
input_data[channel, :, :] = (input_data[channel, :, :] / 255 - mean[channel]) / std[channel]
return input_data
image_names = os.listdir(images_folder)
image_names.sort()
if size_limit > 0 and len(image_names) >= size_limit:
end_index = start_index + size_limit
if end_index > len(image_names):
end_index = len(image_names)
batch_filenames = [image_names[i] for i in range(start_index, end_index)]
else:
batch_filenames = image_names
unconcatenated_batch_data = []
image_size_list = []
for image_name in batch_filenames:
image_filepath = images_folder + '/' + image_name
img = Image.open(image_filepath)
image_data = preprocess_images(img)
image_data = np.expand_dims(image_data, 0)
unconcatenated_batch_data.append(image_data)
image_size_list.append(np.array([img.size[1], img.size[0]], dtype=np.float32).reshape(1, 2))
batch_data = np.concatenate(np.expand_dims(unconcatenated_batch_data, axis=0), axis=0)
return batch_data, batch_filenames, image_size_list
def get_synset_id(self, image_folder, offset, dataset_size):
ilsvrc2012_meta = scipy.io.loadmat(image_folder + "/devkit/data/meta.mat")
id_to_synset = {}
for i in range(1000):
id = int(ilsvrc2012_meta["synsets"][i,0][0][0][0])
id_to_synset[id] = ilsvrc2012_meta["synsets"][i,0][1][0]
synset_to_id = {}
file = open(image_folder + "/synset_words.txt","r")
index = 0
for line in file:
parts = line.split(" ")
synset_to_id[parts[0]] = index
index = index + 1
file.close()
file = open(image_folder + "/devkit/data/ILSVRC2012_validation_ground_truth.txt","r")
id = file.read().strip().split("\n")
id = list(map(int, id))
file.close()
image_names = os.listdir(image_folder + "/val")
image_names.sort()
image_names = image_names[offset : offset + dataset_size]
seq_num = []
for file in image_names:
seq_num.append(int(file.split("_")[-1].split(".")[0]))
id = np.array([id[index - 1] for index in seq_num])
synset_id = np.array([synset_to_id[id_to_synset[index]] for index in id])
# one-hot encoding
synset_id_onehot = np.zeros((len(synset_id), 1000), dtype=np.float32)
for i, id in enumerate(synset_id):
synset_id_onehot[i, id] = 1.0
return synset_id_onehot
class ImageClassificationEvaluator:
def __init__(self, model_path, synset_id,
data_reader: CalibrationDataReader,
providers=["TensorrtExecutionProvider"]
):
'''
:param model_path: ONNX model to validate
:param synset_id: ILSVRC2012 synset id
:param data_reader: user implemented object to read in and preprocess calibration dataset
based on CalibrationDataReader Interface
:param providers: ORT execution provider type
'''
self.model_path = model_path
self.data_reader = data_reader
self.providers = providers
self.prediction_result_list = []
self.synset_id = synset_id
def get_result(self):
return self.prediction_result_list
def predict(self):
sess_options = onnxruntime.SessionOptions()
sess_options.log_severity_level = 0
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=self.providers)
inference_outputs_list = []
while True:
inputs = self.data_reader.get_next()
if not inputs:
break
output = session.run(None, inputs)
inference_outputs_list.append(output)
self.prediction_result_list = inference_outputs_list
def top_k_accuracy(self, truth, prediction, k=1):
'''From https://github.com/chainer/chainer/issues/606
'''
y = np.argsort(prediction)[:,-k:]
return np.any(y.T == truth.argmax(axis=1), axis=0).mean()
def evaluate(self, prediction_results):
batch_size = len(prediction_results[0][0])
total_val_images = len(prediction_results) * batch_size
y_prediction = np.empty((total_val_images, 1000), dtype=np.float32)
i = 0
for res in prediction_results:
y_prediction[i:i + batch_size,:] = res[0]
i = i + batch_size
print("top 1: ", self.top_k_accuracy(self.synset_id, y_prediction, k=1))
print("top 5: ", self.top_k_accuracy(self.synset_id, y_prediction, k=5))
def convert_model_batch_to_dynamic(model_path):
model = onnx.load(model_path)
input = model.graph.input
input_name = input[0].name
shape = input[0].type.tensor_type.shape
dim = shape.dim
if not dim[0].dim_param:
dim[0].dim_param = 'N'
model = onnx.shape_inference.infer_shapes(model)
model_name = model_path.split(".")
model_path = model_name[0] + "_dynamic.onnx"
onnx.save(model, model_path)
return [model_path, input_name]
def get_dataset_size(dataset_path, calibration_dataset_size):
total_dataset_size = len(os.listdir(dataset_path + "/val"))
if calibration_dataset_size > total_dataset_size:
print("Warning: calibration data size is bigger than available dataset. Will assign half of the dataset for calibration")
calibration_dataset_size = total_dataset_size // 2
calibration_dataset_size = (calibration_dataset_size // batch_size) * batch_size
if calibration_dataset_size == 0:
print("Warning: No dataset is assigned for calibration. Please use bigger dataset")
prediction_dataset_size = ((total_dataset_size - calibration_dataset_size) // batch_size) * batch_size
if prediction_dataset_size <= 0:
print("Warning: No dataset is assigned for evaluation. Please use bigger dataset")
return [calibration_dataset_size, prediction_dataset_size]
if __name__ == '__main__':
'''
TensorRT EP INT8 Inference on Resnet model
The script is using ILSVRC2012 ImageNet dataset for calibration and prediction.
Please prepare the dataset as below,
1. Create dataset folder 'ILSVRC2012' in workspace.
2. Download ILSVRC2012 validation dataset and development kit from http://www.image-net.org/challenges/LSVRC/2012/downloads.
3. Extract validation dataset JPEG files to 'ILSVRC2012/val'.
4. Extract development kit to 'ILSVRC2012/devkit'. Two files in the development kit are used, 'ILSVRC2012_validation_ground_truth.txt' and 'meta.mat'.
5. Download 'synset_words.txt' from https://github.com/HoldenCaulfieldRye/caffe/blob/master/data/ilsvrc12/synset_words.txt into 'ILSVRC2012/'.
Please download Resnet50 model from ONNX model zoo https://github.com/onnx/models/blob/master/vision/classification/resnet/model/resnet50-v2-7.tar.gz
Untar the model into the workspace
'''
# Dataset settings
model_path = "./resnet50-v2-7.onnx"
ilsvrc2012_dataset_path = "./ILSVRC2012"
augmented_model_path = "./augmented_model.onnx"
batch_size = 20
calibration_dataset_size = 1000 # Size of dataset for calibration
# INT8 calibration setting
calibration_table_generation_enable = True # Enable/Disable INT8 calibration
# TensorRT EP INT8 settings
os.environ["ORT_TENSORRT_FP16_ENABLE"] = "1" # Enable FP16 precision
os.environ["ORT_TENSORRT_INT8_ENABLE"] = "1" # Enable INT8 precision
os.environ["ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name
os.environ["ORT_TENSORRT_ENGINE_CACHE_ENABLE"] = "1" # Enable engine caching
execution_provider = ["TensorrtExecutionProvider"]
# Convert static batch to dynamic batch
[new_model_path, input_name] = convert_model_batch_to_dynamic(model_path)
# Get calibration and prediction dataset size
[calibration_dataset_size, prediction_dataset_size] = get_dataset_size(ilsvrc2012_dataset_path, calibration_dataset_size)
# Generate INT8 calibration table
if calibration_table_generation_enable:
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,start_index=0, end_index=calibration_dataset_size, stride=calibration_dataset_size, batch_size=batch_size, model_path=augmented_model_path, input_name=input_name)
# For TensorRT calibration, augment all FP32 tensors (empty op_types), disable ORT graph optimization and skip quantization parameter calculation
calibration_cache = calibrate(new_model_path, data_reader, op_types=[], providers=["CUDAExecutionProvider"], ort_graph_optimization_enable=False, quantization_params_calculation_enable=False)
write_calibration_table(calibration_cache)
# Run prediction in Tensorrt EP
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path, start_index=calibration_dataset_size, end_index=calibration_dataset_size + prediction_dataset_size, stride=prediction_dataset_size, batch_size=batch_size, model_path=new_model_path, input_name=input_name)
synset_id = data_reader.get_synset_id(ilsvrc2012_dataset_path, calibration_dataset_size, prediction_dataset_size) # Generate synset id
evaluator = ImageClassificationEvaluator(new_model_path, synset_id, data_reader, providers=execution_provider)
evaluator.predict()
result = evaluator.get_result()
evaluator.evaluate(result)

View file

@ -43,7 +43,6 @@ class ONNXCalibrater:
:param black_nodes: operator names that should not be quantized, default = ''
:param white_nodes: operator names that force to be quantized, default = ''
:param augmented_model_path: save augmented_model to this path
'''
if isinstance(model, string_types):
self.model = onnx.load(model)
@ -66,7 +65,7 @@ class ONNXCalibrater:
def get_calibration_cache(self):
return self.calibration_cache
def augment_graph(self, augment_all_ops=False):
def augment_graph(self):
'''
Adds ReduceMin and ReduceMax nodes to all quantization_candidates op type nodes in
model and ensures their outputs are stored as part of the graph output
@ -84,11 +83,8 @@ class ONNXCalibrater:
tensors_to_calibrate = set()
for node in model.graph.node:
if augment_all_ops:
should_be_calibrate = True
else:
should_be_calibrate = ((node.op_type in self.calibrate_op_types) and
(node.name not in self.black_nodes)) or (node.name in self.white_nodes)
should_be_calibrate = ((node.op_type in self.calibrate_op_types) and
(node.name not in self.black_nodes)) or (node.name in self.white_nodes) or ((not self.calibrate_op_types) and (node.name not in self.black_nodes))
if should_be_calibrate:
for tensor_name in itertools.chain(node.input, node.output):
if tensor_name in value_infos.keys():
@ -99,10 +95,10 @@ class ONNXCalibrater:
# If augmenting all ops, it's possible that some nodes' input value are 0.
# Can't reduce on dim with value of 0 if 'keepdims' is false, therefore set keepdims to 1.
if augment_all_ops:
keepdims_value = 1
else:
if self.calibrate_op_types:
keepdims_value = 0
else:
keepdims_value = 1
for tensor in tensors_to_calibrate:
# Adding ReduceMin nodes
@ -129,26 +125,28 @@ class ONNXCalibrater:
return model
#Using augmented outputs to generate inputs for quantization
def get_intermediate_outputs(self, calib_mode='naive', providers=None):
def get_intermediate_outputs(self, calib_mode='naive', providers=None, ort_graph_optimization_enable=True):
'''
Gather intermediate model outputs after running inference
parameter calib_mode: type 'naive' gives (ReduceMin, ReduceMax) pairs
for each augmented node across test data sets, where
the first element is a minimum of all ReduceMin values
and the second element is a maximum of all ReduceMax
values;
:return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs }
Gather intermediate model outputs after running inference
parameter calib_mode: type 'naive' gives (ReduceMin, ReduceMax) pairs
for each augmented node across test data sets, where
the first element is a minimum of all ReduceMin values
and the second element is a maximum of all ReduceMax
values;
parameter providers: Onnxruntime execution providers
parameter ort_graph_optimization_enable: Enable all OnnxRuntime graph optimizations, default = True
:return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs }
'''
#conduct inference session and get intermediate outputs
if providers:
if ort_graph_optimization_enable:
session = onnxruntime.InferenceSession(self.augmented_model_path, None)
else:
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL #ORT_ENABLE_BASIC
session = onnxruntime.InferenceSession(self.augmented_model_path,
sess_options=sess_options,
providers=providers)
else:
session = onnxruntime.InferenceSession(self.augmented_model_path, None)
#number of outputs in original model
num_model_outputs = len(self.model.graph.output)
@ -329,8 +327,8 @@ def calculate_calibration_data(model,
calibrator = get_calibrator(model,
calibration_data_reader,
op_types_to_quantize,
nodes_to_exclude,
nodes_to_quantize,
nodes_to_exclude,
augmented_model_path=augmented_model_path)
if not os.path.exists(augmented_model_path):
@ -365,15 +363,21 @@ def calibrate(model,
op_types=['Conv', 'MatMul'],
black_nodes=[],
white_nodes=[],
augmented_model_path='augmented_model.onnx'):
augmented_model_path='augmented_model.onnx',
providers=["CPUExecutionProvider"],
ort_graph_optimization_enable=True,
quantization_params_calculation_enable=True):
'''
Given an onnx model, augment and run the augmented model on calibration data set, aggregate and calculate the quantization parameters.
Given an onnx model, augment and run the augmented model on calibration data set, aggregate and calculate the quantization parameters.
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
:param data_reader: user implemented object to read in and preprocess calibration dataset based on CalibrationDataReader interface
:param op_types: operator types to be calibrated and quantized, default = 'Conv,MatMul'
:param op_types: operator types to be calibrated and quantized, default = 'Conv,MatMul'. Empty means to quantize all FP32 tensors (except black_nodes)
:param black_nodes: operator names that should not be quantized, default = ''
:param white_nodes: operator names that force to be quantized, default = ''
:param augmented_model_path: save augmented_model to this path
:param providers: execution providers to run calibration
:param ort_graph_optimization_enable: enable all OnnxRuntime graph optimizations, default = True
:param quantization_params_calculation_enable: enable quantization parameter calculation, default = True
'''
#1. initialize a calibrater
calibrater = ONNXCalibrater(model, data_reader, op_types, black_nodes, white_nodes, augmented_model_path)
@ -381,9 +385,11 @@ def calibrate(model,
augmented_model = calibrater.augment_graph()
onnx.save(augmented_model, augmented_model_path)
#3. generate quantization thresholds
dict_for_quantization = calibrater.get_intermediate_outputs()
dict_for_quantization = calibrater.get_intermediate_outputs(providers=providers, ort_graph_optimization_enable=ort_graph_optimization_enable)
#4. generate quantization parameters dict
quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization)
quantization_params_dict = {}
if quantization_params_calculation_enable:
quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization)
print("Calibrated,quantized parameters calculated and returned.")
return quantization_params_dict
return quantization_params_dict if quantization_params_calculation_enable else dict_for_quantization

View file

@ -197,7 +197,7 @@ def write_calibration_table(calibration_cache):
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
print(calibration_cache)
print("calibration cache: ", calibration_cache)
with open("calibration.json", 'w') as file:
file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse