Commit graph

123 commits

Author SHA1 Message Date
Ye Wang
f9af94009b
onboard MoE (#18279)
### Description
<!-- Describe your changes. -->
1. Introduce MoE CUDA op to ORT based on FT implementation.
2. Upgrade cutlass to 3.1.0 to avoid some build failures on Windows.
Remove patch file for cutlass 3.0.0.
3. Sharded MoE implementation will come with another PR

limitation: __CUDA_ARCH__ >= 700


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-11-14 16:48:51 -08:00
aciddelgado
3dece27f51
GQA Flash Attention with Attention Mask (#18283)
### Description
GQA now only works with Flash Attention with Attention Mask input,
allowing for batched input. Note: This PR Disables Memory Efficient
Attention, only allowing Flash Attention kernel to be used.



### Motivation and Context
Allows GQA to work with batched input.

---------

Co-authored-by: Yufeng Li <liyufeng1987@gmail.com>
2023-11-07 17:47:51 -08:00
pengwa
c8e1038eab
Optimize 4bit Qlora training (#18131)
### Optimize 4bit Qlora training

Extent existing `MatmulBnb4bit` to its usage in training scenarios. 

The PR includes following changes:
1. Add special `torch.autograd.Function` export logic for
`bitsandbytes.autograd._functions.MatMul4Bit` that is preferred before
common PythonOp exporter.
2. Add `training_mode` optional attribute for op `MatmulBnb4bit`, which
help skip some inference specific logic in implementation.
3. Add `transB` optional attribute, which is by default be 1; setting it
to be 0 is needed by backward usage.

Changing from `PythonOp` to this `MatmulBnb4bit` brings roughly ~2.9%
throughput gains. The reason is:
`bitsandbytes.autograd._functions.MatMul4Bit` has logic
`ctx.save_for_backward`, which would need an additional copy in
PythonOp, otherwise, the tensor might be released by ORT, while backward
op still references it.

Removing the clones also reduce the peak memory consumptions because
`bitsandbytes.autograd._functions.MatMul4Bit` saved tensors that are not
needed in backward compute.
2023-11-02 09:46:11 -07:00
aciddelgado
178f7caaeb
GQA Memory Efficient Kernel (#17920)
Implement Cutlass Memory Efficient Attention Kernel into Group Query
Attention Operator.

### Motivation and Context
Before this change, Group Query Attention Operator was supported only by
Flash-Attention. While this is the most efficient kernel for the
operation, it only supports sm >= 80. Cutlass Memory Efficient Attention
Kernel supports sm >= 53, allowing us to support a broader range of GPU
hardware.
2023-11-01 20:04:22 -07:00
Tianlei Wu
95f053c652
[CUDA] Update GroupNorm and Add SkipGroupNorm (#18091)
* Add a new operator SkipGroupNorm to support skip and bias inputs.
* Update GroupNorm kernel to support number of channels used in SD XLrefiner.
* Add epsilon in kernel
* Add parity and performance test script
* Remove many limitations including max batch size, max number of groups, c % cPerBlock ==0 etc.

### Motivation and Context

Update GroupNorm to support SD XL Refiner and beyond.
2023-10-31 10:27:20 -07:00
Xavier Dupré
b5f242e978
GemmFloat8 as a contrib ops (#16051)
### Description
Add support for Gemm with float 8 as a contrib op.

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: Scott McKay <Scott.McKay@microsoft.com>
Co-authored-by: Xavier Dupre <xadupre@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
2023-10-27 14:33:55 +02:00
Jambay Kinley
d30d4d372a
Add MatMul FP4 and NF4 Support (#18066)
### Description
Add a contrib op MatMulBnb4 (FP4 and NF4) and related toolchain to
support quantization on weight.

This PR adds:
- schema for contrib op MatMulBnb4 which can support FP4 (4-bit floating
point) and NF4 (4-bit NormalFloat) quantization on weight.
- a naive implementation for MatMulBnb4 on CPU and GPU, i.e.,
implemented like MatMul(A, Dequantize(B)).
- a special implementation for GemV for MatMulBnb4 and related benchmark
tool.
- tool to quantize model to FP4 or NF4.
2023-10-25 15:34:58 -07:00
kunal-vaishnavi
2a17d5cf32
LLaMA Model Optimization (#18021)
### Description
This PR contains fusion-level and kernel-level optimizations for [Meta's
LLaMA-2](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/).

Some of the added optimizations include:

- SimplifiedLayerNorm changes
  - Fusions for multiple variants
- SkipSimplifiedLayerNorm changes
  - Kernel support for CPU
- Rotary embeddings (previously did not exist)
  - Fusions for multiple variants
  - CPU and CUDA kernels
  - Supports interleaving and non-interleaving in the same kernels
  - Optimized cache that requires half of its originally exported sizes
- Reduced from `(max_sequence_length, head_size)` to
`(max_sequence_length, head_size / 2)`
- Multi-head attention
  - Support for 2D and 3D attention masks
- Group query attention (for FP16 CUDA and INT4 CUDA)
  - Integration with flash attention v2 and past-present buffer sharing
- Removes need for `attention_mask` input as it is supported in the
kernel
- 4 bit quantization
  - `block_size` parameter is available for customizing
- Support the new changes for [Microsoft
version](https://github.com/microsoft/Llama-2-Onnx)
- Support combinations of the below variants (ex: export ORT version and
run with Optimum)

Supported variants of LLaMA-2 include:
- [ORT
version](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama)
- Produces one ONNX file that is already optimized (and quantized if
requested)
  - Integrates with Optimum
- [Another Microsoft version](https://github.com/microsoft/Llama-2-Onnx)
  - Already exported and available off-the-shelf
  - Faster versions of those models will be uploaded there soon
- [Hugging Face version](https://huggingface.co/meta-llama)
  - Models that end with `-hf`
- Some older and current versions of
[`transformers`](https://github.com/huggingface/transformers) and
[`optimum`](https://github.com/huggingface/optimum) that export the
model to ONNX differently
- Note that while some older versions are supported, it is recommended
to use the latest package versions.

### Usage

To use the optimizations, please see `README.md` for details. Please
note the various `requirements.txt` files for the package versions
recommended in order to use these changes.

To run the ORT transformer optimizer separately, run the script as
follows:
```
$ cd onnxruntime/onnxruntime/python/tools/transformers/
$ python3 optimizer.py --input <filename>.onnx --output <filename>.onnx --model_type gpt2 --num_heads <number of attention heads> --hidden_size <attention hidden size> --use_external_data_format --opt_level 0
```

### Motivation and Context
This PR helps the following issues:
- https://github.com/microsoft/onnxruntime/issues/14997
- https://github.com/microsoft/onnxruntime/issues/16254
- https://github.com/microsoft/onnxruntime/issues/17681
- https://github.com/microsoft/onnxruntime/issues/17925
- https://github.com/microsoft/onnxruntime-inference-examples/issues/320

This PR uses changes from the following PRs:
- https://github.com/pytorch/pytorch/pull/104468
- https://github.com/pytorch/pytorch/pull/109759
- https://github.com/microsoft/onnxruntime/pull/17020
- https://github.com/microsoft/onnxruntime/pull/17674
- https://github.com/microsoft/onnxruntime/pull/17890
- https://github.com/microsoft/onnxruntime/pull/17920
- https://github.com/huggingface/transformers/pull/26162
- https://github.com/huggingface/optimum/pull/1257
- https://github.com/huggingface/optimum/pull/1289
- https://github.com/huggingface/optimum/pull/1462

### New TorchDynamo Exporter (experimental stage)

This PR uses changes from the following issues and PRs to begin
supporting the [new TorchDynamo
exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter):
- https://github.com/huggingface/transformers/pull/26307
- https://github.com/pytorch/pytorch/issues/104903
- https://github.com/pytorch/pytorch/pull/105040
- https://github.com/microsoft/onnxscript/pull/847
- https://github.com/microsoft/onnxscript/pull/862
- https://github.com/microsoft/onnxscript/issues/493
2023-10-23 13:00:56 -07:00
Yufeng Li
11af34440a
Add MatMul 4bits support on GPU (#17890)
### Description
<!-- Describe your changes. -->
Add a contrib op MatMulNBits and related toolchain to support
quantization on weight. This PR only adds support for 4bits. It:

- add schema for contrib op MatMulNBits which can support 1-7 bits
quantization on weight.
- a naive implementation for 4bits MatMulNBits on CPU and GPU, i.e.,
implemented like MatMul(A, Dequantize(B)).
- a special implementation for GemV for 4bits MatMulNBits and related
benchmark tool
- tool to quantization model with 4bits. 

Next:
- add general and more efficient kernels for 4bits MatMulNBits on CPU
and GPU
2023-10-13 16:55:30 -07:00
Zhang Lei
762703e037
Support output cross qk, dtw and more for whisper model (#17500)
Support cross qk in beam search for whisper model and related features
Make whisper exporting tools support cross qk and some related features,
* extra_decoding_ids
* no_speech_prob

Implement DTW kernel, unfold tensor kernel with unit test Several fix
related with multiple session running parallel, like:

* guard multihead_attention, fused_fp16_runner_
* some memory allocation with stream awareness
* add use_ep_level_unified_stream option
2023-10-13 11:47:15 -07:00
aciddelgado
406cd324e0
[CUDA] GroupQueryAttention operator using FlashAttention (#17674)
### Description
Added Group Query Attention op, supporting integer multiple number of
heads for Q / KV. As of now, this op can only use FlashAttention kernel,
meaning it only supports sm>=80 on Linux.

Results from onnxruntime/test/python/transformers/benchmark_gqa.py show
an on-average ~37% speed-up over Decoder Masked Multi-Head Attention,
with even greater improvements for long past sequence lengths.

```
op      batch   s_kv    heads   h_dim   ms      TFLOPS
gqa     16      2048    8       32      0.34    0.10
dmmha   16      2048    8       32      0.39    0.09
---------
gqa     16      2048    8       64      0.45    0.15
dmmha   16      2048    8       64      0.61    0.11
---------
gqa     16      2048    8       128     0.54    0.25
dmmha   16      2048    8       128     0.83    0.16
---------
gqa     16      2048    16      32      0.45    0.15
dmmha   16      2048    16      32      0.69    0.10
---------
gqa     16      2048    16      64      0.69    0.19
dmmha   16      2048    16      64      0.83    0.16
---------
gqa     16      2048    16      128     0.71    0.38
dmmha   16      2048    16      128     1.28    0.21
---------
gqa     16      2048    32      32      0.58    0.23
dmmha   16      2048    32      32      0.77    0.17
---------
gqa     16      2048    32      64      0.58    0.46
dmmha   16      2048    32      64      1.25    0.21
---------
gqa     16      2048    32      128     0.76    0.71
dmmha   16      2048    32      128     2.15    0.25
---------
gqa     16      2048    64      32      0.68    0.39
dmmha   16      2048    64      32      1.23    0.22
---------
gqa     16      2048    64      64      0.77    0.70
dmmha   16      2048    64      64      2.11    0.25
---------
gqa     16      2048    64      128     1.10    0.97
dmmha   16      2048    64      128     4.06    0.26
---------
gqa     16      2048    128     32      1.00    0.54
dmmha   16      2048    128     32      2.09    0.26
---------
gqa     16      2048    128     64      1.10    0.97
dmmha   16      2048    128     64      4.08    0.26
```


### Motivation and Context
As of now, this op is targeted for use on LLama models, as it supports
kv-caching and different number of heads for Q and KV (Grouped Query
Attention). We plan to add support for more platforms, input formats,
etc. in the future.

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-10-09 12:43:12 -07:00
Hector Li
385fab5bae
[QNN EP] Qnn cache improvement (#17757)
### Description
Improve the QNN context binary cache feature to reduce the memory
overhead and initialization time overhead.
Instead of dumping a Qnn context binary file with metadata as header, we
dump a Onnx format file with metadata inside Onnx node.

### Motivation and Context
 reduce the memory overhead and initialization time overhead
2023-10-06 15:56:33 -07:00
Adrian Lizarraga
dea425e7c1
[QNN/CPU EP] Add 16-bit Quantize/Dequantize contrib ops (#17015)
### Description
- Adds 16-bit integer support to:
- Quantization kernel implementations: Intel, Neon, and Power intrinsics
  - DequantizeLinear and QuantizeLinear contrib ops
  - QNN EP Quantize and Dequantize operators
  - Python quantization scripts
- Disables QDQ fusions for most 16-bit QDQ node groups (need to add
16-bit support to QLinear* ops)
- Retains support for dropping QDQ nodes from Split, Gather, Reshape,
Transpose, Squeeze, and Unsqueeze node groups.

Sample python code to generate QDQ model with 16-bit activations and
8-bit weights:
```python
    quantize_static(
        input_model_path,
        output_model_path,
        data_reader,
        quant_format=args.quant_format,
        per_channel=args.per_channel,
        activation_type=QuantType.QUInt16,
        weight_type=QuantType.QUInt8,
        extra_options={"DedicatedQDQPair": True, "ForceQuantizeNoInputCheck": True, "UseQDQContribOps": True},
    )
``` 

Note that enabling the `UseQDQContribOps` extra option is not strictly
necessary. If the 16bit types are used without enabling
`UseQDQContribOps`, the QDQ ops domains are overridden to
'com.microsoft', and a warning is printed to stdout.

### Automated Tests
MLAS/CPU EP:
- [x] 16-bit QuantizeLinear computation
- [x] 16-bit DequantizeLinear computation

Optimizer:
- [x] Transpose QDQ fusion
- [x] Gather QDQ fusion
- [x] Reshape QDQ fusion
- [x] Squeeze QDQ fusion
- [x] Unsqueeze QDQ fusion
- [x] Split drop QDQ
- [x] DoubleQDQPairRemover 
- [x] Transpose optimization
- [x] EnsureUniqueDQForNodeUnit
- [x] Common subexpression elimination (DQ not removed)
- [x] Constant folding

QNN EP:
- [x] Conv 16-bit activations, 8-bit weights
- [x] MatMul 16-bit activations, 8-bit weights
- [x] Unary 16-bit QDQ ops
- [x] Binary 16-bit QDQ ops

Quantization tool:
- [x] Test creation of 16-bit QDQ model
### Motivation and Context
Support mixed precision (8bit weights, 16bit activations) models.

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
2023-09-18 09:43:34 -07:00
Adrian Lizarraga
5a83a67f32
Support QDQ transformations with com.microsoft.Quantize/Dequantize ops (#17127)
### Description
- Enables int32 support for com.microsoft.DequantizeLinear (contrib op)
- Makes the `zero_point` input optional for Quantize/Dequantize contrib
ops
- Enables QDQ transformations with the Quantize/Dequantize contrib ops
- Update tests: EnsureUniqueDQForNodeUnitTests, QDQTransformerTests,
TransposeOptimizerTests

### Testing
List of tested graph transformations:
- [x] QDQSelectorActionTransformer
  - qdq_transformer_test.cc
- [x] QDQS8ToU8Transformer
  - qdq_transformer_test.cc
- [x] DoubleQDQPairsRemover
  - qdq_transformer_test.cc
- [x] IdenticalChildrenConsolidation
  - qdq_transformer_test.cc
- [x] QDQPropagation
  - qdq_transformer_test.cc
- [x] QDQFinalCleanup
  - qdq_transformer_test.cc
- [x] CliQuantFusion
  - qdq_transformer_test.cc
- [x] ReluQuantFusion
  - qdq_transformer_test.cc
- [x] EnsureUniqueDQForNodeUnit 
  - ensure_unique_dq_for_node_unit_test.cc
- [x] TransposeOptimizer 
  - transpose_optimizer_test.cc
- [x] CommonSubexpressionElimination
  - graph_transform_test.cc
- [x] ConstantFolding
  - graph_transform_test.cc

### Motivation and Context
We need to [support mixed 16-bit/8-bit precision QDQ
models](https://github.com/microsoft/onnxruntime/pull/17015). This PR is
the first step in achieving this goal: we need to make QDQ contrib ops
work with our optimizations/transformations.

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
2023-08-25 09:57:51 -07:00
Chen Fu
3c10f027de
4b quantization for weights of LLMs (#16833)
### Description
Blockwise 4b quantization for LLMs. 
1. Introduce 4b block-wise quantization for linear layer weights.
2. Implements matrix multiplication kernel for fp32 x int4
3. Implements special operator MatMulFpQ4
4. Implements quantization tool, that convert MatMul operator to
MatMulFpQ4, when the right hand side is 2D const tensor.


### Motivation and Context
Compress and accelerate LLMs

|Benchmark | Time(ns)|
|-------------|----------|
|Q4GEMM/Q4Sym/M:1/N:4096/K:4096/Threads:8| 218054|
|Q4GEMM/Q4Sym/M:1024/N:4096/K:4096/Threads:8| 35830155|
|Q4GEMM/Q4Sym/M:2048/N:4096/K:4096/Threads:8| 73479790|
|Q4GEMM/Q4Zp8/M:1/N:4096/K:4096/Threads:8| 270152|
|Q4GEMM/Q4Zp8/M:1024/N:4096/K:4096/Threads:8| 35826721|
|Q4GEMM/Q4Zp8/M:2048/N:4096/K:4096/Threads:8| 73021200|
|Q4GEMM/Q4Sym128/M:1/N:4096/K:4096/Threads:8| 213832|
|Q4GEMM/Q4Sym128/M:1024/N:4096/K:4096/Threads:8| 36749874|
|Q4GEMM/Q4Sym128/M:2048/N:4096/K:4096/Threads:8| 72618120|


|Benchmark | Time(ns)|
|-------------|----------|
|SGEMM/LLM/M:1/N:4096/K:4096/Threads:8|   522610|
|SGEMM/LLM/M:1024/N:4096/K:4096/Threads:8| 39237689|
|SGEMM/LLM/M:2048/N:4096/K:4096/Threads:8| 75983467|

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
2023-08-07 12:23:55 -07:00
Khalia Spear
4e6ea730d6
Broadcasting for SLN for CPU and CUDA (#16510)
### Description
Enhanced SkipLayerNorm by implementing broadcasting for both CPU and
CUDA



### Motivation and Context
The input and skip tensors no longer have to be the same size which
means that it can accept data where the skip shape can be the same size
as the input shape, have a shape of {1, sequence_length, hidden_size},
or {sequence_length, hidden_size}.

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
2023-08-07 09:55:42 -07:00
Tianlei Wu
50bf310dea
[CUDA] RelativePositionBias supports input with padding removed (#16923)
update RelativePositionBias to support input with padding removed.
- [x] add bias transpose kernel
- [x] add test
- [x] update operator document
2023-08-01 16:39:09 -07:00
Tianlei Wu
1fbd1ed179
[CUDA] PackedMultiHeadAttention support Bias and separated Q, K and V inputs (#16913)
### Description
Follow-up change for PackedMultiHeadAttention added in
https://github.com/microsoft/onnxruntime/pull/16779:
- [x] Add Bias input
- [x] Add CUDA kernels to support separated query, key and values
inputs.
- [x] Update operator documents
- [x] Add unit tests
2023-08-01 15:30:41 -07:00
Tianlei Wu
742edec5e8
[CUDA] Add PackedMultiHeadAttention operator (#16779)
### Description
Add new operator for MultiHeadAttention with inputs removed padding.
This only supports packed QKV format.
2023-07-28 16:35:38 -07:00
Alexey Kamenev
7c05f7bab1
Fix IRFFT contrib op output dimension calculation (#15662)
### Description
Fixes the issue with IRFFT output dimension calculation as described in
#13236

### Motivation and Context
Please refer to #13236 for detailed description.

Specifically, [this code](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.cc#L103) computes the output dimension as:
```
out_dim = in_dim * 2 - 1
```
while it should be this instead:
```
out_dim = 2 * (in_dim - 1)
```
(assuming the original signal has even number of samples, of course).

For example, if the original signal has 4 samples, then the round trip should look something like:
```
4 -> (one-sided RFFT) -> 3 (complex) -> (one-sided IRFFT) -> 4
```
with the current code the output will be a signal with 5 points.

---------

Co-authored-by: Alexey Kamenev <akamenev@nvidia.com>
Co-authored-by: Nick Geneva <nicholasgeneva@gmail.com>
2023-07-28 15:52:37 -07:00
Patrice Vignola
649930142f
[DML EP] Add NCHW and float16 gamma/beta support for GroupNorm (#16814)
This will remove transposes that are non needed in the DML kernel. To
keep backward compatiblity, the default behavior is to set NHWC when no
attribute is set.
2023-07-25 21:43:29 -07:00
Ye Wang
dd7d721f3c
support rotary embeddings in decoder masked self-attention (#16556)
### Description
<!-- Describe your changes. -->

This PR adds support for rotary embeddings in decoder masked
self-attention

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
2023-07-12 13:48:48 -07:00
Zhang Lei
0f8e66d905
optimization for whisper model with decoder masked multihead attention (#15827)
* graph tools update
* cuda kernel update
* operator spec update and implementation update
* greed search bug fix on wrong assumption for cross/self attention
input length
* avoid use of "" name in value info when loading graph which
historically in many model
2023-05-18 15:38:31 -07:00
stevenlix
270c09a37f
Add timestamp logits processor for whisper (#15853)
Enable timestamp estimation and logits processing for Whisper model.
2023-05-16 21:40:00 -07:00
kunal-vaishnavi
5b663d6797
Whisper Multitask and Multilingual (#15936)
### Description
This PR enables Whisper's multitask format and allows a user to use
Whisper for multiple tasks (e.g. transcription, translation) and for
multilingual purposes (e.g. English, Spanish). This PR also removes
`attention_mask` as a required input for Whisper with beam search.

### Usage
Here is an example of how you can use Whisper for English transcription.
```
import numpy as np
import onnxruntime as ort

from datasets import load_dataset
from transformers import AutoConfig, AutoProcessor

model = "openai/whisper-tiny"
config = AutoConfig.from_pretrained(model)
processor = AutoProcessor.from_pretrained(model)

forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
# forced_decoder_ids is of the format [(1, 50259), (2, 50359), (3, 50363)] and needs to be 
# of the format [50258, 50259, 50359, 50363] where 50258 is the start token id
forced_decoder_ids = [config.decoder_start_token_id] + list(map(lambda token: token[1], forced_decoder_ids))

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features = processor(ds[0]["audio"]["array"], return_tensors="np").input_features

inputs = {
  "input_features": np.float32(input_features),
  "max_length": np.array([26], dtype=np.int32),
  "min_length": np.array([1], dtype=np.int32),
  "num_beams": np.array([2], dtype=np.int32),
  "num_return_sequences": np.array([1], dtype=np.int32),
  "length_penalty": np.array([1.0], dtype=np.float32),
  "repetition_penalty": np.array([1.0], dtype=np.float32),
  "decoder_input_ids": np.array([forced_decoder_ids], dtype=np.int32),
}
sess = ort.InferenceSession("whisper-tiny_beamsearch.onnx", providers=["CPUExecutionProvider"])
outputs = sess.run(None, inputs)

# Print tokens and decoded output
print(outputs[0][0][0])
print(processor.decode(outputs[0][0][0]))
```

If you don't want to provide specific decoder input ids or you want
Whisper to predict the output language and task, you can set
`forced_decoder_ids = [config.decoder_start_token_id]` instead.

### Motivation and Context

As seen in the figure below from the [OpenAI Whisper
paper](https://cdn.openai.com/papers/whisper.pdf), Whisper can be used
for multiple tasks and languages.

![Screenshot 2023-05-12
165215](https://github.com/microsoft/onnxruntime/assets/115581922/49335e39-a79c-4f78-92e9-89b034405f65)
2023-05-15 14:36:33 -07:00
Ye Wang
3418ca28a8
pack qkv in t5 decoder (#15801)
### Description
<!-- Describe your changes. -->

V100, b_4_s_128, max_output_len=64, beam=4

before:
t5_small: 101.28ms
t5_base:  200.07ms

after:
t5_small: 87.65ms
t5_base: 174.44ms



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
2023-05-15 13:45:39 -07:00
kunal-vaishnavi
39d6d7050d
Change EmbedLayerNormalization mask index output to optional (#15526)
### Description
This PR changes an EmbedLayerNormalization node's mask index output to
be an optional output if a mask input is not provided.



### Motivation and Context
The documentation for EmbedLayerNormalization states 
```
The last input mask is optional. If mask is provided, mask index (that is position of first 0 in mask, or number of words) will be calculated.
```
However, if the mask input is not provided, the mask index output is
still calculated and required.
2023-04-27 16:32:42 -07:00
Patrice Vignola
3be5bfe363
[DML EP] Add MatMul + SoftMax fusion (#15240) 2023-04-11 08:31:04 -07:00
stevenlix
6d126f8996
Add FP16 support for Whisper model (#15427)
Current ORT can only run inference for Whisper FP32 model. This PR adds
FP16 support.
2023-04-08 21:36:10 -07:00
Chen Fu
8dce83a818
Fuse 'Add' operator into FP16 Conv (#15213)
### Description
Adding 'Add' functionality to FP16 Conv operator. It takes a tensor that
has the same shape of the output tensor, and add it to the result
tensor.


### Motivation and Context
Needed to run Resnet 50
2023-04-07 09:51:03 -07:00
petermcaughan
1251964f96
Petermca/beamsearch whisper (#15339)
### Description
Adjust various code paths to allow Whisper model to function with
BeamSearch op.

Approach: Add a new kModelType enum value in IGenerationParameters as
so:
#### Old: 0 = GPT2, 1 = T5
#### New: 0 = GPT2, 1 = T5, 2 = Whisper

When the user assigns this attribute value to 2, various shape and type
checks are changed to accommodate Whisper inputs.


### Motivation and Context
BeamSearch is currently designed to function with BERT-based models with
inputs as vocab tokens, and needs changes to function with Whisper
inputs (3-D float values processed from audio data).

---------

Co-authored-by: Peter McAughan <petermca@microsoft.com>
2023-04-04 09:09:10 -07:00
Ye Wang
fbfe92f66a
DecoderMaskedMultiHeadAttention enhancement (#15292) 2023-04-02 21:53:03 -07:00
Yufeng Li
c08d6b42e8
Add tool to support packing mode for BERT model (#15283)
### Description
<!-- Describe your changes. -->
Add a tool to convert fused BERT like model to packing mode


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-03-31 08:46:47 -07:00
Ye Wang
44ba23e0f5
Rename DecoderMaskedMHA to DecoderMaskedSelfAttn (#15166)
### Description
<!-- Describe your changes. -->

As synced offline, rename this op and will create another op for mha
that supports both self and cross attention.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
2023-03-23 12:31:38 -07:00
Ye Wang
2ee822d483
Extend memory efficient attention coverage in Attention/MHA cuda op (#15064)
### Description
<!-- Describe your changes. -->

1. upgrade cutlass to 3.0 that containing attn_bias support.
2. extend Attention/MHA to use memory efficient attention when
rel_pos_bias with [1, num_head, s, s*] and 1d mask with [2 * batch_size
+ 1] are present.

new mask format introduction:
MASK_1D_KEY_SEQ_LEN_START,  
[3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1],
query_start[0], ..., query_start[batch_size - 1], query_end[batch_size -
1], key_start[0], ..., key_start[batch_size - 1], key_end[batch_size -
1]]

e.g
2D mask with [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 0]] converts to this
1D mask is [3, 5, 0, 6, 12, 0, 6, 12]


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

It potentially benefits tnlrv6 and t5(encoder)

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com>
Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
2023-03-23 11:05:17 -07:00
Hariharan Seshadri
7033346605 Support mask_filter_value attribute in DecoderMaskedMultiheadAttention (#15158) 2023-03-23 11:00:09 -07:00
Yufeng Li
c7ced7a5e9
Add PackedAttention for packing mode (#14858)
### Description
<!-- Describe your changes. -->
Transformer models can handle batch of inputs at once. However,
sequences in a batch usually have different length. Then we have to pad
the short one to have same length as the longest. This is not efficient
especially for large batch with high variance.

This PR introduces a PackedAttention operator which can take in packed
sequences (no padding) and also produces output in packing mode.

There will be another PR to use the PackedAttention to implement the
encoder in packing mode.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-03-21 12:59:29 -07:00
Hariharan Seshadri
ed7ab1660d
[CUDA] Add option to use DecoderMaskedMultiheadAttention in BeamSearch (#14990) 2023-03-15 17:16:32 -07:00
Ye Wang
538d64891a
[t5 optimization] kernel changes to t5 (#14928)
### Description
<!-- Describe your changes. -->

1. support optional bias in Attention op (used in T5 encoder)
2. support broadcasting rel_pos_bias in attention_softmax.h
3. add scale in
MHA op's attributes
4. support past_key/past_value and present_key/present_value in MHA
5. UT and parity tests are added
6. fix an issue: https://github.com/microsoft/onnxruntime/issues/14920

note: the fusions will be in another PR since mt5 needs to be tested and
an issue from github will be investigated.

Future works:
1. support shared buffer for past/present
2. enable trt kernels when possible and investigate (trt/cutlass)kernels
with rel_pos_bias)
3. support KV/QKV packing with past/present

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
2023-03-13 14:29:16 -07:00
Hariharan Seshadri
112a4d215a
[CUDA] Support decoding multihead self-attention implementation (#14848) 2023-03-08 09:17:54 -08:00
Ye Wang
58da3cacdf
support NeoX-style rotary embedding (#14785)
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
2023-02-22 18:21:34 -08:00
Tianlei Wu
eb2ac72fa9
Stable Diffusion CUDA Optimizations Part 4 (#14680)
(1) Support packed QKV format in MultiHeadAttention. This format could
avoid add bias transpose when TRT fused kernel is used.
(2) Add cache for cumulated sequence length computation. For SD, it only
need computed once since sequence length is fixed.
(3) Do not allocate qkv workspace to save memory for packed KV or QKV.
(4) Add unit tests for packed kv and packed qkv format in
MultiHeadAttention
(5) Mark some fusion options for SD only

Performance tests show slight improvement in T4. Average latency reduced
0.15 seconds (from 5.25s to 5.10s) for 512x512 in 50 steps for SD 1.5
models. Memory usage drops from 5.1GB to 4.8GB.
2023-02-15 14:55:42 -08:00
Tianlei Wu
f638c5a2ae
Stable Diffusion CUDA Optimizations Part 3 (#14646)
The third part for stable diffusion CUDA optimizations
(1) Add BiasAdd operator to replace two Add (bias and residual); Add
fusion for BiasAdd
(2) Add Attention fusion for VAE decoder.
(3) Update float16 conversion to handle Resize and GroupNorm. This could
reduce two Cast nodes for each Resize op in fp16 model.
(4) Force inputs and outputs to be float16 to avoid data casts in the
pipeline.
(5) Add options --force_fp32_ops, --inspect etc in optimize script so that
user could force some operator to run in float32 to potentially get
better image quality (with cost of performance).

Performance tests show slight improvement in T4. Average latency reduced
0.1 seconds (from 5.35s to 5.25s) for 512x512 in 50 steps.
2023-02-14 12:46:50 -08:00
Ye Wang
b539c364ee
Some kernel changes for TULR (#14517)
### Description
<!-- Describe your changes. -->
1. fix a bug in relative position bias kernel where seq_len > 32
2. rename extra_add_qk to relative_position_bias
3. support relative_position_bias in multihead attention (B, N, S, S*)
4. gru_gate support by Lei


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
Co-authored-by: Lei Zhang <zhang.huanning@hotmail.com>
2023-02-07 11:51:06 -08:00
Tianlei Wu
a6c5ba0185
Stable Diffusion CUDA Optimizations (#14428)
### Description

Add stable diffusion CUDA kernel optimizations.

The following are included:
(1) GroupNorm operator. This kernel is from TensorRT 8.5.
(2) BiasSplitGelu operator. This kernel is modified from SplitGelu of
TensorRT 8.5. We added bias to the SplitGelu.
(3) NhwcConv operator. This adds support of NHWC format (ONNX Conv
operator uses NCHW format).
(3) Update MultiHeadAttention (packed kv and no bias) for cross
attention. This could avoid transpose of kv for TRT fused cross
attention kernel.
(4) Optimization and benchmark script

Not included:
(1) Script to convert Conv to NhwcConv in onnx graph.
(2) Update symbolic shape inference for NhwcConv.
(3) Add SeqLen2Spatial operator
(4) Documents

Limitations: GroupNorm, BiasSplitGelu and NhwcConv kernels are
implemented based on stable diffusion usage. They might not be
applicable to any input size or dimensions. For example, BiasSplitGelu
requires hidden size to be 2560 | 5120 | 10240, and NhwcConv assumes 4D
input/weight.

There is minor increasement of binary size. For SM=75 only, python
package wheel size adds (33757K - 33640K) = 117 KB. It is possible to
move NHWC from template parameter to constructor to reduce binary size
(with slight cost of performance).

Note: for RTX 4090/4080/4070 Ti, need build with CUDA 11.8 and latest
cuDNN to get best performance.
2023-02-02 23:43:51 -08:00
Ye Wang
de7a868d5f
Update quantization_defs.cc (#14380)
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-01-20 15:03:50 -08:00
Ye Wang
668586e8f8
Support muP in Attention (#14348)
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
2023-01-19 20:36:55 -08:00
Tianlei Wu
477cad3051
[CUDA] Add trt cross attention kernels (#14328)
Add TRT cross attention kernels for stable diffusion optimization.
2023-01-17 17:55:45 -08:00
Ye Wang
2db57a53a3
Add mask_filter in Attention related ops' attribute (#14274)
### Description
<!-- Describe your changes. -->


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

https://github.com/microsoft/onnxruntime/issues/12843

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
2023-01-17 12:28:11 -08:00
Zhang Lei
15141a40b4
Add present_past_share_buff to QAttention Defs to enable QAttention related tests. (#14297) 2023-01-14 09:19:06 -08:00