**Description**: Subgraph-level recompute
This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.
When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```
[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>: User config:
[1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>: =================================
[1,0]<stderr>: Subgraph: BitmaskDropout+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: BiasGelu+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Add+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Sub+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97
[1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1
[1,0]<stderr>: PatternShape:8 x 64 x Frequency:24
[1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: =================================
```
"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.
**Baseline**
On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.

**With this PR**
Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
This PR enables ORT to execute graphs captured by TorchDynamo. Major compilation code is in `OrtBackend.compile` in ort_backend.py. `register_backend.py` is for plugging `OrtBackend` into TorchDynamo as a compiler.
Some models have QuickGelu(x)=x*sigmoid(1.702x), which has 3 Ops for
forward and 5 Ops for backward. The PR is to fuse this to a single Op
named QuickGelu and its gradient QuickGeluGrad.
For CUDA, tested in V100 using input tensor with shape [64,128,2048] and
float16 type:
Before, FW takes 335us, BW takes 614us

After, FW takes 115us, BW takes 139us, which is much faster.

For CPU kernel, using same shape and float type:
Before, FW takes 10us, BW takes 49us
Mul: 3480[µs]
Sigmoid: 1996[µs]
Mul: 4789[µs]
Mul: 4642[µs]
Mul: 4195[µs]
SigmoidGrad: 18328[µs]
Mul: 2988[µs]
Sum: 18576[µs]
After, FW takes 4us, BW takes 5us, which is also much faster.
QuickGelu: 3939[µs]
QuickGeluGrad: 5089[µs]
Co-authored-by: Vincent Wang <weicwang@microsoft.com>
Add env variable to control disabling custom autogard function support.
When using ORTModule, if the torch model has torch.nn.Function, if user
confirms that it can be exported to ONNX (for example, by inline
PythonOp) and the backward implementation is matched to the forward
impl, user can export "ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT=1" to
disable the custom autograd support so that it won't use ORT's PythonOp
to fallback to PyTorch. Exporting to ONNX sometimes can leverage some
graph optimizations in ORT so that perf is better.
### Description
This is a fix for on device training wheel build.
### Motivation and Context
when building linux wheel it treats PathString same as std::string, but
when trying to build the wheel on windows it fails because we needed to
cast the std::string to a PathString.
This error was found manually because there is no pipeline that uses the
--enable_training_on_device for windows.
Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
### Description
<!-- Describe your changes. -->
Use SAS Token to fix error` failed to perform copy command due to error:
no SAS token or OAuth token is present and the resource is not public`
Generate SAS Token of target data, add it into Key vault, and use it as
Pipeline Variable.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
The PR applies some fixes to Hierarchical ORTModule and ORTModule
PythonOp.
For Hierarchical ORTModule:
- Don't wrap module if the caller is to call other function instead of
forward() function
- Support single module instance is call multiple times with different
types of inputs
- Check if module can be warped from top to bottom instead of from
bottom to top
For ORTModule PythonOp:
- Add env variable control to allow using
torch.utils.checkpoint.CheckpointFunction
- Add env variable control to skip register some autograd functions so
that there is no conflict for some models.
### Description
updating the ptca image used in the nightly pipeline
Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
**Description**: utils for federated learning.
**Motivation and Context**
- This PR includes utils that will be used on federated learning
scenarios.
- Exposing python bindings to some utils, and added a util to calculate
the difference between two buffers.
Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
`python setup.py develop` doesn't install PyTorch as a normal package in
site-packages anymore, and the user must stay at PyTorch's root
directory to call `import torch`. This will break LORT tests because
LORT tests contains `import torch` and are called outside PyTorch root
directory. To make PyTorch a normal package again, this PR build PyTorch
with `python setup.py install`.
This PR is to add support of using env variable to set provider option
cudnn_conv_algo_search so that user can choose better conv algo search
method to run model. This is a quick fix to unblock the test of MoE
model. Will have another PR to design and implement the ORTModule config
so that we can config ORTModule using Python script or config file
instead of env variable.
Model [huggingface's diffusers
library](https://github.com/huggingface/diffusers) has
torch.nn.GroupNorm which will be exported to sub-graph containing ONNX's
InstanceNormalization, which is lack of gradient. The implementation of
ORT's InstanceNormalization will call cuDNN's BatchNorm for part of
computation, which is not efficient compared to PyTorch's
implementation. This PR is to use ATen fallback to support this torch
module, including its forward and backward.
### Description
<!-- Describe your changes. -->
Unit test with ROCm5.3 slower than ROCm5.2.3. Revert to ROCm5.2.3.
We will update to ROCm5.3 when the issue resloved by AMD.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Current graph builder for ORTModule will apply the training's graph
optimizations for both training and eval mode. Take BatchNorm as
example, one of training's graph optimizations will replace
BatchNormalization Op to BatchNormInternal which is for training only.
This PR is to fix this, for eval mode, we will not apply the training's
graph optimizations. The inference's graph optimizations will be applied
when InferenceSession initialization.
### Description
Implemented gradient of cos as per the function below.

### Motivation and Context
Cos gradient required for [huggingface's diffusers
library](https://github.com/huggingface/diffusers)
### Testing
built ORT from source: `./build.sh --config RelWithDebInfo
--enable_training --use_cuda --cuda_home /usr/local/cuda --cudnn_home
/usr/local/cuda --build_wheel --parallel --skip_tests`
tested CosGrad implementation: `cd build/Linux/RelWithDebInfo/ &&
./onnxruntime_test_all --gtest_filter=GradientCheckerTest.CosGrad`
Co-authored-by: Prathik Rao <prathikrao@microsoft.com>
### Description
Implemented gradient of sin as a function op.
### Motivation and Context
Sin gradient currently implemented as cpu op which could hurt
performance.
### Testing
built ORT from source: `./build.sh --config RelWithDebInfo
--enable_training --use_cuda --cuda_home /usr/local/cuda --cudnn_home
/usr/local/cuda --build_wheel --parallel --skip_tests`
tested SinGrad implementation: `cd build/Linux/RelWithDebInfo/ &&
./onnxruntime_test_all --gtest_filter=GradientCheckerTest.SinGrad`
Co-authored-by: Prathik Rao <prathikrao@microsoft.com>
Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
This PR has two fixes:
- https://github.com/pytorch/pytorch/pull/85636 change the behavior of
register_custom_op_symbolic to only register the symbolic function at a
single version. For ORTModule we need to pass the op_set version when
calling it.
- Since torch_1.13 the signature of einsum is changed to have a new
argument, need to change our custom op symbolic registry code
accordingly.
Without the fixes, ORTModule will not work with the nightly torch, and
the new torch version will be released.
Update for ROCm CI before reland tunable GEMM #12853. This PR also update
composable kernel to use CMakes's HIP language support so that we can
mix C/C++ compiler with HIP compiler instead of locking to hip-clang
### Description
Training C# bindings (ReleaseTrainingSession and ReleaseCheckpointState)
broke after an API order change in Training C API. This PR fixes this
issue.
### Motivation and Context
Bug Fix for Training C# bindings
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
For below code in some transformers models:
```
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
```
The exported graph will contains 3 Gather nodes, currently ORT's
GatherGrad CUDA implementation is slow. This pattern can be fused to use
one Split, so that we can launch less kernels for the compute, the perf
of Split/Concat (for grad) is also better than Gather/GatherGrad.
In a real example, one GatherGrad will take 15ms and there are 3 for
each layer in the graph, after the fusion, one Concat takes only 35us.
The total time of a step is improved from 1.5s to 0.4s.
This PR is to fix https://github.com/microsoft/onnxruntime/issues/12930
and https://github.com/microsoft/onnxruntime/issues/12579.
In detail:
- For CPU EP, since current impl of SimplifiedLayerNormalization doesn't
support input and scale having different data types, so if the sub-graph
contains Cast Op, the sub-graph will not fused, this guarantee that both
inputs and output data type will be same
- For CUDA EP, add (fp16, float) support to (T,V) type constraints all
combinations of fp16 and float can be supported in the impl
With the fix, the original model can be run with
SimplifiedLayerNormalization, which also helps to improve the perf.
* `QuantizeBFP` and `DequantizeBFP` schemas - similar to
`QuantizeLinear` and `DeQuantizeLinear`.
* BFP datatype is represented as a `uint8` tensor with shape and stride
metadata. This is preferrable to adding a new datatype for BFP, which is
more disruptive and [discouraged by
PyTorch](https://discuss.pytorch.org/t/training-with-custom-quantized-datatype/152132/2).
Context:
The Microsoft Floating Point (BFP) datatype shares an exponent for every
n numbers called a “bounding box.” Each number still has its own
mantissa and sign bits. BFP has been shown to incur 3-4 less cost
(energy and area) than BFloat16 and INT8 counterparts without reductions
in accuracy for the ImageNet benchmark as described in [Rouhani
2020](https://proceedings.neurips.cc/paper/2020/file/747e32ab0fea7fbd2ad9ec03daa3f840-Paper.pdf).
Requirements:
* There are many variants of BFP (number of mantissa bits, number of
shared exponent bits, size of bounding box, custom bit fields, etc.)
* The size and layout of an BFP variant varies across hardware
* bounding box can be over arbitrary dimensions; for example, for the
channel "C" dimension in a N x C x H x W tensor for convolution
Goals of this PR:
* Add initial versions of QuantizeBFP and DequantizeBFP operators to
enable QDQ-style quantization with BFP. Once the schemas stabilize, we
can consider upstreaming to ONNX.
* Add some basic type and shape inferencing tests; tests that run on an
EP will be a follow-up.
# Motivation
Currently, ORT minimal builds use kernel def hashes to map from nodes to
kernels to execute when loading the model. As the kernel def hashes must
be known ahead of time, this works for statically registered kernels.
This works well for the CPU EP.
For this approach to work, the kernel def hashes must also be known at
ORT format model conversion time, which means the EP with statically
registered kernels must also be enabled then. This is not an issue for
the always-available CPU EP. However, we do not want to require that any
EP which statically registers kernels is always available too.
Consequently, we explore another approach to match nodes to kernels that
does not rely on kernel def hashes. An added benefit of this is the
possibility of moving away from kernel def hashes completely, which
would eliminate the maintenance burden of keeping the hashes stable.
# Approach
In a full build, ORT uses some information from the ONNX op schema to
match a node to a kernel. We want to avoid including the ONNX op schema
in a minimal build to reduce binary size. Essentially, we take the
necessary information from the ONNX op schema and make it available in a
minimal build.
We decouple the ONNX op schema from the kernel matching logic. The
kernel matching logic instead relies on per-op information which can
either be obtained from the ONNX op schema or another source.
This per-op information must be available in a minimal build when there
are no ONNX op schemas. We put it in the ORT format model.
Existing uses of kernel def hashes to look up kernels are replaced
with the updated kernel matching logic. We no longer store
kernel def hashes in the ORT format model’s session state and runtime
optimization representations. We no longer keep the logic to
generate and ensure stability of kernel def hashes.
**Description**: **Python API Bindings for on device training. **
**Motivation and Context**
- This PR contains api bindings so python users can perform a whole
training loop.
Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
Currently, CUDA hardware is not available to be leveraged by build
during `docker build`. because of that, CUDA capable hardware would not
have CUDA support
This PR adds an env varf ONNXRUNTIME_FORCE_CUDA in which it allows CUDA
extensions to be compiled even when CUDA support is not detected.
* drop nuphar code and configs
* refactor test case
* format python
* remove nuphar from training test
* remove commented nuphar logics
* restore llvm setting
* drop nuphar ci
* fix compile err
* fix compile err
Co-authored-by: Randy Shuai <rashuai@microsoft.com>