### Description
DML Implementation for
[com.microsoft.MatMulIntegerToFloat](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulIntegerToFloat)
```
.\onnxruntime_test_all.exe --gtest_filter="*MatMulIntegerToFloat.*"
Note: Google Test filter = *MatMulIntegerToFloat.*
[==========] Running 22 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 22 tests from MatMulIntegerToFloat
[ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8
[ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8 (620 ms)
[ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8
[ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8 (497 ms)
[ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8S8
[ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8S8 (488 ms)
[ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8S8
[ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8S8 (503 ms)
[ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8U8
[ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8U8 (495 ms)
[ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8U8
[ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8U8 (488 ms)
[ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8U8
[ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8U8 (492 ms)
[ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8X8
[ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8X8 (502 ms)
[ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8U8
[ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8U8 (452 ms)
[ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8U8
[ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8U8 (454 ms)
[ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8U8
[ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8U8 (446 ms)
[ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8U8
[ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8U8 (508 ms)
[ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8S8
[ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8S8 (456 ms)
[ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8S8
[ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8S8 (455 ms)
[ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8S8
[ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8S8 (447 ms)
[ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8S8
[ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8S8 (465 ms)
[ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8U8
[ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8U8 (111 ms)
[ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8S8
[ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8S8 (115 ms)
[ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8S8
[ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8S8 (114 ms)
[ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8U8
[ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8U8 (110 ms)
[ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16
[ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16 (112 ms)
[ RUN ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint
[ OK ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint (337 ms)
[----------] 22 tests from MatMulIntegerToFloat (8679 ms total)
[----------] Global test environment tear-down
[==========] 22 tests from 1 test suite ran. (8680 ms total)
[ PASSED ] 22 tests.
memleakdbg:
----- No memory leaks detected -----
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
* `CalculateMatMulIntegerToFloat` to replace CPU EP run reference
* Added more FP32 testcases to isolate all input datatype combinations
* Added fixed input to `MatMulIntegerToFloat_FP16*` test cases as for
FP16 test cases.
* onnxruntime/test/testdata/matmul_integer_to_float.py` is capable of
generating FP16 models, but we do not produce any for now
### Follow up fix for Gelu impl
There are two minor comments in
https://github.com/microsoft/onnxruntime/pull/19560.
Fix them in this pull request.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
I've added NHWC GridSample support to the CUDA EP to reduce the number
of layout transforms. Also I've enabled the full set of GridSampleTests
for all EPs. I've also added the GridSample OpSet 16 to the registered
kernels.
### Motivation and Context
This is the first PR is a series of enhancements of the CUDA EP
improving NHWC support to avoid costly layout transforms between NWHC
and NCHW nodes which are layout sensitive. Also testing was quite
rudimentary for the CUDA EP while it was great for the CPU path. I've
regenerated grid_sample_test.cc enabling tests for other platforms as
well. Those tests resurfaced #10607 again which is fixed as well.
### ONNX Gelu Op in Opset 20
Refactor code to support MSDomain Gelu and ONNX Gelu-opset20 Op
1. Move CPU-GELU implmentation from
`onnxruntime/contrib_ops/cpu/activations.h/cc` to
`onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation
for approximate attribute to be 'none'.
2. Dumplicate some logic from
`onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc` to
`onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation
for approximate attribute to be 'tanh'.
3. Register ONNX domain Gelu CPU kernel from opset 20 in
`onnxruntime/core/providers/cpu/cpu_execution_provider.cc`.
4. Move `onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h/cu` to
`onnxruntime/core/providers/cuda/tensor/gelu_impl.h` and
`onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu`
respectively, as the implementation for approximate attribute to be
'tanh'.
5. Implement the logic for approximate attribute to be 'none' in
`onnxruntime/core/providers/cuda/tensor/gelu_impl.cu`.
6. Register ONNX domain Gelu CUDA kernel from opset 20 in
`onnxruntime/core/providers/cuda/cuda_execution_provider.cc`.
7. ROCM ep related changes.
8. Enrich the tests for ONNX domain Gelu in
`onnxruntime/test/providers/cpu/activation/activation_op_test.cc`.
### Description
This PR updates exporting and running the Whisper model with beam search
by adding the following.
- Adds temperature as a graph input to the exported model
- Fixes the token ids by adding them as attributes to
`WhisperBeamSearch`
- Fixes the timestamps test cases so they pass now
- Fixes a bug with invoking `torch.onnx.export`
- Cleans up the Whisper scripts and groups the arguments in
`convert_to_onnx.py`
- Adds a `requirements.txt` file to specify package dependencies
- Adds `whisper-large-v3` to list of pretrained models
- Fixes a bug with missing cross-attention KV cache inputs in the
decoder subgraph
### Motivation and Context
- This is a follow-up to [this
PR](https://github.com/microsoft/onnxruntime/pull/19188).
- The incorrect token ids in the timestamps processor were first noticed
during [this PR
review](https://github.com/microsoft/onnxruntime/pull/17500#discussion_r1333520007).
When they were originally added in [this
PR](https://github.com/microsoft/onnxruntime/pull/15853), the offsets
were previously constant across the Whisper model sizes. When comparing
the new `whisper-large-v3` variant, the English-only variants (e.g.
`whisper-tiny.en`), and the original variants (e.g. `whisper-tiny`),
both the values and the offsets differ. Therefore, it is easier to set
the token ids as attributes to `WhisperBeamSearch` when exporting to
ensure the right values are used in the timestamps processor.
- The Hugging Face API for returning timestamps and the expected outputs
from the PyTorch model have both changed.
- The fix for `torch.onnx.export` is a follow-up to [this PR
review](https://github.com/microsoft/onnxruntime/pull/17179#issuecomment-1683001470).
- The argument grouping is a follow-up to [this PR
review](https://github.com/microsoft/onnxruntime/pull/17500#discussion_r1333521721).
- Specific package versions are needed to run the Whisper scripts and
the `requirements.txt` file ensures that these versions are installed.
- The `whisper-large-v3` variant is released and should be in the list
of official pretrained models.
- After the changes from [this
PR](https://github.com/microsoft/onnxruntime/pull/17316), the exported
model is not loading in an ORT inference session because the
cross-attention KV cache inputs are missing in the decoder subgraph.
### Description
Sqrt does not have BF16 support yet. Adding that with this PR
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Description
<!-- Describe your changes. -->
Adds bfloat16 as a supported dtype for SimplifiedLayerNormFusion which
will provide speedup for Llama-v2 on A100 using bfloat16 numerical
format.
_layernorm_optimized_training.onnx exported in bfloat16 vs. float16:_

### Repro Instructions
```python
from torch import nn
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
import torch
dtype = torch.bfloat16
# dtype = torch.float16
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(784, 10, dtype=dtype)
self.layernorm = nn.LayerNorm([784], dtype=dtype)
def forward(self, x):
x = x.view(x.shape[0], -1)
x = self.layernorm(x)
x = self.fc(x)
return x
model = Net()
model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='layernorm', log_level=LogLevel.INFO))
model.to("cuda")
images = torch.randn((8, 28, 28), dtype=dtype).to("cuda")
output = model(images)
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
ONNX Runtime integration with Llama-v2 family of LLMs.
---------
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
`ScatterElements` in opset 18 has been around for a while. However, the
highest opset supporting `ScatterElements` in ORT is 13. This PR
implement this op in CUDA EP by replacing `assignment` in the current
CDUA kernel with `atomic reduction` (e.g., atomic add, atomic max). A
series of fundamental atomic functions (e.g., atomic max for int8_t and
half) are implemented in `common.cuh`; the implementation is general
enough to cover old CUDA and new CUDA versions.
- The core changes are in `cuda/atomic/common.cuh` with very detailed
documentation including `bit-wise operation's visualization`. They are
also copied to `rocm/atomic/common.cuh` to support AMD GPU.
- `/cuda/tensor/gather_elements_impl.cu` contains small changes to call
the new atomic functions to support new `reduction` behavior in new
`ScatterElements`.
- New `ScatterElements` are defined in `rocm_execution_provider.cc` and
`cuda_execution_provider.cc`.
### Description
Implement Pad-18 for Cuda.
### Motivation and Context
Latest models converted by Dynamo fall back on CPU for Pad with
performance degradation.
This contributes to
https://github.com/microsoft/onnx-rewriter/issues/126
### Description
These changes add rotary embedding and packed qkv input to gqa. As of
now, the changes are only supported with Flash-Attention (SM >= 80) but
should soon be supported with Memory Efficient Attention as well.
### Motivation and Context
With the fusion of rotary embedding into this Attention op, we hope to
observe some perf gain. The packed QKV should also provide some perf
gain in the context of certain models, like Llama2, that would benefit
from running ops on the fused QKV matrix, rather than the separate Q, K,
and V.
---------
Co-authored-by: Yufeng Li <liyufeng1987@gmail.com>
### Description
<!-- Describe your changes. -->
Add `temperature` as an input to WhisperBeamSearch op and initialize
correctly in parameter setup.
### Motivation and Context
Currently, temperature is included as an attribute to the BeamSearch op,
which doesn't let the model act dynamically in a single inference
session. By including this variable as an input, the temperature value
can be altered in any inference call (important for 1P teams)
---------
Co-authored-by: Peter McAughan <petermca@microsoft.com>
Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com>
### Description
<!-- Describe your changes. -->
Register DML operators for opset 19.
- Cast19
- Castlike19
- Constant19
- Equal19
- Identity19
- QuantizeLinear19
- DequantizeLinear19
- Reshape19
- Shape19
- Size
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
---------
Co-authored-by: linnealovespie <linneamay@microsoft.com>
### Description
<!-- Describe your changes. -->
1. support causal mask in MHA cpu
2. support custom rotary_dim in rotary_emb
3. add bf16 for rotary_emb
4. fix a bug in attention rotary
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Description
<!-- Describe your changes. -->
Bump up version to 1.18.0 since the release branch has been cut.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
### Description
<!-- Describe your changes. -->
Implements LabelEncoder as per `ai.onnx.ml` opset 4 for the upcoming
ONNX 1.15 release. ~~This currently depends on a new ONNX release
candidate and so is marked as draft in the meantime.~~
### Motivation and Context
Closes https://github.com/microsoft/onnxruntime/issues/17602
When the TRT engine cache (precompiled engine) is present, it doesn't
make sense to go over the processes of model verification, model
optimization, TRT EP's GetCapability(), TRT EP's model proto
reconstruction, calling TRT parser and engine compilation.
This PR makes TRT EP skip those processes and directly load the engine
to perform inference.
The feature request:
https://github.com/microsoft/onnxruntime/issues/18072
Features:
- Replace original model with TRT engine wrapped ONNX model. It can save
a lot of time as mentioned above.
- How to get TRT engine wrapped ONNX model?
1. Set `trt_dump_ep_context_model` provider option to "true" and run the
inference. You will find the "xxx_wrapper.onnx" at the engine cache
path. (The same logic of generating engine cache)
2. Use gen_trt_engine_wrapper_onnx_model.py
- Three provider options are added,
`trt_dump_ep_context_model`: Enable dump wrapped onnx model by TRT EP
`trt_ep_context_embed_mode`: Add embed_mode as attribute. 0 means engine
cache path, 1 means engine binary data.
`trt_ep_context_compute_capability_enable`: Add hardware_arch as
attribute. When running the model, TRT EP will check consistency between
model's hardware_arch and GPU's compute capability.
- When the engine cache path is given in the wrapped model, TRT EP will
first search for the engine file using the path (relative to model
path), if it can't find it, it will change to use the path as it is
(depends on user, could be relative to working dir or absolute path)
Note:
1. This PR includes the change of
https://github.com/microsoft/onnxruntime/pull/17751
Constraints:
1. The whole model should be fully supported by TRT.
4. Users need to make sure the engine is built with min/max/opt
optimization profiles that large enough to cover the range of all
inputs. TRT EP will simply fail and won't rebuild the engine if the
input shape is out of range during runtime.
### Description
<!-- Describe your changes. -->
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Description
reducemax/min have been updated in onnx(20). implement it in ort
### Motivation and Context
this is for ort1.17.0 release
---------
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
### Description
dft is updated in opset20. implement it in ort
### Motivation and Context
this is for ort 1.17.0 release
Fixes#17723
---------
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
### Allow layer-wise recompute
Early, we need users/developers to specify the subgraphs to recompute,
now we introduced a more user-friendly way to enable recompute for all
detected stashed activation recomputation subgraphs. This scarifies
getting the best configs while makes it easier to support user
requirements when they switches from PyTorch per-layer gradient
checkpoint to ORTModule.
`ORTMODULE_MEMORY_OPT_LEVEL` is introduced to control the usage, by
default, it is 0, e.g. `USER_SPECIFIED`, all subgraphs definedin
`ORTMODULE_MEMORY_OPT_CONFIG` will be recomputed. So this is compatible
to existing recompute usage in ORTModule integrated models.
Using `ORTMODULE_MEMORY_OPT_LEVEL=1`, we will enable all recompute plans
detected, so those configs in `ORTMODULE_MEMORY_OPT_CONFIG` will not be
respected any more.
Add Unit Tests using 3 layer blooms.
https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md
Fix a bug that can't create context binary if the model has inputs/outputs with different data type
### Description
Update EPContext op schema to unblock nodes with different data type among inputs & outputs
### Skip module clone for preparing large model export
For LLAMA2 13B, when running with Lora, DeepSpeed stage2 on 8 GPUs . It
failed during preparing outputs which will be used for
torch.onnx.export. The reason, we deep copy all the params including
both big sizes of frozen weights, + a little bit of Lora trainable
weight.
This PR will firstly check whether the GPU memmory is enough for a
cloned module, if not, skip the copy.
Copying the module is to guarantee the fw path run may change the
weight, while this case should be rare. But for now, Not-Able-To-Run is
worse than Runnable-with-A-little-bit-different-initial-weight,
especially for large models.
This PR:
- Remove unused arguments from generated triton code,
- Remove unnecessary mask for symbolic shape case from generated triton
code.
- Add doc for usage of ORTMODULE_TRITON_CONFIG_FILE.
### Description
<!-- Describe your changes. -->
Add bfloat16 support for `MatMulBnb4` contrib op. This is useful for
QLoRA fine-tuning.
- On GPUs with SM80+ (A100, etc), it uses the native cuda bfloat16
dtype, `nv_bfloat16`. On other GPUs, it uses the onnxruntime `BFloat16`
type which uses float for compute.
- I have validated the op in a llama2-7b training scenario. The losses
match pytorch training and the training throughput is better.
- Cannot add a bfloat16 case in the op unit test since casting BFloat16
to and from float multiple times during the test causes the required
tolerances to be unachievable.
The custom autograd function exporter in onnxruntime-training is updated
to support the latest version of bitsandbytes. They changed how the
`quant_state` is stored.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Enable QLoRA fine-tuning with bfloat16.
### Description
<!-- Describe your changes. -->
change RotaryEmbeddings op implementation, add support for 4D input
tensor that is with shape of [batch, num_heads, seq_len, head_size].
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Current RotaryEmbedding op only support 3d input tensor with shape
[batch, seq_len, hidden_size]
For llamav2 model, when using FusionRotaryEmbeddings to only fuse
RotaryEmbeddings op, there will be a transpose operation for query and
key, and then the input tensor of RotaryEmbeddings becomes 4D [batch,
num_heads, seq_len, head_size].
This scenario can't be supported by current RotaryEmbeddings
implementation. So it needs to support 4D input tensor.
### Description
Implement preliminary version of local (sliding window) attention.
Currently only supported by Flash Attention (sm >= 80, Linux). Currently
only supports sliding attention with a large cached kv.
### Motivation and Context
This change enables to run Mistral and other models which use sliding
window attention.
### Description
<!-- Describe your changes. -->
1. Introduce MoE CUDA op to ORT based on FT implementation.
2. Upgrade cutlass to 3.1.0 to avoid some build failures on Windows.
Remove patch file for cutlass 3.0.0.
3. Sharded MoE implementation will come with another PR
limitation: __CUDA_ARCH__ >= 700
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
### Description
<!-- Describe your changes. -->
Registers BFloat16 datatype as valid input type for CUDA Neg Kernel.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime
training.
---------
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
### Tune logging experience a bit
After last time we update the ORTModule log experience, we found few
issues:
1. `INFO` level output too many things, including PyTorch exporter
verbose logs (tracing graphs) on every ranks. On this level, we only
want to
- Output a little bit more information to Users than `WARNING` level,
for example the memory recomputation recommendations or other
not-fully-ready features.
- Output a little bit more information for a quick diagnostic, collected
on rank-0 only.
2. ONNX Runtime logging filter during graph build, session init
sometimes will hide the issues (for example segement fault), there is no
useful information in `WARNING`/`INFO` for users to report to us. This
is not good!
3. Some of our devs like using `pdb` to debug Python code, but if we add
`import pdb; pdb.set_trace()` in models' code might hang when they use
`INFO` or `WARNING`, where exporter happens and all output got
redirected due to log filtering. The only workaround is to switch to
VERBOSE, which output toooooooooooo many logs.
The corresponding changes proposed here are:
1. For `INFO` logging,
- We only logs rank-0.
- We restricted the ORT backend logging level to be WARNING in this
case, because ORT backend code output way too many logs that should be
under verbose, while we cannot guarantee we can get them cleaned up
immediately once they are added.
- We output the PyTorch exporter verbose log (including tracing graph),
which is useful for a quick diagnostic when an issue happens.
2. Remove all logging filtering on ORT backend, then the segment fault
issue details will not be hidden once it happens again.
3. Introduced a `DEVINFO` logging,
- Log logs on all ranks
- Log ORT backend logging level INFO
- PyTorch exporter logging filtering are all turned OFF (to unblock the
pdb debugging).
4. Currently, to use Memory Optimizer, need use DEVINFO (which will
output ORT backend INFO log). So update memory optimizer document to
reflect this. https://github.com/microsoft/onnxruntime/pull/17481 will
update the requirement back to INFO for show memory optimization infos.
You can check
https://github.com/microsoft/onnxruntime/blob/pengwa/devinfo_level/docs/ORTModule_Training_Guidelines.md#log-level-explanations
for a better view of different log levels.
This PR also extract some changes from a bigger one
https://github.com/microsoft/onnxruntime/pull/17481, to reduce its
complexity for review.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
---------
Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
### Description
GQA now only works with Flash Attention with Attention Mask input,
allowing for batched input. Note: This PR Disables Memory Efficient
Attention, only allowing Flash Attention kernel to be used.
### Motivation and Context
Allows GQA to work with batched input.
---------
Co-authored-by: Yufeng Li <liyufeng1987@gmail.com>
This is a graph implementation of RotaryEmbedding since there's no time
to add it to DML before 1.16.2, but it eventually should move into
DirectML since we're bandwidth-bound.
### Description
<!-- Describe your changes. -->
Adds bfloat16 as a valid input parameter type for where node for ONNX
opset 16+.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime
training.
---------
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
### Optimize 4bit Qlora training
Extent existing `MatmulBnb4bit` to its usage in training scenarios.
The PR includes following changes:
1. Add special `torch.autograd.Function` export logic for
`bitsandbytes.autograd._functions.MatMul4Bit` that is preferred before
common PythonOp exporter.
2. Add `training_mode` optional attribute for op `MatmulBnb4bit`, which
help skip some inference specific logic in implementation.
3. Add `transB` optional attribute, which is by default be 1; setting it
to be 0 is needed by backward usage.
Changing from `PythonOp` to this `MatmulBnb4bit` brings roughly ~2.9%
throughput gains. The reason is:
`bitsandbytes.autograd._functions.MatMul4Bit` has logic
`ctx.save_for_backward`, which would need an additional copy in
PythonOp, otherwise, the tensor might be released by ORT, while backward
op still references it.
Removing the clones also reduce the peak memory consumptions because
`bitsandbytes.autograd._functions.MatMul4Bit` saved tensors that are not
needed in backward compute.
Implement Cutlass Memory Efficient Attention Kernel into Group Query
Attention Operator.
### Motivation and Context
Before this change, Group Query Attention Operator was supported only by
Flash-Attention. While this is the most efficient kernel for the
operation, it only supports sm >= 80. Cutlass Memory Efficient Attention
Kernel supports sm >= 53, allowing us to support a broader range of GPU
hardware.