From d63b90538eb97ac281f7a28a809759846901e5d0 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Tue, 2 Jun 2020 19:39:46 -0700 Subject: [PATCH] Symbolic shape inference exit on models without onnx opset used (#4090) * Symbolic shape inference exit on models without onnx opset used * Temporary fix for ConvTranspose with symbolic input dims Co-authored-by: Changming Sun --- .../nuphar/scripts/symbolic_shape_infer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py index 2a339b9c28..65b3d6ccae 100755 --- a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py +++ b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py @@ -1139,6 +1139,13 @@ class SymbolicShapeInference: self._onnx_infer_single_node(node) if node.op_type in self.dispatcher_: self.dispatcher_[node.op_type](node) + elif node.op_type in ['ConvTranspose']: + # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input + # before adding symbolic compute for them + # mark the output type as UNDEFINED to allow guessing of rank + vi = self.known_vi_[node.output[0]] + if len(vi.type.tensor_type.shape.dim) == 0: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED if self.verbose_ > 2: print(node.op_type + ': ' + node.name) @@ -1246,8 +1253,9 @@ class SymbolicShapeInference: @staticmethod def infer_shapes(input_model, output_model, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0): in_mp = onnx.load(input_model) - if get_opset(in_mp) < 7: - print('Only support models of opset 7 and above.') + onnx_opset = get_opset(in_mp) + if not onnx_opset or onnx_opset < 7: + print('Only support models of onnx opset 7 and above.') return symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose) all_shapes_inferred = False @@ -1277,4 +1285,4 @@ if __name__ == '__main__': print('output model ' + args.output) print('Doing symbolic shape inference...') out_mp = SymbolicShapeInference.infer_shapes(args.input, args.output, args.int_max, args.auto_merge, args.guess_output_rank, args.verbose) - print('Done!') \ No newline at end of file + print('Done!')