diff --git a/caffe2/python/trt/test_pt_onnx_trt.py b/caffe2/python/trt/test_pt_onnx_trt.py index e066e8363a1..3ba67ea2124 100644 --- a/caffe2/python/trt/test_pt_onnx_trt.py +++ b/caffe2/python/trt/test_pt_onnx_trt.py @@ -61,16 +61,18 @@ class Test_PT_ONNX_TRT(unittest.TestCase): self.image_files[index] = os.path.abspath(os.path.join(data_path, f)) if not os.path.exists(self.image_files[index]): raise FileNotFoundError(self.image_files[index] + " does not exist.") - self.labels = open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r').read().split('\n') + with open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r') as f: + self.labels = f.read().split('\n') def build_engine_onnx(self, model_file): with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags = 1) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: - builder.max_workspace_size = 1 << 33 + builder_config = builder.create_builder_config() + builder_config.max_workspace_size = 1 << 33 with open(model_file, 'rb') as model: if not parser.parse(model.read()): for error in range(parser.num_errors): self.fail("ERROR: {}".format(parser.get_error(error))) - return builder.build_cuda_engine(network) + return builder.build_engine(network, builder_config) def _test_model(self, model_name, input_shape = (3, 224, 224), normalization_hint = 0):