Symbolic shape inference exit on models without onnx opset used (#4090)

* Symbolic shape inference exit on models without onnx opset used

* Temporary fix for ConvTranspose with symbolic input dims

Co-authored-by: Changming Sun <me@sunchangming.com>
This commit is contained in:
KeDengMS 2020-06-02 19:39:46 -07:00 committed by GitHub
parent 6f8a4f4cad
commit d63b90538e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1139,6 +1139,13 @@ class SymbolicShapeInference:
self._onnx_infer_single_node(node)
if node.op_type in self.dispatcher_:
self.dispatcher_[node.op_type](node)
elif node.op_type in ['ConvTranspose']:
# onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
# before adding symbolic compute for them
# mark the output type as UNDEFINED to allow guessing of rank
vi = self.known_vi_[node.output[0]]
if len(vi.type.tensor_type.shape.dim) == 0:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
if self.verbose_ > 2:
print(node.op_type + ': ' + node.name)
@ -1246,8 +1253,9 @@ class SymbolicShapeInference:
@staticmethod
def infer_shapes(input_model, output_model, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
in_mp = onnx.load(input_model)
if get_opset(in_mp) < 7:
print('Only support models of opset 7 and above.')
onnx_opset = get_opset(in_mp)
if not onnx_opset or onnx_opset < 7:
print('Only support models of onnx opset 7 and above.')
return
symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose)
all_shapes_inferred = False
@ -1277,4 +1285,4 @@ if __name__ == '__main__':
print('output model ' + args.output)
print('Doing symbolic shape inference...')
out_mp = SymbolicShapeInference.infer_shapes(args.input, args.output, args.int_max, args.auto_merge, args.guess_output_rank, args.verbose)
print('Done!')
print('Done!')