### Flash attn recompute
1. Allow PythonOp(FlashAttn) can be recomputed correctly.
45879ff5c2
2. Use JSON to pass the selected-to-recompute subgraphs.
3c374da678
#### Better Memory Efficiency
Customer model can run both PyTorch SPDA and Flash Attn, this PR make it
possible to let the Flash Attn path work with ORTModule layerwise
recompute. The peak drop from 45.xGB to 32.xGB if we only compare the
layers (not including other pieces, BTW there are few more optimization
targeting other pieces as well later).
#### Better Perf
Using Flash ATTN bring additionally 16% end to end time reduction, with
highly aligned loss curve.

#### Use JSON File to pass Recompute Plans
To overcome the limitation of max length of the strings defined in
session options.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Description
The InsertGatherBeforeSceLoss optimization is enabled when the density
of label padding less than 90%. We need to check the density of the
label padding to decide whether enable the optimization.
Before this pr, we just check the inputs of graph and correlate one with
the SCE node by iterate graph from the SCE node back to one graph input.
This is hard to be general because there may be complicated pattern
between graph input and SCE node.
This pr check padding density by the direct input of SCE module rather
than the input of graph at the first graph execution when exporting onnx
graph.
And if the density < 90%, insert a flag PythonOp after the SCE node as:
```
SoftmaxCrossEntropy
|
PythonOp (func_name: FlagAndPrintDensity) (insert if density < 90%)
|
Following graph
```
When the InsertGatherBeforeSceLoss is invoked, it check if there is the
flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if
it is, remove it and do the padding elimination optimization.
If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input
density each step by the PythonOp (func_name: FlagAndPrintDensity). In
this case the PythonOp will not be removed.
### Improve perf for mem efficient grad mgmt
When memory efficient gradient mangement feature is enabled, the weight
retrieval PythonOp for every layers will be launched at the beginning of
the forward, which would make GPU stream idle for few milliseconds. The
reason is the ReversedDFS ordering cannot ALWAYS handle such input
branching well, so we introduce a distantance-to-input_leaf concepts
when doing the reversedDFS, which not only move the problematical
PythonOp to the place where it is needed, but also those Cast ops
following the weight retrieval to the place where it is needed.
Main branch: 102.19 - 26.35s = 75.84s for 260 steps(4627samples),
61.04sample/second
This PR: 100.28s - 25.10s = 75.18s for 260 steps. 61.54samples/second
(+0.8% gains)
Main branch:

