onnxruntime/onnxruntime/python/tools/transformers/Dev_Guide.md

7.5 KiB

Transformer Model Optimization Tool Dev Guide

Transformer model optimization tool applies to BERT, GPT-2 and some variations (like Roberta, DistilBert etc). However, it cannot cover all the cases especially for the new ones that are coming out of academics. This guide will give you an overall introduction of how the graph transformation works and how to optimize your custom transformer-based model with limited code changes on graph fusion logic and kernels implementations.

The objective of the Dev Guide is to enable more transformer-based models to take advantage of ONNXRuntime optimized kernels.

Meanwhile, welcome to contribute!

Prerequisite

Rule Of Thumb

The graph fusion transforms a certain graph structure to a single fused node. The kernel wrapped by the fused node is the strict computation equivalent of that certain graph structure and executed by the runtime engine. This means that the candidate graph should have the exact same logic as fused node kernel implementation. It's suggested to get familiar with the targeted optimized kernel implementation and then work on the fusion logic.

Kernel Implementation

ONNXRuntime supports optimized kernels as contrib operators in both CPU and CUDA Execution Provider.

For instance, the entry point of Attention CPU kernel is the Compute() function. Similarly, for the EmbedLayerNorm CUDA kernel, the entry point is the ComputeInternal() function.

Graph Fusion

The main part of the transformer optimizer is graph fusion. In the current implementation for bert optimization, it supports a couple of fusions executed in order. Each particular graph fusion is an inheritance class of Fusion with fuse() method to implement. For instance, the fuse() method in attention fusion.

The onnx_model class provides many useful functions to modify onnx graph including not limited to:

Fusion process

After fusing the graph, check the parity between optimized onnx model and original one by feeding the same inputs to both models and comparing outputs.

A Concrete Case

Contribution

Coding Conventions and Standards