Create profile for all dynamic shape input tensors (#5229)

This commit is contained in:
stevenlix 2020-09-20 05:55:21 -07:00 committed by GitHub
parent cd663d58f5
commit aefb2cc49b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -995,14 +995,12 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
dimension_update[input_name] = true;
}
if (dimension_update[input_name]) {
if (trt_profile == nullptr) {
trt_profile = trt_builder->createOptimizationProfile();
}
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size);
if (trt_profile == nullptr) {
trt_profile = trt_builder->createOptimizationProfile();
}
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size);
} else { // execution tensor
nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims);
for (int j = 0, end = nb_dims; j < end; ++j) {
@ -1030,14 +1028,12 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
}
}
if (dimension_update[input_name]) {
if (trt_profile == nullptr) {
trt_profile = trt_builder->createOptimizationProfile();
}
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max);
if (trt_profile == nullptr) {
trt_profile = trt_builder->createOptimizationProfile();
}
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max);
}
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
}