This PR:

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Description
Add bf16 support for below ops:
ConstantOfShape
Exp
Erf
convolution
PythonOp
### Motivation and Context
phimm model works on bf16, ORT need support bf16 on previous ops to work
with phimm on bf16
### Description
In Deepspeed's Pipeline Parallel Implementation, there is a class used
to instantiate the object after it's moved to the device and assigned in
a stage.
This approach helps reduce peak memory usage.
In this PR, we're adding support to ORT for wrapping this LayerSpec.
### Introduce memory efficient topo sort (for training)
~~and laze initialize Priority-Based and Memory-Efficient topo sort.
Because in most cases, they are not needed, so we free the overheads of
GraphViewer construction for most use cases.~~
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Description
Introducing a new class ORTPipelineModule to handle wrapping layers in
DeepSpeed pipeline parallel.
### Motivation and Context
To support pipeline parallelism on ORTModule.
This PR will include an initial support of deepspeed Pipeline
parallelism.
- [x] Support Pipeline parallel where layers are nn Modules in
Sequential.
- [ ] Support LayerSpec and TiedLayerSpec
- [ ] Enable partitioning to accept List
- [ ] Full-GPU Graph Consolidation
- [ ] Subgraph Merging for Inference
Previous implementation used numpy array and numpy data_type to store
constant value and data type, which is not support BFloat16 natively.
This PR is to switch to use torch tensor which supports BFloat16.
### Description
The PaddingElimination optimization is enabled when the density of
embedding padding less than 90%. We need to check the density of the
embedding padding to decide whether enable the optimization.
Before this pr, we just check the inputs of graph and correlate one with
the embedding node by iterate graph from the embedding node back to one
graph input.
This is hard to be general because there may be complicated pattern
between graph input and embedding node.
This pr check padding density by the direct input of embedding module
rather than the input of graph at the first graph execution when
exporting onnx graph.
And if the density < 90%, insert a flag PythonOp after the embedding
node as:
```
Embedding
|
PythonOp (func_name:_FlagPaddingElimination) (insert if density < 90%)
|
Following graph
```
When the PaddingElimination is invoked, it check if there is the flag
PythonOp(func_name:_FlagPaddingElimination) after the Embedding node and
if it is, remove it and do the padding elimination optimization.
### Prompt layer-wise when applicable
Give explicit prompts in export failures to users to enable layer-wise
memory optimization if we found the checkpoint function is used.
- Using checkpoint function is a strong indicator that the model is too
large to fit in GPU memory.
- If we don't override the checkpoint function here, mostly ONNX export
will be failed. 1. For old version PyTorch, when handling gradient
checkpoint feature, we just throw an exception. 2. For new version
PyTorch, an export failure happens.
- But both failures did not give users explicitly "HOW" to mitigate.
This PR did that.
``

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Fix memory stats printing
The mmeory stats printing is failed when module is in eval mode, doing
ORTModule wrap. At that time, runtime inspector for training manager
should have training model being true, but got a false (because existing
logic get the boolean from module.training). Runtime inspector as part
of training manager or inference manager should know it is serving
training or not explicitly, so we cannot depend on the stat of
module.training during ORTModule initialization.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Description
Why we need to define softmax export logic here?
For the usage `nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32)` in the model,
76a33a1092/src/transformers/models/mistral/modeling_mistral.py (L302)
If dtype is specified, the input tensor is casted to dtype before the
operation is performed.
This is useful for preventing data type overflows. While existing ONNX
exporter do the cast after the operation, which is not correct.
(cf06189a2d/torch/onnx/symbolic_opset13.py (L27)).
This override can be a workaround before PyTorch fix the issues in
coming releases.
(TODO: pengwa - add PyTorch versions when the issue is fixed).
@thiagocrepaldi We may need a fix in PyTorch repo as well.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Fix torch cpp extension build warnings
For the warnings shown as below:
```
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
[4/5] c++ -MMD -MF /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.o.d -pthread -B /opt/conda/envs/ptca/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/TH -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/THC -I/opt/conda/envs/ptca/include/python3.8 -c -c /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc -o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.o -O3 -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=torch_interop_utils -D_GLIBCXX_USE_CXX11_ABI=0
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
In file included from /opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/python_arg_parser.h:65,
from /opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/tensor_new.h:4,
from /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc:9:
/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/python_strings.h:104:19: warning: ‘pybind11::object PyObject_FastGetAttrString(PyObject*, const char*)’ defined but not used [-Wunused-function]
104 | static py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) {
| ^~~~~~~~~~~~~~~~~~~~~~~~~~
[5/5] c++ -MMD -MF /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.o.d -pthread -B /opt/conda/envs/ptca/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/TH -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/THC -I/opt/conda/envs/ptca/include/python3.8 -c -c /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc -o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.o -O3 -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=torch_interop_utils -D_GLIBCXX_USE_CXX11_ABI=0
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
In file included from /opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/python_arg_parser.h:65,
from /opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/tensor_new.h:4,
from /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc:13:
/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/python_strings.h:104:19: warning: ‘pybind11::object PyObject_FastGetAttrString(PyObject*, const char*)’ defined but not used [-Wunused-function]
104 | static py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) {
| ^~~~~~~~~~~~~~~~~~~~~~~~~~
g++ -pthread -B /opt/conda/envs/ptca/compiler_compat -Wl,--sysroot=/ -pthread -shared -B /opt/conda/envs/ptca/compiler_compat -L/opt/conda/envs/ptca/lib -Wl,-rpath=/opt/conda/envs/ptca/lib -Wl,--no-as-needed -Wl,--sysroot=/ /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.o -L/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -o build/lib.linux-x86_64-cpython-38/torch_interop_utils.cpython-38-x86_64-linux-gnu.so
Installing /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/lib.linux-x86_64-cpython-38/fused_ops.cpython-38-x86_64-linux-gnu.so -> /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/fused_ops.cpython-38-x86_64-linux-gnu.so
Installing /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/lib.linux-x86_64-cpython-38/aten_op_executor.cpython-38-x86_64-linux-gnu.so -> /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor.cpython-38-x86_64-linux-gnu.so
Installing /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/lib.linux-x86_64-cpython-38/torch_gpu_allocator.cpython-38-x86_64-linux-gnu.so -> /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator.cpython-38-x86_64-linux-gnu.so
Installing /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/lib.linux-x86_64-cpython-38/torch_interop_utils.cpython-38-x86_64-linux-gnu.so -> /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_interop_utils.cpython-38-x86_64-linux-gnu.so
```
Fix by replacing eixsting `PyObject_GetAttrString` with
`PyObject_FastGetAttrString` which claims to be faster in its
implementation comment.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Fix and enable few ORTModule Unit Tests
Fix 'test_bert_inputs_with_dynamic_shape' and
'test_bert_result_with_layerwise_recompute' generate Nan loss in ORT
run.
The root cause is, the logic to generatic attention mask test data is
not correct, only 0 or 1 is allowed in the dataset, but we see lots of
other numbers. ( The reason we don't have this using old version of
transformers for example v4.4.2 or 4.16.2 is because they don't contains
such
d3cb28886a,
which increase the scaling to a bigger number, causing a overflow to
inf)
Another improvement during the investigation using convergence tools:
Don't dump the activations during model export phase, otherwise, the
dumped data might contains some PyTorch run's result making us confused
during comparing with stock PyTorch run results.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This PR:
- add support for int as return type, will create a CPU scalar tensor
for it.
- add attributes to specify which arguments or returns are CPU tensors.
- adjust ATen efficient attn to match latest PyTorch native function.
- a Triton codegen bugfix by the way.
### Fix seed for recomputed Dropout
If Dropout node is recomputed in the backward, we should make sure its
execution is same as the run in the forward.
If we don't set seed attribute, then this cannot be guaranteed.
Add ` export ORTMODULE_MEMORY_OPT_LEVEL=2` to enabled per layer
recompute with compromised recomputable subgraphs.
# loss function extra inputs.
Currently, the loss functions in onnxblock expect exactly two inputs in
their build method.
Occasionally, models may pass additional inputs, causing the build
function to fail.
To solve this issue, we can let users pass a list of loss input names to
be used in the loss function.
Including removing a unnecessary assert, and add support of passing
string attribute from ONNX node attribute to python functoin kwargs
(mainly for passing debug info from graph to python for now).
### Description
<!-- Describe your changes. -->
This PR upgrades ORTModule's default opset from 15 to 17. Opset 17 is
the final opset supported by torchscript exporter
(https://github.com/pytorch/pytorch/pull/107829)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Engineering excellence contribution for ORT Training DRI.
---------
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
### Description
<!-- Describe your changes. -->
Add ATen fallback support for bicubic interpolation algorithm.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Required for facebook/dinov2 model architecture as part of ONNX Runtime
integration with AML Vision models.
When using scaled_dot_product_attention on float16 type, the exported
graph has Sqrt(float16(constant)), which cannot be ConstantFold in ORT
because Sqrt CPU kernel doesn't support float16. This causes Triton
code-gen generates code like:
result = 128.0.to(tl.float32)
This code cannot be compiled because .to() cannot be applied to
constant.
This PR is to handle such case that constant number will not do the
Cast.
### Fix missing subgraph candidates for recompute
For subgraphs for example `MatMul+Transpose+Reshape`, since the ending
node is a Reshape, in ORT, it is reusing input buffers.
Currently, the subgraph detection logic has defect, as a result, those
subgraphs will be missing as recompute candidates.
Also append a few more node types for recompute support.
TODO: add unit test later. This PR is needed for a customer model now.
### Description
<!-- Describe your changes. -->
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Improve perf for stage3 training - first wave
Port existing PythonOp/PythonOpGrad python runner to C++, also introduce
an unsafe run mode (to skip inplace, save for backward, materrialized
grad detection on the fly).
This reduce the overhead from XX~XXX us to X ~ lower end of XX us . In
LLAMA2 7B training with 8x32GV100, we have observed 6.7% gains over
PyTorch. (1.59 v.s. 1.49it/s)
Peak memory also dropped from 31GB to 28GB.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Allow layer-wise recompute
Early, we need users/developers to specify the subgraphs to recompute,
now we introduced a more user-friendly way to enable recompute for all
detected stashed activation recomputation subgraphs. This scarifies
getting the best configs while makes it easier to support user
requirements when they switches from PyTorch per-layer gradient
checkpoint to ORTModule.
`ORTMODULE_MEMORY_OPT_LEVEL` is introduced to control the usage, by
default, it is 0, e.g. `USER_SPECIFIED`, all subgraphs definedin
`ORTMODULE_MEMORY_OPT_CONFIG` will be recomputed. So this is compatible
to existing recompute usage in ORTModule integrated models.
Using `ORTMODULE_MEMORY_OPT_LEVEL=1`, we will enable all recompute plans
detected, so those configs in `ORTMODULE_MEMORY_OPT_CONFIG` will not be
respected any more.
Add Unit Tests using 3 layer blooms.
https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md
### Skip module clone for preparing large model export
For LLAMA2 13B, when running with Lora, DeepSpeed stage2 on 8 GPUs . It
failed during preparing outputs which will be used for
torch.onnx.export. The reason, we deep copy all the params including
both big sizes of frozen weights, + a little bit of Lora trainable
weight.
This PR will firstly check whether the GPU memmory is enough for a
cloned module, if not, skip the copy.
Copying the module is to guarantee the fw path run may change the
weight, while this case should be rare. But for now, Not-Able-To-Run is
worse than Runnable-with-A-little-bit-different-initial-weight,
especially for large models.
ORT's default topo-order is a reversed DFS algorithm, while the
priority-based topo-order is a forward BFS algorithm. It's likely that
the default order is better than priority-based order on memory because
tensor memory is more likely to be released right after it's consumed.
Currently ORTModule uses priority-based order, for some models, it sorts
lots of small Ops to the beginning, this introduces big CPU overhead at
the beginning (see below screenshot), this PR is to use default order
for training. The priority-based order is heavily used for some
recompute optimization, so if there is recompute enabled, we will still
use priority-based order.
This PR also adds an optimization to the default order, which is to move
all Shape/Size Ops to right after their parent nodes. This is to make
sure the shape and size nodes are executed right after their parents so
it's possible the input tensor memory can be released as soon as
possible. This is especially important for non-CPU devices or for
training case where some gradient graphs use only shape/size of tensors
from forward.
Profiling result:
Before
<img width="910" alt="截屏2023-11-13 12 09 02"
src="https://github.com/microsoft/onnxruntime/assets/11661208/e54d5ead-274f-4725-923e-521bbcfce752">
After
<img width="910" alt="截屏2023-11-13 12 10 44"
src="https://github.com/microsoft/onnxruntime/assets/11661208/f50d196d-11ac-43a2-9493-517e4552ffab">
This PR:
- Remove unused arguments from generated triton code,
- Remove unnecessary mask for symbolic shape case from generated triton
code.
- Add doc for usage of ORTMODULE_TRITON_CONFIG_FILE.
### Description
`generate_artifacts` generates 4 graphs for training. All graphs should
share the same opset version, the one coming from the model to train,
but the optimizer is left undefined. onnxruntime is using the latest
version defined by onnx but onnxruntime does not necessarily support it.
### Motivation and Context
The code does not let the user change it.