mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Symbolic shape inference exit on models without onnx opset used (#4090)
* Symbolic shape inference exit on models without onnx opset used * Temporary fix for ConvTranspose with symbolic input dims Co-authored-by: Changming Sun <me@sunchangming.com>
This commit is contained in:
parent
6f8a4f4cad
commit
d63b90538e
1 changed files with 11 additions and 3 deletions
|
|
@ -1139,6 +1139,13 @@ class SymbolicShapeInference:
|
|||
self._onnx_infer_single_node(node)
|
||||
if node.op_type in self.dispatcher_:
|
||||
self.dispatcher_[node.op_type](node)
|
||||
elif node.op_type in ['ConvTranspose']:
|
||||
# onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
|
||||
# before adding symbolic compute for them
|
||||
# mark the output type as UNDEFINED to allow guessing of rank
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
if len(vi.type.tensor_type.shape.dim) == 0:
|
||||
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
|
||||
|
||||
if self.verbose_ > 2:
|
||||
print(node.op_type + ': ' + node.name)
|
||||
|
|
@ -1246,8 +1253,9 @@ class SymbolicShapeInference:
|
|||
@staticmethod
|
||||
def infer_shapes(input_model, output_model, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
|
||||
in_mp = onnx.load(input_model)
|
||||
if get_opset(in_mp) < 7:
|
||||
print('Only support models of opset 7 and above.')
|
||||
onnx_opset = get_opset(in_mp)
|
||||
if not onnx_opset or onnx_opset < 7:
|
||||
print('Only support models of onnx opset 7 and above.')
|
||||
return
|
||||
symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose)
|
||||
all_shapes_inferred = False
|
||||
|
|
@ -1277,4 +1285,4 @@ if __name__ == '__main__':
|
|||
print('output model ' + args.output)
|
||||
print('Doing symbolic shape inference...')
|
||||
out_mp = SymbolicShapeInference.infer_shapes(args.input, args.output, args.int_max, args.auto_merge, args.guess_output_rank, args.verbose)
|
||||
print('Done!')
|
||||
print('Done!')
|
||||
|
|
|
|||
Loading…
Reference in a new issue