Mistral Optimization & Benchmarking Support (#18225)

### Description
As a prerequisite for this model running correctly, two PRs need to be
merged:

- GQA Sliding Window Attention:
https://github.com/microsoft/onnxruntime/tree/aciddelgado/gqa_local
- MHA Fusion:
https://github.com/frankdongms/onnxruntime/tree/frdong/llama_70b

This PR adds optimization, quantization, and benchmarking support for
Mistral. The README included describes steps to export, optimize, and
benchmark Mistral models, but won't function correctly without the two
above branches being merged first.

---------

Co-authored-by: Peter McAughan <petermca@microsoft.com>
Co-authored-by: Abhishek Jindal <abjindal@microsoft.com>
Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
This commit is contained in:
petermcaughan 2023-12-05 15:39:17 -08:00 committed by GitHub
parent c9e558cd36
commit 871c52977a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 7 deletions

View file

@ -1272,7 +1272,9 @@ def find_past_seq_len_usage(subg: GraphProto):
return tensor_names_to_rename, nodes_to_remove
def replace_mha_with_gqa(model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1):
def replace_mha_with_gqa(
model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0
):
# Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
#
# attention_mask

View file

@ -1,3 +1,13 @@
# Contents
- [LLaMA-2](#llama-2)
- [Exporting LLaMA-2](#exporting-llama-2)
- [Benchmarking LLaMA-2](#benchmark-llama-2)
- [Mistral](#mistral)
- [Exporting Mistral](#exporting-mistral)
- [Optimizing and Quantizing Mistral](#optimizing-and-quantizing-mistral)
- [Benchmarking Mistral](#benchmark-mistral)
# LLaMA-2
## Prerequisites
@ -372,3 +382,58 @@ python3 -m models.llama.benchmark_all \
--num-runs 1000 \
--timeout 60 # number of minutes before moving to the next benchmark
```
# Mistral
## Introduction
These tools for LLaMA-2 also allow the quantization and optimization of Mistral in ORT.
## Exporting Mistral
There is currently one supported way to export Mistral to ONNX format:
### [Hugging Face Optimum](https://github.com/huggingface/optimum)
The following command will export Mistral in full precision:
```
python -m optimum.exporters.onnx -m mistralai/Mistral-7B-v0.1 --library-name transformers /path/to/model/directory
```
## Optimizing and Quantizing Mistral
To quantize Mistral to FP16 and apply fusion optimizations, you can run the following command:
```
python -m models.llama.convert_to_onnx -i /path/to/model/directory -o /path/to/optimized_model/directory -p fp16 --optimize_optimum -m mistralai/Mistral-7B-v0.1
```
## Benchmark Mistral
The benchmarking scripts in the LLaMA directory support Mistral benchmarking. To benchmark the ORT version, you can run:
```
python -m models.llama.benchmark \
-bt ort-convert-to-onnx \
-p fp16 \
-m mistralai/Mistral-7B-v0.1 \
--ort-model-path /path/to/model.onnx
```
To benchmark the Hugging Face implementation without `torch.compile`:
```
python -m models.llama.benchmark \
-bt hf-pt-eager \
-p fp16 \
-m mistralai/Mistral-7B-v0.1
```
And to benchmark the Hugging Face implementation with `torch.compile`:
```
python -m models.llama.benchmark \
-bt hf-pt-compile \
-p fp16 \
-m mistralai/Mistral-7B-v0.1
```

View file

@ -79,7 +79,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
return_dict=True,
)
elif args.benchmark_type == "hf-ort":
elif args.benchmark_type in {"hf-ort"}:
if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
# Using split models in Optimum (e.g. created by Optimum export)
init_inputs = get_sample_inputs(
@ -529,7 +529,13 @@ def get_args(rank=0):
"--benchmark-type",
type=str,
required=True,
choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort-msft", "ort-convert-to-onnx"],
choices=[
"hf-pt-eager",
"hf-pt-compile",
"hf-ort",
"ort-msft",
"ort-convert-to-onnx",
],
)
parser.add_argument(
"-m",

View file

@ -391,7 +391,7 @@ def run_torchscript_merged_export(
# Optimize the model as FP32
def optimize_export(config: AutoConfig, input_path: str, output_path: str):
def optimize_export(config: AutoConfig, input_path: str, output_path: str, remove_model: bool = True):
from fusion_options import FusionOptions
optimization_options = FusionOptions("gpt2")
@ -407,7 +407,8 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str):
)
model_opt.save_model_to_file(output_path, use_external_data_format=True)
logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
remove_existing_model(input_path)
if remove_model:
remove_existing_model(input_path)
def convert_to_float16(
@ -438,7 +439,7 @@ def convert_to_float16(
return new_paths
def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1):
def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1, window_size: int = 0):
# Replace MultiHeadAttention with GroupQueryAttention
fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "attention_mask", config.num_key_value_heads, world_size)
fp16_model_opt.prune_graph()
@ -539,6 +540,23 @@ def remove_existing_files(output_path: str):
logger.warning(f"Removed {filepath}")
def optimize_optimum(config: AutoConfig, args: argparse.Namespace):
tmp_file = os.path.join(args.output, args.model_name + ".tmp.onnx")
output_file = os.path.join(args.output, args.model_name + ".onnx")
optimize_export(config, args.input, tmp_file, remove_model=False)
logger.info(f"Model successfully optimized to {tmp_file}")
opt_model = OnnxModel(onnx.load_model(tmp_file, load_external_data=True))
if args.precision == Precision.FLOAT16:
opt_model.convert_float_to_float16(keep_io_types=False)
window_size = 0 if not hasattr(config, "sliding_window") else config.sliding_window
opt_model = use_group_query_attention(config, opt_model, args.world_size, window_size)
logger.info("Model successfully fused and quantized to FP16!")
opt_model.save_model_to_file(output_file, use_external_data_format=True)
logger.info(f"Output model successfully saved to {output_file}")
logger.info(f"Removing {tmp_file}")
remove_existing_model(tmp_file)
def get_args():
parser = argparse.ArgumentParser()
@ -554,7 +572,7 @@ def get_args():
"--input",
required=False,
default=os.path.join("."),
help="Directory path to PyTorch model and associated files if saved on disk",
help="Directory path to PyTorch model and associated files if saved on disk, or ONNX model file location if optimize_optimum is passed.",
)
parser.add_argument(
@ -720,6 +738,13 @@ def get_args():
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
)
parser.add_argument(
"--optimize_optimum",
action="store_true",
help="Avoid exporting model, only apply quantizations and optimizations to existing model exported from optimum.",
)
parser.set_defaults(optimize_optimum=False)
args = parser.parse_args()
return args
@ -740,6 +765,7 @@ def main():
world_size = get_size()
rank = get_rank()
args.world_size = world_size
# Load model and config
use_auth_token = args.input == os.path.join(".")
@ -754,6 +780,11 @@ def main():
location = args.original_model_name if use_auth_token else args.input
if args.optimize_optimum:
config = AutoConfig.from_pretrained(args.original_model_name)
optimize_optimum(config, args)
return
# Use CUDA for LLaMA-2-70B to speed up export and CPU for other models
l_config, llama = setup_torch_model(
args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None