mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Mistral Optimization & Benchmarking Support (#18225)
### Description As a prerequisite for this model running correctly, two PRs need to be merged: - GQA Sliding Window Attention: https://github.com/microsoft/onnxruntime/tree/aciddelgado/gqa_local - MHA Fusion: https://github.com/frankdongms/onnxruntime/tree/frdong/llama_70b This PR adds optimization, quantization, and benchmarking support for Mistral. The README included describes steps to export, optimize, and benchmark Mistral models, but won't function correctly without the two above branches being merged first. --------- Co-authored-by: Peter McAughan <petermca@microsoft.com> Co-authored-by: Abhishek Jindal <abjindal@microsoft.com> Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
This commit is contained in:
parent
c9e558cd36
commit
871c52977a
4 changed files with 111 additions and 7 deletions
|
|
@ -1272,7 +1272,9 @@ def find_past_seq_len_usage(subg: GraphProto):
|
|||
return tensor_names_to_rename, nodes_to_remove
|
||||
|
||||
|
||||
def replace_mha_with_gqa(model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1):
|
||||
def replace_mha_with_gqa(
|
||||
model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0
|
||||
):
|
||||
# Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
|
||||
#
|
||||
# attention_mask
|
||||
|
|
|
|||
|
|
@ -1,3 +1,13 @@
|
|||
# Contents
|
||||
- [LLaMA-2](#llama-2)
|
||||
- [Exporting LLaMA-2](#exporting-llama-2)
|
||||
- [Benchmarking LLaMA-2](#benchmark-llama-2)
|
||||
- [Mistral](#mistral)
|
||||
- [Exporting Mistral](#exporting-mistral)
|
||||
- [Optimizing and Quantizing Mistral](#optimizing-and-quantizing-mistral)
|
||||
- [Benchmarking Mistral](#benchmark-mistral)
|
||||
|
||||
|
||||
# LLaMA-2
|
||||
|
||||
## Prerequisites
|
||||
|
|
@ -372,3 +382,58 @@ python3 -m models.llama.benchmark_all \
|
|||
--num-runs 1000 \
|
||||
--timeout 60 # number of minutes before moving to the next benchmark
|
||||
```
|
||||
|
||||
# Mistral
|
||||
|
||||
## Introduction
|
||||
|
||||
These tools for LLaMA-2 also allow the quantization and optimization of Mistral in ORT.
|
||||
|
||||
## Exporting Mistral
|
||||
|
||||
There is currently one supported way to export Mistral to ONNX format:
|
||||
|
||||
### [Hugging Face Optimum](https://github.com/huggingface/optimum)
|
||||
|
||||
|
||||
The following command will export Mistral in full precision:
|
||||
```
|
||||
python -m optimum.exporters.onnx -m mistralai/Mistral-7B-v0.1 --library-name transformers /path/to/model/directory
|
||||
```
|
||||
|
||||
## Optimizing and Quantizing Mistral
|
||||
|
||||
To quantize Mistral to FP16 and apply fusion optimizations, you can run the following command:
|
||||
```
|
||||
python -m models.llama.convert_to_onnx -i /path/to/model/directory -o /path/to/optimized_model/directory -p fp16 --optimize_optimum -m mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
## Benchmark Mistral
|
||||
The benchmarking scripts in the LLaMA directory support Mistral benchmarking. To benchmark the ORT version, you can run:
|
||||
|
||||
```
|
||||
python -m models.llama.benchmark \
|
||||
-bt ort-convert-to-onnx \
|
||||
-p fp16 \
|
||||
-m mistralai/Mistral-7B-v0.1 \
|
||||
--ort-model-path /path/to/model.onnx
|
||||
```
|
||||
|
||||
To benchmark the Hugging Face implementation without `torch.compile`:
|
||||
|
||||
```
|
||||
python -m models.llama.benchmark \
|
||||
-bt hf-pt-eager \
|
||||
-p fp16 \
|
||||
-m mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
And to benchmark the Hugging Face implementation with `torch.compile`:
|
||||
|
||||
```
|
||||
python -m models.llama.benchmark \
|
||||
-bt hf-pt-compile \
|
||||
-p fp16 \
|
||||
-m mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
return_dict=True,
|
||||
)
|
||||
|
||||
elif args.benchmark_type == "hf-ort":
|
||||
elif args.benchmark_type in {"hf-ort"}:
|
||||
if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
|
||||
# Using split models in Optimum (e.g. created by Optimum export)
|
||||
init_inputs = get_sample_inputs(
|
||||
|
|
@ -529,7 +529,13 @@ def get_args(rank=0):
|
|||
"--benchmark-type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort-msft", "ort-convert-to-onnx"],
|
||||
choices=[
|
||||
"hf-pt-eager",
|
||||
"hf-pt-compile",
|
||||
"hf-ort",
|
||||
"ort-msft",
|
||||
"ort-convert-to-onnx",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
|
|
|
|||
|
|
@ -391,7 +391,7 @@ def run_torchscript_merged_export(
|
|||
|
||||
|
||||
# Optimize the model as FP32
|
||||
def optimize_export(config: AutoConfig, input_path: str, output_path: str):
|
||||
def optimize_export(config: AutoConfig, input_path: str, output_path: str, remove_model: bool = True):
|
||||
from fusion_options import FusionOptions
|
||||
|
||||
optimization_options = FusionOptions("gpt2")
|
||||
|
|
@ -407,7 +407,8 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str):
|
|||
)
|
||||
model_opt.save_model_to_file(output_path, use_external_data_format=True)
|
||||
logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
|
||||
remove_existing_model(input_path)
|
||||
if remove_model:
|
||||
remove_existing_model(input_path)
|
||||
|
||||
|
||||
def convert_to_float16(
|
||||
|
|
@ -438,7 +439,7 @@ def convert_to_float16(
|
|||
return new_paths
|
||||
|
||||
|
||||
def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1):
|
||||
def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1, window_size: int = 0):
|
||||
# Replace MultiHeadAttention with GroupQueryAttention
|
||||
fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "attention_mask", config.num_key_value_heads, world_size)
|
||||
fp16_model_opt.prune_graph()
|
||||
|
|
@ -539,6 +540,23 @@ def remove_existing_files(output_path: str):
|
|||
logger.warning(f"Removed {filepath}")
|
||||
|
||||
|
||||
def optimize_optimum(config: AutoConfig, args: argparse.Namespace):
|
||||
tmp_file = os.path.join(args.output, args.model_name + ".tmp.onnx")
|
||||
output_file = os.path.join(args.output, args.model_name + ".onnx")
|
||||
optimize_export(config, args.input, tmp_file, remove_model=False)
|
||||
logger.info(f"Model successfully optimized to {tmp_file}")
|
||||
opt_model = OnnxModel(onnx.load_model(tmp_file, load_external_data=True))
|
||||
if args.precision == Precision.FLOAT16:
|
||||
opt_model.convert_float_to_float16(keep_io_types=False)
|
||||
window_size = 0 if not hasattr(config, "sliding_window") else config.sliding_window
|
||||
opt_model = use_group_query_attention(config, opt_model, args.world_size, window_size)
|
||||
logger.info("Model successfully fused and quantized to FP16!")
|
||||
opt_model.save_model_to_file(output_file, use_external_data_format=True)
|
||||
logger.info(f"Output model successfully saved to {output_file}")
|
||||
logger.info(f"Removing {tmp_file}")
|
||||
remove_existing_model(tmp_file)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
|
@ -554,7 +572,7 @@ def get_args():
|
|||
"--input",
|
||||
required=False,
|
||||
default=os.path.join("."),
|
||||
help="Directory path to PyTorch model and associated files if saved on disk",
|
||||
help="Directory path to PyTorch model and associated files if saved on disk, or ONNX model file location if optimize_optimum is passed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -720,6 +738,13 @@ def get_args():
|
|||
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optimize_optimum",
|
||||
action="store_true",
|
||||
help="Avoid exporting model, only apply quantizations and optimizations to existing model exported from optimum.",
|
||||
)
|
||||
parser.set_defaults(optimize_optimum=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
|
@ -740,6 +765,7 @@ def main():
|
|||
|
||||
world_size = get_size()
|
||||
rank = get_rank()
|
||||
args.world_size = world_size
|
||||
|
||||
# Load model and config
|
||||
use_auth_token = args.input == os.path.join(".")
|
||||
|
|
@ -754,6 +780,11 @@ def main():
|
|||
|
||||
location = args.original_model_name if use_auth_token else args.input
|
||||
|
||||
if args.optimize_optimum:
|
||||
config = AutoConfig.from_pretrained(args.original_model_name)
|
||||
optimize_optimum(config, args)
|
||||
return
|
||||
|
||||
# Use CUDA for LLaMA-2-70B to speed up export and CPU for other models
|
||||
l_config, llama = setup_torch_model(
|
||||
args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None
|
||||
|
|
|
|||
Loading…
Reference in a new issue