From 871c52977aa4297d783fd4d830eaa10c71cb2be6 Mon Sep 17 00:00:00 2001 From: petermcaughan Date: Tue, 5 Dec 2023 15:39:17 -0800 Subject: [PATCH] Mistral Optimization & Benchmarking Support (#18225) ### Description As a prerequisite for this model running correctly, two PRs need to be merged: - GQA Sliding Window Attention: https://github.com/microsoft/onnxruntime/tree/aciddelgado/gqa_local - MHA Fusion: https://github.com/frankdongms/onnxruntime/tree/frdong/llama_70b This PR adds optimization, quantization, and benchmarking support for Mistral. The README included describes steps to export, optimize, and benchmark Mistral models, but won't function correctly without the two above branches being merged first. --------- Co-authored-by: Peter McAughan Co-authored-by: Abhishek Jindal Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> --- .../tools/transformers/convert_generation.py | 4 +- .../tools/transformers/models/llama/README.md | 65 +++++++++++++++++++ .../transformers/models/llama/benchmark.py | 10 ++- .../models/llama/convert_to_onnx.py | 39 +++++++++-- 4 files changed, 111 insertions(+), 7 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index b59af41c49..17f0dd0bc6 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,7 +1272,9 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def replace_mha_with_gqa(model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1): +def replace_mha_with_gqa( + model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0 +): # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes # # attention_mask diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 44dea3cb73..0e34fb0e69 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -1,3 +1,13 @@ +# Contents + - [LLaMA-2](#llama-2) + - [Exporting LLaMA-2](#exporting-llama-2) + - [Benchmarking LLaMA-2](#benchmark-llama-2) + - [Mistral](#mistral) + - [Exporting Mistral](#exporting-mistral) + - [Optimizing and Quantizing Mistral](#optimizing-and-quantizing-mistral) + - [Benchmarking Mistral](#benchmark-mistral) + + # LLaMA-2 ## Prerequisites @@ -372,3 +382,58 @@ python3 -m models.llama.benchmark_all \ --num-runs 1000 \ --timeout 60 # number of minutes before moving to the next benchmark ``` + +# Mistral + +## Introduction + +These tools for LLaMA-2 also allow the quantization and optimization of Mistral in ORT. + +## Exporting Mistral + +There is currently one supported way to export Mistral to ONNX format: + +### [Hugging Face Optimum](https://github.com/huggingface/optimum) + + +The following command will export Mistral in full precision: +``` +python -m optimum.exporters.onnx -m mistralai/Mistral-7B-v0.1 --library-name transformers /path/to/model/directory +``` + +## Optimizing and Quantizing Mistral + +To quantize Mistral to FP16 and apply fusion optimizations, you can run the following command: +``` +python -m models.llama.convert_to_onnx -i /path/to/model/directory -o /path/to/optimized_model/directory -p fp16 --optimize_optimum -m mistralai/Mistral-7B-v0.1 +``` + +## Benchmark Mistral +The benchmarking scripts in the LLaMA directory support Mistral benchmarking. To benchmark the ORT version, you can run: + +``` +python -m models.llama.benchmark \ + -bt ort-convert-to-onnx \ + -p fp16 \ + -m mistralai/Mistral-7B-v0.1 \ + --ort-model-path /path/to/model.onnx +``` + +To benchmark the Hugging Face implementation without `torch.compile`: + +``` +python -m models.llama.benchmark \ + -bt hf-pt-eager \ + -p fp16 \ + -m mistralai/Mistral-7B-v0.1 +``` + +And to benchmark the Hugging Face implementation with `torch.compile`: + +``` +python -m models.llama.benchmark \ + -bt hf-pt-compile \ + -p fp16 \ + -m mistralai/Mistral-7B-v0.1 +``` + diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index 021b0dd03a..a53dead77d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -79,7 +79,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): return_dict=True, ) - elif args.benchmark_type == "hf-ort": + elif args.benchmark_type in {"hf-ort"}: if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids] # Using split models in Optimum (e.g. created by Optimum export) init_inputs = get_sample_inputs( @@ -529,7 +529,13 @@ def get_args(rank=0): "--benchmark-type", type=str, required=True, - choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort-msft", "ort-convert-to-onnx"], + choices=[ + "hf-pt-eager", + "hf-pt-compile", + "hf-ort", + "ort-msft", + "ort-convert-to-onnx", + ], ) parser.add_argument( "-m", diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index c9c7f4d39d..e694b5050c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -391,7 +391,7 @@ def run_torchscript_merged_export( # Optimize the model as FP32 -def optimize_export(config: AutoConfig, input_path: str, output_path: str): +def optimize_export(config: AutoConfig, input_path: str, output_path: str, remove_model: bool = True): from fusion_options import FusionOptions optimization_options = FusionOptions("gpt2") @@ -407,7 +407,8 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str): ) model_opt.save_model_to_file(output_path, use_external_data_format=True) logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!") - remove_existing_model(input_path) + if remove_model: + remove_existing_model(input_path) def convert_to_float16( @@ -438,7 +439,7 @@ def convert_to_float16( return new_paths -def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1): +def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1, window_size: int = 0): # Replace MultiHeadAttention with GroupQueryAttention fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "attention_mask", config.num_key_value_heads, world_size) fp16_model_opt.prune_graph() @@ -539,6 +540,23 @@ def remove_existing_files(output_path: str): logger.warning(f"Removed {filepath}") +def optimize_optimum(config: AutoConfig, args: argparse.Namespace): + tmp_file = os.path.join(args.output, args.model_name + ".tmp.onnx") + output_file = os.path.join(args.output, args.model_name + ".onnx") + optimize_export(config, args.input, tmp_file, remove_model=False) + logger.info(f"Model successfully optimized to {tmp_file}") + opt_model = OnnxModel(onnx.load_model(tmp_file, load_external_data=True)) + if args.precision == Precision.FLOAT16: + opt_model.convert_float_to_float16(keep_io_types=False) + window_size = 0 if not hasattr(config, "sliding_window") else config.sliding_window + opt_model = use_group_query_attention(config, opt_model, args.world_size, window_size) + logger.info("Model successfully fused and quantized to FP16!") + opt_model.save_model_to_file(output_file, use_external_data_format=True) + logger.info(f"Output model successfully saved to {output_file}") + logger.info(f"Removing {tmp_file}") + remove_existing_model(tmp_file) + + def get_args(): parser = argparse.ArgumentParser() @@ -554,7 +572,7 @@ def get_args(): "--input", required=False, default=os.path.join("."), - help="Directory path to PyTorch model and associated files if saved on disk", + help="Directory path to PyTorch model and associated files if saved on disk, or ONNX model file location if optimize_optimum is passed.", ) parser.add_argument( @@ -720,6 +738,13 @@ def get_args(): help="model cache dir to override default HF cache dir to avoid overflood the /home dir", ) + parser.add_argument( + "--optimize_optimum", + action="store_true", + help="Avoid exporting model, only apply quantizations and optimizations to existing model exported from optimum.", + ) + parser.set_defaults(optimize_optimum=False) + args = parser.parse_args() return args @@ -740,6 +765,7 @@ def main(): world_size = get_size() rank = get_rank() + args.world_size = world_size # Load model and config use_auth_token = args.input == os.path.join(".") @@ -754,6 +780,11 @@ def main(): location = args.original_model_name if use_auth_token else args.input + if args.optimize_optimum: + config = AutoConfig.from_pretrained(args.original_model_name) + optimize_optimum(config, args) + return + # Use CUDA for LLaMA-2-70B to speed up export and CPU for other models l_config, llama = setup_torch_model( args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None