Update ORT format model conversion utility to optionally fail fast on model conversion failure. (#8589)

This commit is contained in:
Edward Chen 2021-08-03 11:12:56 -07:00 committed by GitHub
parent deab284e4c
commit e09321f4db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -54,7 +54,7 @@ def _create_session_options(optimization_level: ort.GraphOptimizationLevel,
def _convert(model_path_or_dir: pathlib.Path, optimization_level_str: str, use_nnapi: bool, use_coreml: bool,
custom_op_library: pathlib.Path, create_optimized_onnx_model: bool):
custom_op_library: pathlib.Path, create_optimized_onnx_model: bool, allow_conversion_failures: bool):
optimization_level = _get_optimization_level(optimization_level_str)
@ -124,6 +124,8 @@ def _convert(model_path_or_dir: pathlib.Path, optimization_level_str: str, use_n
# onnx_target_path, ort_target_path, orig_size, new_size, new_size - orig_size, new_size / orig_size))
except Exception as e:
print("Error converting {}: {}".format(model, e))
if not allow_conversion_failures:
raise
num_failures += 1
print("Converted {} models. {} failures.".format(len(models), num_failures))
@ -192,6 +194,9 @@ def parse_args():
help='Save the optimized version of each ONNX model. '
'This will have the same optimizations applied as the ORT format model.')
parser.add_argument('--allow_conversion_failures', action='store_true',
help='Whether to proceed after encountering model conversion failures.')
parser.add_argument('model_path_or_dir', type=pathlib.Path,
help='Provide path to ONNX model or directory containing ONNX model/s to convert. '
'All files with a .onnx extension, including in subdirectories, will be processed.')
@ -222,7 +227,7 @@ def convert_onnx_models_to_ort():
raise ValueError('The CoreML Execution Provider was not included in this build of ONNX Runtime.')
_convert(model_path_or_dir, args.optimization_level, args.use_nnapi, args.use_coreml, custom_op_library,
args.save_optimized_onnx_model)
args.save_optimized_onnx_model, args.allow_conversion_failures)
_create_config_file_from_ort_models(model_path_or_dir, args.optimization_level, args.enable_type_reduction)