Commit graph

484 commits

Author SHA1 Message Date
mindest
009209e016
Fix Orttraining Linux Lazy Tensor CI Pipeline (#21652)
### Description
Fix `Orttraining Linux Lazy Tensor CI Pipeline`
- Remove unused import of `torch.onnx._internal.exporter`, whose path is
changed in newer torch (pytorch/pytorch#132429).
- Move import of `register_custom_op_symbolic` from `torch.onnx` into
local function, which causes circular import when running `import
torch.onnx` (at least in the CI environment).
2024-08-21 18:10:08 +08:00
Justin Chu
c203d89958
Update ruff and clang-format versions (#21479)
ruff -> 0.5.4
clang-format -> 18
2024-07-24 11:50:11 -07:00
Prathik Rao
11ad299451
Adds ATen fallback for scaled_dot_product_attention (#21107)
### Description
<!-- Describe your changes. -->

Introduces an ATen fallback for
`torch.nn.functional.scaled_dot_product_attention`. This operator was
introduced in torch 2.0 and, since then, has had many updates including
the implementation of memory efficient attention for V100 machines. The
current torchscript exporter exports a subgraph for attention which does
not provide the same memory savings that PyTorch's memory efficient
attention kernel provides. Allowing fallback to PyTorch ATen op for
attention helps mitigate memory spike issues for models leveraging
memory efficient attention.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Memory issues arose when integrating ONNX Runtime Training with AML
Stable Diffusion.

---------

Co-authored-by: root <prathikrao@microsoft.com>
2024-07-22 16:37:04 -07:00
mindest
5b9369e93c
Fix typos according to reviewdog report. (#21335)
### Description
Fix typos based on reviewdog report but with some
exceptions/corrections.
2024-07-22 13:37:32 -07:00
pengwa
88336ffa92
Fix typos - 1st Wave (#21278)
### Description

There are so many typos reported by the review dog, [Optional Lint]
actions (example:
https://github.com/microsoft/onnxruntime/actions/runs/9864564489/job/27239732367),
this PR is to fix some of them.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
2024-07-11 13:35:08 +08:00
pengwa
4932e04053
ORTModule GraphTransitionManager (#19007)
### Problem

Currently, the codebase contains some logics pertaining to model
re-export checks and graph_builder reinitialization checks. Ideally,
these operations should function akin to a state machine. However, upon
inspecting the implementation, it becomes apparent that certain states
are checked or set in various scattered locations. This fragmentation
makes it challenging to comprehend when a re-export or re-initialization
will be triggered. For optimal clarity and maintainability, it is
advisable to consolidate these states into a cohesive component, rather
than dispersing them within the current graph execution manager.

Furthermore, the process of model exports and post-export processing for
stage 3 support or memory-efficient gradient management introduces
considerable complexity. To enhance the codebase's structure, it would
be beneficial to extract these intricate functionalities into a
dedicated component, divorcing them from the current graph execution
manager.

As part of the effort to improve the codebase, it's essential to address
inconsistencies in handling input/output flatten/unflatten operations.
Currently, there are several functions performing these operations
recursively, each with slightly different implementations. This
inconsistency leads to varying support for input/output data types and
structures in different parts of the code. To rectify this, the proposed
pull request simplifies these operations into a set of primitive
functions, ensuring uniformity. This not only streamlines the code but
also facilitates the maintenance of consistency when introducing bug
fixes or supporting new data types. One thing to mention here: input
output handling is deeply bound to the graph transition mentioned above,
so it is difficult to make this change separately.

While acknowledging the complexity of these logics, it is reassuring
that the codebase benefits from an extensive suite of unit tests that
cover all possible branches. Despite the intricacies, ensuring the
passage of all tests has been a time-intensive but necessary aspect of
this development effort.

### Design


Introduce `GraphTransitionManager` and put all model export and
post-export processing logics in it.
1. Re-export check
2. Do export
3. Re-post-export process check
4. Do post-export process
5. Return `PostExportProcessedModelInfo`, which contains all the
information we need, to pass to ORT to build gradient graph (currently
we do the same for training or evaluating, but ideally we should not do
it for evaluating, let's keep this behavior as it is now, and make the
change later).
    ```
          # Input names for the pre-gradient-build graph.
# This may be different with the one in ExportedGraph since we may
modify the graph inputs as needed
# for example when memory efficient gradient management is enabled.
self.onnx_graph_input_names: list[str] = onnx_graph_input_names
  
          # A subset of onnx_graph_input_names.
# Input names that require gradients for the pre-gradient-build graph.
self.onnx_graph_input_names_require_grad: list[str] =
onnx_graph_input_names_require_grad
  
# Create symbolic names for each dimension of the graph input (e.g.
onnx_graph_input_names).
# The key is the input name, the value is a dict of {dim_index:
symbolic_dim_name}
# e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0:
"input2_dim0"}}
self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] =
onnx_graph_input_dynamic_axes_map
  
self.buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict()
          self.onnx_graph_input_names_user_defined = (
onnx_graph_input_names_user_defined # The ONNX graph input names
excluding the parameters, buffers.
          )
  
# The ONNX graph input names excluding the parameters, buffers.
self.onnx_graph_input_names_require_grad_user_defined =
onnx_graph_input_names_require_grad_user_defined
  
self._post_export_processed_model: onnx.ModelProto | None =
post_export_processed_model
  
# A function to access the input data from the args and kwargs.
# If it is not None, the length is same as onnx_graph_input_names.
# For i-th input name, we can use the i-th function to get the input
data from args and kwargs.
          self.data_accessor: list[callable] | None = data_accessor
  
          # Used for unflattening the outputs from the ORT forward run.
self.module_forward_output_schema: ORTModelInputOutputSchemaType | None
= module_forward_output_schema```




The `GraphTransitionManager` instance is a property of
`GraphExecutionManager` (e.g. `TrainingManager` or ``InferenceManager),
1. Use
'self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs,
kwargs)' to check whether the PyTorch module need a re-export or
re-post-export-process.
2. Use
`self._graph_transition_manager._post_export_processed_model_info.construct_inputs`
to construct the list of inputs used for ORT runs.
3. Use
`self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)`
to restore the outputs in original PyTorch output structure.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-07-03 10:53:31 +08:00
Caroline Zhu
6236707c64
Enable >2GB models + allow model paths to be passed for generate_artifacts API (#20958)
### Description
Alternative design from #20942 

Allow users to pass in a model path for the generate_artifacts API. 

### Motivation and Context
- ONNX API calls such as the onnx checker + shape inference fail when
given a model > 2GB, but work if a path to a model >2GB is passed in.
2024-06-21 09:55:26 -07:00
pengwa
8a98874e7e
Flash attention recompute (#20603)
### Flash attn recompute

1. Allow PythonOp(FlashAttn) can be recomputed correctly.
45879ff5c2
2. Use JSON to pass the selected-to-recompute subgraphs.
3c374da678

#### Better Memory Efficiency 

Customer model can run both PyTorch SPDA and Flash Attn, this PR make it
possible to let the Flash Attn path work with ORTModule layerwise
recompute. The peak drop from 45.xGB to 32.xGB if we only compare the
layers (not including other pieces, BTW there are few more optimization
targeting other pieces as well later).

#### Better Perf

Using Flash ATTN bring additionally 16% end to end time reduction, with
highly aligned loss curve.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/bb63894a-f281-49bc-a8e6-ff818439be38)

#### Use JSON File to pass Recompute Plans

To overcome the limitation of max length of the strings defined in
session options.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-05-21 13:38:19 +08:00
guyang3532
cfe830b248
Generalize label input sparsity check and refactor (#20636)
### Description
The InsertGatherBeforeSceLoss optimization is enabled when the density
of label padding less than 90%. We need to check the density of the
label padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the SCE node by iterate graph from the SCE node back to one graph input.
This is hard to be general because there may be complicated pattern
between graph input and SCE node.

This pr check padding density by the direct input of SCE module rather
than the input of graph at the first graph execution when exporting onnx
graph.
And if the density < 90%, insert a flag PythonOp after the SCE node as:
```
           SoftmaxCrossEntropy
		  |
            PythonOp (func_name: FlagAndPrintDensity)   (insert if density < 90%)
		  |
            Following graph
```

When the InsertGatherBeforeSceLoss is invoked, it check if there is the
flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if
it is, remove it and do the padding elimination optimization.

If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input
density each step by the PythonOp (func_name: FlagAndPrintDensity). In
this case the PythonOp will not be removed.
2024-05-10 21:55:43 +08:00
pengwa
56f7035521
Improve perf for mem efficient grad mgmt (#20480)
### Improve perf for mem efficient grad mgmt

When memory efficient gradient mangement feature is enabled, the weight
retrieval PythonOp for every layers will be launched at the beginning of
the forward, which would make GPU stream idle for few milliseconds. The
reason is the ReversedDFS ordering cannot ALWAYS handle such input
branching well, so we introduce a distantance-to-input_leaf concepts
when doing the reversedDFS, which not only move the problematical
PythonOp to the place where it is needed, but also those Cast ops
following the weight retrieval to the place where it is needed.

Main branch: 102.19 - 26.35s = 75.84s for 260 steps(4627samples),
61.04sample/second
This PR: 100.28s - 25.10s = 75.18s for 260 steps. 61.54samples/second
(+0.8% gains)

Main branch:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/75c4131e-dade-49b0-aa8b-ee1c637ad9a8)


This PR:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/e590a536-3b80-4f51-b89f-f25a55ddd7e2)


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-05-10 08:09:17 +08:00
Frank Dong
227c4419fc
add bf16 support for few ops (#20385)
### Description
Add bf16 support for below ops:
ConstantOfShape
Exp
Erf
convolution
PythonOp



### Motivation and Context
phimm model works on bf16, ORT need support bf16 on previous ops to work
with phimm on bf16
2024-04-25 11:28:34 -07:00
Adam Louly
4ce7bbf6f1
Add LayerSpec Support to ORTPipelineModule (#20410)
### Description
In Deepspeed's Pipeline Parallel Implementation, there is a class used
to instantiate the object after it's moved to the device and assigned in
a stage.

This approach helps reduce peak memory usage. 

In this PR, we're adding support to ORT for wrapping this LayerSpec.
2024-04-23 17:57:08 -07:00
guyang3532
ffb9c8d598
fix embedding sparsity log bug of -1% density (#20420)
### Description
When not checked valid embedding sparsity, the log print a wrong info of
"-1% density", this pr is to fix it.
2024-04-23 20:37:50 +08:00
pengwa
a7787a0bad
Introduce memory efficient topological sort (#20258)
### Introduce memory efficient topo sort (for training)

~~and laze initialize Priority-Based and Memory-Efficient topo sort.
Because in most cases, they are not needed, so we free the overheads of
GraphViewer construction for most use cases.~~

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-04-23 08:00:23 +08:00
Adam Louly
ee74fb6908
Introducing ORTPipelineModule - DeepSpeed Parallel Pipeline Support. (#20287)
### Description
Introducing a new class ORTPipelineModule to handle wrapping layers in
DeepSpeed pipeline parallel.


### Motivation and Context
To support pipeline parallelism on ORTModule.

This PR will include an initial support of deepspeed Pipeline
parallelism.

- [x] Support Pipeline parallel where layers are nn Modules in
Sequential.
- [ ] Support LayerSpec and TiedLayerSpec
- [ ] Enable partitioning to accept List
- [ ] Full-GPU Graph Consolidation
- [ ] Subgraph Merging for Inference
2024-04-18 11:30:15 -07:00
Vincent Wang
c47f446f25
Support BFloat16 for Triton Codegen (#20353)
Previous implementation used numpy array and numpy data_type to store
constant value and data type, which is not support BFloat16 natively.
This PR is to switch to use torch tensor which supports BFloat16.
2024-04-18 17:15:11 +08:00
guyang3532
471e969e2f
Check padding density by input of embedding module (#19821)
### Description
The PaddingElimination optimization is enabled when the density of
embedding padding less than 90%. We need to check the density of the
embedding padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the embedding node by iterate graph from the embedding node back to one
graph input.
This is hard to be general because there may be complicated pattern
between graph input and embedding node.

This pr check padding density by the direct input of embedding module
rather than the input of graph at the first graph execution when
exporting onnx graph.
And if the density < 90%, insert a flag PythonOp after the embedding
node as:
```
             Embedding
		  |
            PythonOp (func_name:_FlagPaddingElimination)   (insert if density < 90%)
		  |
            Following graph
```

When the PaddingElimination is invoked, it check if there is the flag
PythonOp(func_name:_FlagPaddingElimination) after the Embedding node and
if it is, remove it and do the padding elimination optimization.
2024-04-10 18:45:51 +08:00
pengwa
280b2634c5
Prompt layer-wise recompute when applicable (#20126)
### Prompt layer-wise when applicable

Give explicit prompts in export failures to users to enable layer-wise
memory optimization if we found the checkpoint function is used.
- Using checkpoint function is a strong indicator that the model is too
large to fit in GPU memory.
- If we don't override the checkpoint function here, mostly ONNX export
will be failed. 1. For old version PyTorch, when handling gradient
checkpoint feature, we just throw an exception. 2. For new version
PyTorch, an export failure happens.
- But both failures did not give users explicitly "HOW" to mitigate.
This PR did that.

``


![image](https://github.com/microsoft/onnxruntime/assets/10530022/c0476748-5818-4cc8-b2d6-88c7580fe4da)



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-04-10 11:50:28 +08:00
pengwa
dfa891a2d8
Fix memory stats printing (#20061)
### Fix memory stats printing

The mmeory stats printing is failed when module is in eval mode, doing
ORTModule wrap. At that time, runtime inspector for training manager
should have training model being true, but got a false (because existing
logic get the boolean from module.training). Runtime inspector as part
of training manager or inference manager should know it is serving
training or not explicitly, so we cannot depend on the stat of
module.training during ORTModule initialization.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-03-26 21:25:59 +08:00
pengwa
1a0ba3f69f
Fix softmax export (#20057)
### Description


Why we need to define softmax export logic here?
For the usage `nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32)` in the model,

76a33a1092/src/transformers/models/mistral/modeling_mistral.py (L302)
If dtype is specified, the input tensor is casted to dtype before the
operation is performed.
This is useful for preventing data type overflows. While existing ONNX
exporter do the cast after the operation, which is not correct.
(cf06189a2d/torch/onnx/symbolic_opset13.py (L27)).
This override can be a workaround before PyTorch fix the issues in
coming releases.
 (TODO: pengwa - add PyTorch versions when the issue is fixed).


@thiagocrepaldi We may need a fix in PyTorch repo as well. 

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-03-26 13:09:20 +08:00
Vincent Wang
d30c81d270
Add Symbolic Shape Hint to Triton Codegen Config (#20056)
Add symbolic shape hint to Triton codegen config so that we can avoid
unnecessary recompile when input shapes are keeping changing. Below
screenshot shows that with proper configuration, we can speed up the
training a lot by reducing unnecessary recompile.


![image](https://github.com/microsoft/onnxruntime/assets/11661208/699944d2-81cd-4c22-84e7-73a4fa0d2a28)
2024-03-25 15:05:02 +08:00
Baiju Meswani
226f60f2f1
Add support for SGD optimizer in minimal build (#19901) 2024-03-14 11:31:20 -07:00
Justin Chu
faea42af95
Bump ruff to 0.3.2 and black to 24 (#19878)
### Motivation and Context

Routing updates
2024-03-13 10:00:32 -07:00
pengwa
3fb8905393
Fix torch cpp extension build warnings (#19842)
### Fix torch cpp extension build warnings

For the warnings shown as below: 

```
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
[4/5] c++ -MMD -MF /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.o.d -pthread -B /opt/conda/envs/ptca/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/TH -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/THC -I/opt/conda/envs/ptca/include/python3.8 -c -c /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc -o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.o -O3 -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=torch_interop_utils -D_GLIBCXX_USE_CXX11_ABI=0
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
In file included from /opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/python_arg_parser.h:65,
                 from /opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/tensor_new.h:4,
                 from /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc:9:
/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/python_strings.h:104:19: warning: ‘pybind11::object PyObject_FastGetAttrString(PyObject*, const char*)’ defined but not used [-Wunused-function]
  104 | static py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) {
      |                   ^~~~~~~~~~~~~~~~~~~~~~~~~~
[5/5] c++ -MMD -MF /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.o.d -pthread -B /opt/conda/envs/ptca/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/TH -I/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/THC -I/opt/conda/envs/ptca/include/python3.8 -c -c /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc -o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.o -O3 -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=torch_interop_utils -D_GLIBCXX_USE_CXX11_ABI=0
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
In file included from /opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/python_arg_parser.h:65,
                 from /opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/tensor_new.h:4,
                 from /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc:13:
/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/include/torch/csrc/utils/python_strings.h:104:19: warning: ‘pybind11::object PyObject_FastGetAttrString(PyObject*, const char*)’ defined but not used [-Wunused-function]
  104 | static py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) {
      |                   ^~~~~~~~~~~~~~~~~~~~~~~~~~
g++ -pthread -B /opt/conda/envs/ptca/compiler_compat -Wl,--sysroot=/ -pthread -shared -B /opt/conda/envs/ptca/compiler_compat -L/opt/conda/envs/ptca/lib -Wl,-rpath=/opt/conda/envs/ptca/lib -Wl,--no-as-needed -Wl,--sysroot=/ /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.o /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/temp.linux-x86_64-cpython-38/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.o -L/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -o build/lib.linux-x86_64-cpython-38/torch_interop_utils.cpython-38-x86_64-linux-gnu.so
Installing /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/lib.linux-x86_64-cpython-38/fused_ops.cpython-38-x86_64-linux-gnu.so -> /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/fused_ops.cpython-38-x86_64-linux-gnu.so
Installing /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/lib.linux-x86_64-cpython-38/aten_op_executor.cpython-38-x86_64-linux-gnu.so -> /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor.cpython-38-x86_64-linux-gnu.so
Installing /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/lib.linux-x86_64-cpython-38/torch_gpu_allocator.cpython-38-x86_64-linux-gnu.so -> /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator.cpython-38-x86_64-linux-gnu.so
Installing /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/build/lib.linux-x86_64-cpython-38/torch_interop_utils.cpython-38-x86_64-linux-gnu.so -> /opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_interop_utils.cpython-38-x86_64-linux-gnu.so

```

Fix by replacing eixsting `PyObject_GetAttrString` with
`PyObject_FastGetAttrString` which claims to be faster in its
implementation comment.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-03-12 10:51:30 +08:00
pengwa
3e954da3e6
Fix and enable few ORTModule Unit Tests (#19847)
### Fix and enable few ORTModule Unit Tests

Fix 'test_bert_inputs_with_dynamic_shape' and
'test_bert_result_with_layerwise_recompute' generate Nan loss in ORT
run.

The root cause is, the logic to generatic attention mask test data is
not correct, only 0 or 1 is allowed in the dataset, but we see lots of
other numbers. ( The reason we don't have this using old version of
transformers for example v4.4.2 or 4.16.2 is because they don't contains
such
d3cb28886a,
which increase the scaling to a bigger number, causing a overflow to
inf)

Another improvement during the investigation using convergence tools:
Don't dump the activations during model export phase, otherwise, the
dumped data might contains some PyTorch run's result making us confused
during comparing with stock PyTorch run results.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-03-12 10:49:19 +08:00
Vincent Wang
1bfc26685b
ATen Op Supports Int Return Type and CPU Tensor Arguments (#19773)
This PR:
- add support for int as return type, will create a CPU scalar tensor
for it.
- add attributes to specify which arguments or returns are CPU tensors.
- adjust ATen efficient attn to match latest PyTorch native function.
- a Triton codegen bugfix by the way.
2024-03-06 10:11:46 +08:00
pengwa
d102569755
Fix seed for recomputed Dropout (#19715)
### Fix seed for recomputed Dropout

If Dropout node is recomputed in the backward, we should make sure its
execution is same as the run in the forward.
If we don't set seed attribute, then this cannot be guaranteed. 

Add ` export ORTMODULE_MEMORY_OPT_LEVEL=2` to enabled per layer
recompute with compromised recomputable subgraphs.
2024-03-06 10:06:25 +08:00
guyang3532
cd56ea4a74
enable embedding sparse optimization by default (#19714) 2024-03-05 13:15:30 +08:00
Adam Louly
d5606cd7ee
Introducing customizable input names for loss in generate_artifacts. (#19705)
# loss function extra inputs.
Currently, the loss functions in onnxblock expect exactly two inputs in
their build method.
Occasionally, models may pass additional inputs, causing the build
function to fail.
To solve this issue, we can let users pass a list of loss input names to
be used in the loss function.
2024-02-29 13:40:56 -08:00
Vincent Wang
937cdd651e
[ORTMODULE] Support Register Custom Triton Kernel (#19690)
Add support for registering custom Triton kernel function.
2024-02-29 23:03:57 +08:00
pengwa
026e3178ae
Improve memory matrix for ORTModule (#19620)
### Memory matrix for ORTModule

Collect  parameter/gradient/buffers sizes also. 
Exposed as a function, can be used externally for debugging purpose. 


```
2024-02-27 07:18:55,283 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,322 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,358 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,438 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▏                                                                                                                                                                                                                                              | 2/3200 [01:27<32:05:11, 36.12s/it]2024-02-27 07:18:55,498 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,537 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,576 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,657 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▏                                                                                                                                                                                                                                              | 3/3200 [01:27<17:30:57, 19.72s/it]2024-02-27 07:18:55,711 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,750 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,786 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,867 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
[2024-02-27 07:18:55,886] [INFO] [loss_scaler.py:190:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, but hysteresis is 2. Reducing hysteresis to 1
  0%|▎                                                                                                                                                                                                                                              | 4/3200 [01:28<10:39:52, 12.01s/it]2024-02-27 07:18:55,902 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,944 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,979 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,060 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▍                                                                                                                                                                                                                                               | 5/3200 [01:28<6:53:04,  7.76s/it]2024-02-27 07:18:56,115 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,154 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,190 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,270 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▍                                                                                                                                                                                                                                               | 6/3200 [01:28<4:36:19,  5.19s/it]2024-02-27 07:18:56,323 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,365 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,398 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,478 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▌                                                                                                                                                                                                                                               | 7/3200 [01:28<3:09:33,  3.56s/it]2024-02-27 07:18:56,533 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,572 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,608 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,727 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▌                                                                                                                                                                                                                                               | 8/3200 [01:28<2:13:48,  2.52s/it]2024-02-27 07:18:56,806 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,846 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,882 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,962 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▋                                                                                                                                                                                                                                               | 9/3200 [01:29<1:36:03,  1.81s/it]2024-02-27 07:18:57,053 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:57,094 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8

```
2024-02-28 15:57:05 +08:00
jingyanwangms
3bdb10d5ca
Move import to when needed to avoid circular dependency error (#19579)
### Description
Move import to when needed to avoid circular dependency error


### Motivation and Context
Fixes dependency error described here:
https://github.com/microsoft/DeepSpeed/issues/5140

---------

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
2024-02-22 10:56:25 -08:00
Vincent Wang
3d88487c96
Minor Triton Fix (#19589)
Including removing a unnecessary assert, and add support of passing
string attribute from ONNX node attribute to python functoin kwargs
(mainly for passing debug info from graph to python for now).
2024-02-22 10:35:26 +08:00
zhijiang
8fadc6c913
Zhijxu/cleanup cached tensors when oom (#19306)
in pytorch, when oom happens at bp, user could decrease the batch size
and rerun it without restarting the process.

while in ORT, the intermediate tensors are kept even OOM, so decrease
batch size still fail.


this is torch run, we can see after oom failure, torch will release
tensor before next step

![image](https://github.com/microsoft/onnxruntime/assets/43435212/92b8a2e3-454b-448a-a223-17cb91d463c2)

this is from ort, we can see ort not release its tensors after OOM
failure.

![image](https://github.com/microsoft/onnxruntime/assets/43435212/bb6a3882-8e14-4f37-8079-e7f70fc2546b)

ort with the PR, we can see memory is released, **the 4GB memory is not
own by ort, and will be released by torch at the end**.

![image](https://github.com/microsoft/onnxruntime/assets/43435212/7f39d711-4e36-47d5-aecf-3805433a6d01)
2024-02-21 10:41:42 +08:00
Baiju Meswani
944d8f8513
Update the default std flag used during torch extensions compilation (#19516) 2024-02-14 12:49:34 -08:00
Prathik Rao
3b03b2e046
Upgrade default ORTModule opset from 15 to 17 (#19315)
### Description
<!-- Describe your changes. -->

This PR upgrades ORTModule's default opset from 15 to 17. Opset 17 is
the final opset supported by torchscript exporter
(https://github.com/pytorch/pytorch/pull/107829)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Engineering excellence contribution for ORT Training DRI.

---------

Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
2024-02-14 11:19:33 -08:00
Justin Chu
3d2ddf96e3
Bump ruff linter to 0.2.1 (#19471)
### Motivation and Context

Include new lint rules
2024-02-08 16:08:27 -08:00
Prathik Rao
d120104dcd
add ATen support for bicubic interpolation (#19380)
### Description
<!-- Describe your changes. -->

Add ATen fallback support for bicubic interpolation algorithm.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Required for facebook/dinov2 model architecture as part of ONNX Runtime
integration with AML Vision models.
2024-02-05 13:11:37 -08:00
jingyanwangms
319481898c
Give a triton library missing warning instead of silently turn off (#19276)
### Description
When USE_ORTMODULE_TRITON is set to 1 but there's no triton library,
triton function is silently turned off. This adds a warning
2024-02-01 15:25:33 -08:00
Baiju Meswani
3262e8df2f
Introduce a Nominal Checkpoint for On-Device Training (#19232) 2024-01-30 22:11:25 -08:00
Vincent Wang
9f68a27c7a
[ORTModule] Handle Cast on Constant Number on Triton Code-gen (#19321)
When using scaled_dot_product_attention on float16 type, the exported
graph has Sqrt(float16(constant)), which cannot be ConstantFold in ORT
because Sqrt CPU kernel doesn't support float16. This causes Triton
code-gen generates code like:

result = 128.0.to(tl.float32)

This code cannot be compiled because .to() cannot be applied to
constant.

This PR is to handle such case that constant number will not do the
Cast.
2024-01-30 17:04:01 +08:00
Vincent Wang
2b87dd373a
[ORTModule] Remove Mod from Hash to Avoid Conflict for Triton Code-gen (#19256)
Remove mod (10**8) from hash to avoid conflict for Triton code-gen.
2024-01-25 10:16:41 +08:00
pengwa
1150b1f81e
ORTModule memory improvement (#18924)
## Dependency

https://github.com/microsoft/onnxruntime/pull/19007

## ORTModule memory efficient gradient management

Previously I have tried to solve the coarsed-grained gradient
accumulation/update problem in ORTModule with
https://github.com/microsoft/onnxruntime/pull/8979, while that
resolution somehow is not fully validated with DDP or there is user
hooks on the gradient accumulation on torch parameter.

This PR is addressing the problem in the similar approach as PR 8979,
e.g. trigger gradient accumulation once ORT computed the grad, but
instead of use a AccumulateGrad op, this time with a ONNX operator
PythonOp, internally it will call param.backward(grad), which will help
handle all related hooks correctly.


## Design

Check the details from


https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ

## Convergence Validation:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe)

differences are on mostly 0.000x, sometimes 0.00x, which may comes from
the different order gradient apply happens before or after this change
(on deepspeed zero stage 2)


## TODO

Consolidate the logic with Stage3's similar logic.
2024-01-16 08:57:37 +08:00
Baiju Meswani
58bf836592
Offline tooling for training to use reduction with keepdims=False (#19027) 2024-01-11 10:51:23 -08:00
pengwa
d03e477b90
Fix missing subgraph candidates for recompute (#19077)
### Fix missing subgraph candidates for recompute

For subgraphs for example `MatMul+Transpose+Reshape`, since the ending
node is a Reshape, in ORT, it is reusing input buffers.

Currently, the subgraph detection logic has defect, as a result, those
subgraphs will be missing as recompute candidates.

Also append a few more node types for recompute support. 

TODO: add unit test later. This PR is needed for a customer model now.
2024-01-11 12:50:55 +08:00
Wei-Sheng Chin
658e30eb33
Remove DORT since it's in PyTorch main now (#18996)
Main code are removed and tests are modified to use DORT directly from
PyTorch.
2024-01-04 12:59:47 -08:00
pengwa
998517b209
Minor fixes (#18949)
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-12-28 20:01:06 +08:00
pengwa
5eda79bdd3
Improve perf for stage3 training (#18099)
### Improve perf for stage3 training - first wave

Port existing PythonOp/PythonOpGrad python runner to C++, also introduce
an unsafe run mode (to skip inplace, save for backward, materrialized
grad detection on the fly).

This reduce the overhead from XX~XXX us to X ~ lower end of XX us . In
LLAMA2 7B training with 8x32GV100, we have observed 6.7% gains over
PyTorch. (1.59 v.s. 1.49it/s)

Peak memory also dropped from 31GB to 28GB.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-12-15 13:32:19 +08:00
pengwa
ccf3b2054b
Allow layer-wise recompute (#18566)
### Allow layer-wise recompute 

Early, we need users/developers to specify the subgraphs to recompute,
now we introduced a more user-friendly way to enable recompute for all
detected stashed activation recomputation subgraphs. This scarifies
getting the best configs while makes it easier to support user
requirements when they switches from PyTorch per-layer gradient
checkpoint to ORTModule.

`ORTMODULE_MEMORY_OPT_LEVEL` is introduced to control the usage, by
default, it is 0, e.g. `USER_SPECIFIED`, all subgraphs definedin
`ORTMODULE_MEMORY_OPT_CONFIG` will be recomputed. So this is compatible
to existing recompute usage in ORTModule integrated models.

Using `ORTMODULE_MEMORY_OPT_LEVEL=1`, we will enable all recompute plans
detected, so those configs in `ORTMODULE_MEMORY_OPT_CONFIG` will not be
respected any more.


Add Unit Tests using 3 layer blooms. 



https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md
2023-12-12 08:44:05 +08:00
pengwa
4bfa84487c
Skip module clone for preparing large model export (#18663)
### Skip module clone for preparing large model export

For LLAMA2 13B, when running with Lora, DeepSpeed stage2 on 8 GPUs . It
failed during preparing outputs which will be used for
torch.onnx.export. The reason, we deep copy all the params including
both big sizes of frozen weights, + a little bit of Lora trainable
weight.

This PR will firstly check whether the GPU memmory is enough for a
cloned module, if not, skip the copy.

Copying the module is to guarantee the fw path run may change the
weight, while this case should be rare. But for now, Not-Able-To-Run is
worse than Runnable-with-A-little-bit-different-initial-weight,
especially for large models.
2023-12-05 12:41:17 -08:00