diff --git a/onnxruntime/python/tools/onnxruntime_test.py b/onnxruntime/python/tools/onnxruntime_test.py index 3542a51f62..3cf0a93b6e 100644 --- a/onnxruntime/python/tools/onnxruntime_test.py +++ b/onnxruntime/python/tools/onnxruntime_test.py @@ -31,6 +31,11 @@ def main(): parser.add_argument('num_iters', nargs='?', type=int, default=1000, help='model run iterations. default=1000') parser.add_argument('--debug', action='store_true', help='pause execution to allow attaching a debugger.') parser.add_argument('--profile', action='store_true', help='enable chrome timeline trace profiling.') + parser.add_argument('--symbolic_dims', default=None, type=lambda s: dict(x.split("=") for x in s.split(",")), + help='Comma separated name=value pairs for any symbolic dimensions in the model input. ' + 'e.g. --symbolic_dims batch=1,seqlen=5. ' + 'If not provided, the value of 1 will be used for all symbolic dimensions.') + args = parser.parse_args() iters = args.num_iters @@ -50,8 +55,21 @@ def main(): feeds = {} for input_meta in sess.get_inputs(): - # replace any symbolic dimensions (value is None) with 1 - shape = [dim if dim else 1 for dim in input_meta.shape] + # replace any symbolic dimensions + shape = [] + for dim in input_meta.shape: + if not dim: + # unknown dim + shape.append(1) + elif type(dim) == str: + # symbolic dim. see if we have a value otherwise use 1 + if dim in args.symbolic_dims: + shape.append(int(args.symbolic_dims[dim])) + else: + shape.append(1) + else: + shape.append(dim) + if input_meta.type in float_dict: feeds[input_meta.name] = np.random.rand(*shape).astype(float_dict[input_meta.type]) elif input_meta.type in integer_dict: