mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
### Description <!-- Describe your changes. --> Adds bfloat16 as a supported dtype for SimplifiedLayerNormFusion which will provide speedup for Llama-v2 on A100 using bfloat16 numerical format. _layernorm_optimized_training.onnx exported in bfloat16 vs. float16:_  ### Repro Instructions ```python from torch import nn from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel import torch dtype = torch.bfloat16 # dtype = torch.float16 class Net(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(784, 10, dtype=dtype) self.layernorm = nn.LayerNorm([784], dtype=dtype) def forward(self, x): x = x.view(x.shape[0], -1) x = self.layernorm(x) x = self.fc(x) return x model = Net() model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='layernorm', log_level=LogLevel.INFO)) model.to("cuda") images = torch.randn((8, 28, 28), dtype=dtype).to("cuda") output = model(images) ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> ONNX Runtime integration with Llama-v2 family of LLMs. --------- Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> |
||
|---|---|---|
| .. | ||
| c_cxx | ||
| execution_providers/images | ||
| images | ||
| python | ||
| ABI_Dev_Notes.md | ||
| Android_testing.md | ||
| C_API_Guidelines.md | ||
| cmake_guideline.md | ||
| Coding_Conventions_and_Standards.md | ||
| ContribOperators.md | ||
| FAQ.md | ||
| How_To_Update_ONNX_Dev_Notes.md | ||
| Memory_Optimizer.md | ||
| Model_Test.md | ||
| NotesOnThreading.md | ||
| ONNX_Runtime_Server_Usage.md | ||
| onnxruntime_dependencies.dot | ||
| onnxruntime_dependencies.png | ||
| onnxruntime_extensions.md | ||
| OperatorKernels.md | ||
| ORT_Format_Update_in_1.13.md | ||
| ORT_Use_Trtion_Kernel.md | ||
| ORTMobilePackageOperatorTypeSupport.md | ||
| ORTModule_Convergence_Notes.md | ||
| ORTModule_ModuleWithLoss_Wrapper.md | ||
| ORTModule_PythonOp_Notes.md | ||
| ORTModule_Training_Guidelines.md | ||
| PR_Guidelines.md | ||
| Privacy.md | ||
| Python_Dev_Notes.md | ||
| Reduced_Operator_Kernel_build.md | ||
| ReleaseManagement.md | ||
| Roadmap.md | ||
| Server.md | ||
| TVM_EP.md | ||
| Versioning.md | ||
| WinML_principles.md | ||