mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
### Description <!-- Describe your changes. --> Adds bfloat16 as a supported dtype for SimplifiedLayerNormFusion which will provide speedup for Llama-v2 on A100 using bfloat16 numerical format. _layernorm_optimized_training.onnx exported in bfloat16 vs. float16:_  ### Repro Instructions ```python from torch import nn from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel import torch dtype = torch.bfloat16 # dtype = torch.float16 class Net(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(784, 10, dtype=dtype) self.layernorm = nn.LayerNorm([784], dtype=dtype) def forward(self, x): x = x.view(x.shape[0], -1) x = self.layernorm(x) x = self.fc(x) return x model = Net() model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='layernorm', log_level=LogLevel.INFO)) model.to("cuda") images = torch.randn((8, 28, 28), dtype=dtype).to("cuda") output = model(images) ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> ONNX Runtime integration with Llama-v2 family of LLMs. --------- Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> |
||
|---|---|---|
| .. | ||
| orttraining | ||
| tools | ||