onnxruntime/orttraining/orttraining/python/training
pengwa fe0db63dee
Upstream reshape of merging batch/sequence (#15023)
### Upstream reshape of merging batch/sequence

For Reshape node that fulfills following requirements:
- input data rank = 3
- input shape is constant initializer, the untorched dim value MUST be a
constant value.
- Reshape is merging the first dimension, so output data rank = 2.

We upstream it to make it run as earlier as possible. Doing this will
allow us to upstream other operators (Gather) that is blocked by those
kind of Reshape node.

Currently, we did not enable it in graph_transformer_utils, since the
combined upstream gather changes are not ready yet.

Before:


![image](https://user-images.githubusercontent.com/10530022/224698252-f9705082-9710-4385-95ec-f1ccf50dc0e3.png)


After:


![image](https://user-images.githubusercontent.com/10530022/224698381-7e124d0d-ba47-4f35-8e37-6015014cd1c4.png)
2023-04-05 18:51:07 +08:00
..
amp Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
api Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
experimental Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
onnxblock Miscellaneous updates to training artifact generation (#15315) 2023-04-04 20:09:51 -07:00
optim Enable pylint and numpy rules (#15218) 2023-03-27 20:37:53 -07:00
ortmodule Upstream reshape of merging batch/sequence (#15023) 2023-04-05 18:51:07 +08:00
torchdynamo Enable pylint and numpy rules (#15218) 2023-03-27 20:37:53 -07:00
utils Enable pylint and numpy rules (#15218) 2023-03-27 20:37:53 -07:00
__init__.py Refining the offline tooling for training artifact generation (#15212) 2023-03-30 18:05:51 -07:00
_checkpoint_storage.py Enable pylint and numpy rules (#15218) 2023-03-27 20:37:53 -07:00
_utils.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
artifacts.py Refining the offline tooling for training artifact generation (#15212) 2023-03-30 18:05:51 -07:00
checkpoint.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
model_desc_validation.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
orttrainer.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
orttrainer_options.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
postprocess.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00