Replace TensorRT's deprecated API in caffe2/python/trt/test_pt_onnx_trt.py (#60236)

Summary:
TensorRT v8 is going to remove some functions/methods that used in test.

ref:
- getMaxWorkspaceSize deprecation: b2d60b6e10/include/NvInfer.h (L6984-L6993)
- buildCudaEngine deprecation: b2d60b6e10/include/NvInfer.h (L7079-L7087)

cc ptrblck

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60236

Reviewed By: gchanan

Differential Revision: D29232376

Pulled By: ngimel

fbshipit-source-id: 2b8a48787bf61c68a81568b6026d6afd5a83e751
This commit is contained in:
Masaki Kozuki 2021-06-19 19:55:26 -07:00 committed by Facebook GitHub Bot
parent 5ec4ad7f54
commit c19acf816f

View file

@ -61,16 +61,18 @@ class Test_PT_ONNX_TRT(unittest.TestCase):
self.image_files[index] = os.path.abspath(os.path.join(data_path, f))
if not os.path.exists(self.image_files[index]):
raise FileNotFoundError(self.image_files[index] + " does not exist.")
self.labels = open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r').read().split('\n')
with open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r') as f:
self.labels = f.read().split('\n')
def build_engine_onnx(self, model_file):
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags = 1) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = 1 << 33
builder_config = builder.create_builder_config()
builder_config.max_workspace_size = 1 << 33
with open(model_file, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
self.fail("ERROR: {}".format(parser.get_error(error)))
return builder.build_cuda_engine(network)
return builder.build_engine(network, builder_config)
def _test_model(self, model_name, input_shape = (3, 224, 224), normalization_hint = 0):