Commit graph

7 commits

Author SHA1 Message Date
KeDengMS
ee908eb0aa
Symbolic shape inference: fix rank for ConstantOfShape (#5912) 2020-11-24 14:50:41 -08:00
KeDengMS
32bf6390ad
Some fixes to symbolic shape inference (#5642)
* Some fixes to symbolic shape inference

1. Topological sort before iteration in graph
2. Fix a case in slice: start=100000, end=-100000, step=-1, dim=2
3. Fix Nuphar Gemm test's random seed
4. Slice opset 1 axes is optional
2020-10-30 19:28:47 -07:00
KeDengMS
e1a54c4090
Symbolic shape inference: fix a bug in shape merge (#5519)
* Symbolic shape inference: fix a bug in shape merge

OpType Where:
input0: ['mt_src_tokens_batch', 1, 1, 'mt_src_tokens_len']
input1: []
input2: ['mt_prev_output_tokens_batch', 12, 'mt_prev_output_tokens_len', 'floor(mt_src_tokens_batch*mt_src_tokens_len/mt_prev_output_tokens_batch)'] 1
output: [None, 12, 'mt_prev_output_tokens_len', None]

* Undo unintended TRT change
2020-10-16 17:54:57 -07:00
KeDengMS
7495dc167a
Symbolic shape inference: fix a bug in auto_merge when broadcasting (#5349)
The bug happens when merging following shapes:

input0: [1, 1, 'Min(1024, input1_dynamic_axes_3)', 'Min(1024, input1_dynamic_axes_3)']
input1: ['input1_dynamic_axes_1*input1_dynamic_axes_2', 12, 'input1_dynamic_axes_3', 'input1_dynamic_axes_3']
input2: []

The fix is to avoid broadcasting merge on input2
2020-10-01 15:24:00 -07:00
KeDengMS
5a71819be6
Symbolic shape inference: fix a case for concat (#5277)
* Symbolic shape inference: fix a case when concat requires merge multiple dims

* Fix a bug triggered in newer version of sympy
Fix a bug in output data type guessing
2020-09-24 08:16:47 -07:00
KeDengMS
8dceebda0e
[Training/Python] Add option to enable symbolic shape inference (#5107)
This change adds symbolic shape inference to ORT training which helps static memory planning for model like BART.
2020-09-22 10:49:07 -07:00
KeDengMS
ce3b67e0cd
[Python] Move symbolic_shape_infer from nuphar to tools (#5162)
* [Python] Move symbolic shape inference from nuphar to tools

* Fix PEP8 ERROR
2020-09-18 09:31:06 -07:00
Renamed from onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py (Browse further)