mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Improve handling of symbolic dimensions in the onnxruntime_test.py script. (#3959)
If a symbolic dimension is found allow the user to provide a value, or default to 1. `python .\onnxruntime_test.py --symbolic_dims batch=1,seqlen=4 onnxruntime\test\testdata\transform\fusion\fast_gelu_use_graph_input.onnx`
This commit is contained in:
parent
523d70f667
commit
fd8ea4e466
1 changed files with 20 additions and 2 deletions
|
|
@ -31,6 +31,11 @@ def main():
|
|||
parser.add_argument('num_iters', nargs='?', type=int, default=1000, help='model run iterations. default=1000')
|
||||
parser.add_argument('--debug', action='store_true', help='pause execution to allow attaching a debugger.')
|
||||
parser.add_argument('--profile', action='store_true', help='enable chrome timeline trace profiling.')
|
||||
parser.add_argument('--symbolic_dims', default=None, type=lambda s: dict(x.split("=") for x in s.split(",")),
|
||||
help='Comma separated name=value pairs for any symbolic dimensions in the model input. '
|
||||
'e.g. --symbolic_dims batch=1,seqlen=5. '
|
||||
'If not provided, the value of 1 will be used for all symbolic dimensions.')
|
||||
|
||||
args = parser.parse_args()
|
||||
iters = args.num_iters
|
||||
|
||||
|
|
@ -50,8 +55,21 @@ def main():
|
|||
|
||||
feeds = {}
|
||||
for input_meta in sess.get_inputs():
|
||||
# replace any symbolic dimensions (value is None) with 1
|
||||
shape = [dim if dim else 1 for dim in input_meta.shape]
|
||||
# replace any symbolic dimensions
|
||||
shape = []
|
||||
for dim in input_meta.shape:
|
||||
if not dim:
|
||||
# unknown dim
|
||||
shape.append(1)
|
||||
elif type(dim) == str:
|
||||
# symbolic dim. see if we have a value otherwise use 1
|
||||
if dim in args.symbolic_dims:
|
||||
shape.append(int(args.symbolic_dims[dim]))
|
||||
else:
|
||||
shape.append(1)
|
||||
else:
|
||||
shape.append(dim)
|
||||
|
||||
if input_meta.type in float_dict:
|
||||
feeds[input_meta.name] = np.random.rand(*shape).astype(float_dict[input_meta.type])
|
||||
elif input_meta.type in integer_dict:
|
||||
|
|
|
|||
Loading…
Reference in a new issue