Improve handling of symbolic dimensions in the onnxruntime_test.py script. (#3959)

If a symbolic dimension is found allow the user to provide a value, or default to 1.

`python .\onnxruntime_test.py --symbolic_dims batch=1,seqlen=4 onnxruntime\test\testdata\transform\fusion\fast_gelu_use_graph_input.onnx`
This commit is contained in:
Scott McKay 2020-05-18 16:51:09 +10:00 committed by GitHub
parent 523d70f667
commit fd8ea4e466
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -31,6 +31,11 @@ def main():
parser.add_argument('num_iters', nargs='?', type=int, default=1000, help='model run iterations. default=1000')
parser.add_argument('--debug', action='store_true', help='pause execution to allow attaching a debugger.')
parser.add_argument('--profile', action='store_true', help='enable chrome timeline trace profiling.')
parser.add_argument('--symbolic_dims', default=None, type=lambda s: dict(x.split("=") for x in s.split(",")),
help='Comma separated name=value pairs for any symbolic dimensions in the model input. '
'e.g. --symbolic_dims batch=1,seqlen=5. '
'If not provided, the value of 1 will be used for all symbolic dimensions.')
args = parser.parse_args()
iters = args.num_iters
@ -50,8 +55,21 @@ def main():
feeds = {}
for input_meta in sess.get_inputs():
# replace any symbolic dimensions (value is None) with 1
shape = [dim if dim else 1 for dim in input_meta.shape]
# replace any symbolic dimensions
shape = []
for dim in input_meta.shape:
if not dim:
# unknown dim
shape.append(1)
elif type(dim) == str:
# symbolic dim. see if we have a value otherwise use 1
if dim in args.symbolic_dims:
shape.append(int(args.symbolic_dims[dim]))
else:
shape.append(1)
else:
shape.append(dim)
if input_meta.type in float_dict:
feeds[input_meta.name] = np.random.rand(*shape).astype(float_dict[input_meta.type])
elif input_meta.type in integer_dict: