mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Reduce LLaMA memory usage (#18181)
### Description This PR reduces the memory usage when exporting and benchmarking LLaMA. ### Motivation and Context - Exporting: The PyTorch model is deleted from memory after a successful export instead of deleting it from memory after exporting + converting the ONNX model to the desired precision. - Benchmarking: In the ONNX model with GroupQueryAttention, the KV cache inputs use the same GPU memory for both the prompt and token generation benchmarks.
This commit is contained in:
parent
2b95e74fa1
commit
d1b85f5fb4
4 changed files with 256 additions and 194 deletions
|
|
@ -11,9 +11,8 @@ import numpy as np
|
|||
import onnx
|
||||
import psutil
|
||||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from llama_inputs import (
|
||||
convert_inputs_for_ort,
|
||||
add_io_bindings,
|
||||
get_merged_sample_with_past_kv_inputs,
|
||||
get_msft_sample_inputs,
|
||||
get_sample_inputs,
|
||||
|
|
@ -25,7 +24,7 @@ from tqdm import trange
|
|||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.transformers.benchmark_helper import measure_memory
|
||||
from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -48,9 +47,19 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
init_inputs, iter_inputs = None, None
|
||||
|
||||
# For past_present_share_buffer:
|
||||
# Set max_seq_len to 2048 for Hugging Face model since that is the default value
|
||||
# Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported
|
||||
max_seq_len = 2048
|
||||
# Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2)
|
||||
# Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value
|
||||
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
|
||||
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
|
||||
max_seq_len = (
|
||||
2048
|
||||
if args.benchmark_type == "ort-msft"
|
||||
else 16384
|
||||
if "codellama" in temp_name
|
||||
else 4096
|
||||
if "llama2" in temp_name
|
||||
else 2048
|
||||
)
|
||||
|
||||
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
||||
init_inputs = get_sample_inputs(
|
||||
|
|
@ -95,7 +104,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
args.batch_size,
|
||||
seq_len=args.sequence_length,
|
||||
past_seq_len=0,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
engine="pt",
|
||||
return_dict=True,
|
||||
)
|
||||
iter_inputs = get_merged_sample_with_past_kv_inputs(
|
||||
|
|
@ -104,7 +115,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
args.batch_size,
|
||||
seq_len=1,
|
||||
past_seq_len=args.sequence_length,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
engine="pt",
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
|
|
@ -116,7 +129,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
args.batch_size,
|
||||
seq_len=args.sequence_length,
|
||||
past_seq_len=0,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
engine="ort",
|
||||
return_dict=True,
|
||||
)
|
||||
iter_inputs = get_merged_sample_with_past_kv_inputs(
|
||||
|
|
@ -125,27 +140,11 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
args.batch_size,
|
||||
seq_len=1,
|
||||
past_seq_len=args.sequence_length,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
engine="ort",
|
||||
return_dict=True,
|
||||
)
|
||||
init_inputs = convert_inputs_for_ort(
|
||||
init_inputs,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.past_present_share_buffer,
|
||||
past_seq_len=0,
|
||||
max_seq_len=max_seq_len,
|
||||
device=args.device,
|
||||
device_id=args.device_id,
|
||||
)
|
||||
iter_inputs = convert_inputs_for_ort(
|
||||
iter_inputs,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.past_present_share_buffer,
|
||||
past_seq_len=args.sequence_length,
|
||||
max_seq_len=max_seq_len,
|
||||
device=args.device,
|
||||
device_id=args.device_id,
|
||||
)
|
||||
|
||||
elif args.benchmark_type == "ort-msft":
|
||||
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
||||
|
|
@ -156,6 +155,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
args.batch_size,
|
||||
past_seq_len=0,
|
||||
seq_len=args.sequence_length,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
split_kv=split_kv,
|
||||
)
|
||||
|
|
@ -164,27 +164,10 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
args.batch_size,
|
||||
past_seq_len=args.sequence_length,
|
||||
seq_len=1,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
split_kv=split_kv,
|
||||
)
|
||||
init_inputs = convert_inputs_for_ort(
|
||||
init_inputs,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.past_present_share_buffer,
|
||||
past_seq_len=0,
|
||||
max_seq_len=max_seq_len,
|
||||
device=args.device,
|
||||
device_id=args.device_id,
|
||||
)
|
||||
iter_inputs = convert_inputs_for_ort(
|
||||
iter_inputs,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.past_present_share_buffer,
|
||||
past_seq_len=args.sequence_length,
|
||||
max_seq_len=max_seq_len,
|
||||
device=args.device,
|
||||
device_id=args.device_id,
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Unable to auto-detect inputs for provided model")
|
||||
|
|
@ -449,7 +432,7 @@ def run_hf_inference(args, init_inputs, iter_inputs, model):
|
|||
|
||||
|
||||
def run_ort_inference(args, init_inputs, iter_inputs, model):
|
||||
def prepare_ort_inputs(inputs):
|
||||
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
|
||||
# Check that all model inputs will be provided
|
||||
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
|
||||
user_inputs = set(inputs.keys())
|
||||
|
|
@ -467,29 +450,13 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
|
|||
|
||||
# Add IO bindings for non-CPU execution providers
|
||||
if args.device != "cpu":
|
||||
io_binding = model.io_binding()
|
||||
|
||||
for k, v in inputs.items():
|
||||
if args.past_present_share_buffer:
|
||||
# Bind all OrtValue inputs to device
|
||||
io_binding.bind_ortvalue_input(k, v)
|
||||
else:
|
||||
io_binding.bind_cpu_input(k, v)
|
||||
|
||||
for output in model.get_outputs():
|
||||
name = output.name
|
||||
if args.past_present_share_buffer and ("out" in name or "present" in name):
|
||||
# Bind present KV cache outputs to OrtValue with buffer sharing
|
||||
io_binding.bind_ortvalue_output(
|
||||
name, inputs[name.replace("out", "cache").replace("present", "past_key_values")]
|
||||
)
|
||||
else:
|
||||
io_binding.bind_output(name, device_type=args.device, device_id=args.device_id)
|
||||
|
||||
io_binding, kv_cache_ortvalues = add_io_bindings(
|
||||
model, inputs, args.device, int(args.device_id), kv_cache_ortvalues
|
||||
)
|
||||
setattr(args, "io_binding", io_binding) # noqa: B010
|
||||
return io_binding
|
||||
return io_binding, kv_cache_ortvalues
|
||||
|
||||
return inputs
|
||||
return inputs, kv_cache_ortvalues
|
||||
|
||||
def with_io_binding(io_binding):
|
||||
# Inference pass with IO binding
|
||||
|
|
@ -501,9 +468,10 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
|
|||
return outputs
|
||||
|
||||
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
|
||||
kv_cache_ortvalues = {}
|
||||
|
||||
if args.profile:
|
||||
ort_init_inputs = prepare_ort_inputs(init_inputs)
|
||||
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
|
||||
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
|
||||
|
||||
# Turn profiling off to stop appending to log file
|
||||
|
|
@ -513,7 +481,7 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
|
|||
|
||||
# Re-initialize model for new log file instead of appending to old log file
|
||||
model = get_model(args)
|
||||
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
|
||||
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
|
||||
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
|
||||
|
||||
# Turn profiling off to stop appending to log
|
||||
|
|
@ -524,12 +492,12 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
|
|||
|
||||
# ORT evaluations
|
||||
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
|
||||
ort_init_inputs = prepare_ort_inputs(init_inputs)
|
||||
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
|
||||
time_fn(args, generate_fn, ort_init_inputs)
|
||||
measure_fn(args, generate_fn, ort_init_inputs)
|
||||
|
||||
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
|
||||
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
|
||||
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
|
||||
time_fn(args, generate_fn, ort_iter_inputs)
|
||||
measure_fn(args, generate_fn, ort_iter_inputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -716,6 +716,7 @@ def main():
|
|||
run_torchscript_separate_export(args, l_config, llama)
|
||||
else:
|
||||
run_torchscript_merged_export(args, l_config, llama)
|
||||
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
|
||||
|
||||
# Set model paths to store FP32 optimized model
|
||||
decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx")
|
||||
|
|
@ -811,7 +812,6 @@ def main():
|
|||
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
|
||||
remove_existing_model(fp_path)
|
||||
|
||||
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
|
||||
logger.info("Verifying parity on all ONNX models created")
|
||||
|
||||
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
import torch
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from onnxruntime import OrtValue
|
||||
from onnxruntime import InferenceSession, OrtValue
|
||||
|
||||
|
||||
# Get position_ids from attention_mask
|
||||
|
|
@ -12,22 +12,36 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
|
|||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if use_past_kv:
|
||||
# Shape: (batch_size, 1)
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# Shape: (batch_size, sequence_length)
|
||||
return position_ids
|
||||
|
||||
|
||||
# Inputs for first pass to get initial past_key_values
|
||||
# input_ids: (batch_size, sequence_length)
|
||||
# attention_mask: (batch_size, sequence_length)
|
||||
# position_ids: (batch_size, sequence_length)
|
||||
def get_sample_inputs(
|
||||
config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, return_dict: bool = False
|
||||
config: LlamaConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
engine: str = "pt",
|
||||
return_dict: bool = False,
|
||||
):
|
||||
input_ids = torch.randint(
|
||||
low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64
|
||||
)
|
||||
attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64)
|
||||
# position_ids is of shape (batch_size, seq_len)
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
||||
|
||||
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
||||
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
||||
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
|
||||
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
|
||||
|
||||
if not return_dict:
|
||||
# For export
|
||||
return (input_ids, attention_mask, position_ids)
|
||||
|
||||
inputs = {
|
||||
|
|
@ -39,135 +53,131 @@ def get_sample_inputs(
|
|||
|
||||
|
||||
# Inputs for subsequent passes with past_key_values
|
||||
# input_ids: (batch_size, 1)
|
||||
# attention_mask: (batch_size, past_sequence_length + 1)
|
||||
# position_ids: (batch_size, 1)
|
||||
# past_key: (batch_size, num_heads, past_sequence_length, head_size)
|
||||
# past_value: (batch_size, num_heads, past_sequence_length, head_size)
|
||||
def get_sample_with_past_kv_inputs(
|
||||
config: LlamaConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
past_seq_len: int,
|
||||
use_fp16: bool = False,
|
||||
engine: str = "pt",
|
||||
return_dict: bool = False,
|
||||
):
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64)
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
|
||||
# position_ids is of shape (batch_size, 1)
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=True)
|
||||
past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16)
|
||||
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16)
|
||||
|
||||
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
||||
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
||||
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
|
||||
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
|
||||
past_kv = (
|
||||
flatten_past_kv_inputs(past_kv)
|
||||
if engine == "ort"
|
||||
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
# For export
|
||||
assert isinstance(past_kv, list)
|
||||
return (input_ids, attention_mask, position_ids, past_kv)
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_kv,
|
||||
}
|
||||
if engine == "ort":
|
||||
assert isinstance(past_kv, dict)
|
||||
inputs.update(past_kv)
|
||||
else:
|
||||
assert isinstance(past_kv, list)
|
||||
inputs["past_key_values"] = past_kv
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
# Inputs for all passes with past_key_values
|
||||
# input_ids: (batch_size, sequence_length)
|
||||
# attention_mask: (batch_size, past_sequence_length + sequence_length)
|
||||
# position_ids: (batch_size, sequence_length)
|
||||
# past_key: (batch_size, num_heads, kv_sequence_length, head_size)
|
||||
# For models with GQA, kv_sequence_length = max_sequence_length
|
||||
# For models without GQA, kv_sequence_length = past_sequence_length
|
||||
# past_value: (batch_size, num_heads, kv_sequence_length, head_size)
|
||||
# For models with GQA, kv_sequence_length = max_sequence_length
|
||||
# For models without GQA, kv_sequence_length = past_sequence_length
|
||||
def get_merged_sample_with_past_kv_inputs(
|
||||
config: LlamaConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
past_seq_len: int,
|
||||
max_seq_len: int,
|
||||
use_fp16: bool = False,
|
||||
engine: str = "pt",
|
||||
return_dict: bool = False,
|
||||
):
|
||||
input_ids = torch.randint(
|
||||
low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64
|
||||
)
|
||||
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64)
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
|
||||
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
|
||||
past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16)
|
||||
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16)
|
||||
|
||||
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
||||
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
||||
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
|
||||
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
|
||||
past_kv = (
|
||||
flatten_past_kv_inputs(past_kv)
|
||||
if engine == "ort"
|
||||
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
# For export
|
||||
assert isinstance(past_kv, list)
|
||||
return (input_ids, attention_mask, position_ids, past_kv)
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_kv,
|
||||
}
|
||||
if engine == "ort":
|
||||
assert isinstance(past_kv, dict)
|
||||
inputs.update(past_kv)
|
||||
|
||||
if use_fp16: # If model has GQA
|
||||
del inputs["attention_mask"]
|
||||
inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64)
|
||||
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
|
||||
|
||||
else:
|
||||
assert isinstance(past_kv, list)
|
||||
inputs["past_key_values"] = past_kv
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
# Create past_key_values
|
||||
def get_sample_past_kv_inputs(
|
||||
config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool
|
||||
):
|
||||
num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
past_kv = [
|
||||
(
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
return past_kv
|
||||
|
||||
|
||||
# Convert list of past_kv to dict of past_key and past_value
|
||||
def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool):
|
||||
past_kv = {}
|
||||
np_dtype = np.float16 if use_fp16 else np.float32
|
||||
for i, (past_k, past_v) in enumerate(past_key_values):
|
||||
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype)
|
||||
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype)
|
||||
return past_kv
|
||||
|
||||
|
||||
# Format PyTorch inputs to ONNX Runtime inputs
|
||||
def convert_inputs_for_ort(
|
||||
pt_inputs: dict,
|
||||
use_fp16: bool,
|
||||
use_buffer_share: bool = False,
|
||||
past_seq_len: int = 0,
|
||||
max_seq_len: int = 2048,
|
||||
device: str = "",
|
||||
device_id: int = -1,
|
||||
):
|
||||
ort_inputs = {}
|
||||
for k, v in pt_inputs.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
ort_inputs[k] = v
|
||||
elif k == "past_key_values":
|
||||
ort_inputs.update(flatten_past_kv_inputs(v, use_fp16))
|
||||
elif k == "attention_mask" and use_fp16 and use_buffer_share:
|
||||
# Skip because FP16 model has GroupQueryAttention, uses buffer sharing,
|
||||
# and GQA supports a causal mask by default
|
||||
|
||||
# Instead, add the past sequence length input for GQA
|
||||
ort_inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64)
|
||||
else:
|
||||
ort_inputs[k] = v.detach().cpu().numpy()
|
||||
|
||||
# Enable past-present-share-buffer by using device memory directly
|
||||
if use_buffer_share and device != "" and device != "cpu" and device_id > -1:
|
||||
for k, v in ort_inputs.items():
|
||||
new_v = v
|
||||
# Allocate new buffers with max_sequence_length for GQA
|
||||
if "cache" in k or "past_key_values" in k:
|
||||
# Copy v (BxSxPxH) into new_v (BxSxMxH)
|
||||
batch_size, num_heads, _, head_size = v.shape
|
||||
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
|
||||
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
|
||||
ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id)
|
||||
|
||||
return ort_inputs
|
||||
|
||||
|
||||
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
||||
def get_msft_sample_inputs(
|
||||
config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool
|
||||
config: LlamaConfig,
|
||||
batch_size: int,
|
||||
past_seq_len: int,
|
||||
seq_len: int,
|
||||
max_seq_len: int,
|
||||
use_fp16: bool,
|
||||
split_kv: bool,
|
||||
):
|
||||
np_dtype = np.float16 if use_fp16 else np.float32
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
max_seq_len = 2048
|
||||
|
||||
if not split_kv:
|
||||
ort_inputs = {
|
||||
|
|
@ -201,4 +211,111 @@ def get_msft_sample_inputs(
|
|||
}
|
||||
)
|
||||
|
||||
if use_fp16: # If model has GQA
|
||||
del ort_inputs["attn_mask"]
|
||||
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
|
||||
|
||||
return ort_inputs
|
||||
|
||||
|
||||
# Create past_key_values
|
||||
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
|
||||
def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool):
|
||||
num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
past_kv = [
|
||||
(
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
return past_kv
|
||||
|
||||
|
||||
# Convert list of past_key_values to dict of past_key and past_value
|
||||
def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]):
|
||||
past_kv = {}
|
||||
for i, (past_k, past_v) in enumerate(past_key_values):
|
||||
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
|
||||
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
|
||||
return past_kv
|
||||
|
||||
|
||||
# Format PyTorch inputs to ONNX Runtime inputs
|
||||
def convert_inputs_for_ort(
|
||||
pt_inputs: dict,
|
||||
use_fp16: bool,
|
||||
use_buffer_share: bool = False,
|
||||
past_seq_len: int = 0,
|
||||
max_seq_len: int = 2048,
|
||||
device: str = "",
|
||||
device_id: int = -1,
|
||||
):
|
||||
ort_inputs = {}
|
||||
for k, v in pt_inputs.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
ort_inputs[k] = v
|
||||
elif k == "past_key_values":
|
||||
ort_inputs.update(flatten_past_kv_inputs(v))
|
||||
elif k == "attention_mask" and use_fp16 and use_buffer_share:
|
||||
# Skip because FP16 model has GroupQueryAttention, uses buffer sharing,
|
||||
# and GQA supports a causal mask by default
|
||||
|
||||
# Instead, add the past sequence length input for GQA
|
||||
ort_inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64)
|
||||
else:
|
||||
ort_inputs[k] = v.detach().cpu().numpy()
|
||||
|
||||
# Reshape kv caches if using past-present-share-buffer
|
||||
if use_buffer_share and device != "" and device != "cpu" and device_id > -1:
|
||||
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
|
||||
|
||||
return ort_inputs
|
||||
|
||||
|
||||
def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
|
||||
for k, v in ort_inputs.items():
|
||||
# Allocate new buffers with max_sequence_length for GQA
|
||||
if "cache" in k or "past_key_values" in k:
|
||||
# Copy v (BxSxPxH) into new_v (BxSxMxH)
|
||||
batch_size, num_heads, _, head_size = v.shape
|
||||
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
|
||||
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
|
||||
ort_inputs[k] = new_v
|
||||
return ort_inputs
|
||||
|
||||
|
||||
# Add IO bindings for execution providers
|
||||
def add_io_bindings(model: InferenceSession, ort_inputs: dict, device: str, device_id: int, kv_cache_ortvalues: dict):
|
||||
use_fp16 = False
|
||||
io_binding = model.io_binding()
|
||||
|
||||
for k, v in ort_inputs.items():
|
||||
# Detect if model is in FP16
|
||||
if v.dtype == np.float16:
|
||||
use_fp16 = True
|
||||
|
||||
# Bind OrtValue inputs to device
|
||||
if use_fp16 and ("cache" in k or "past_key_values" in k):
|
||||
if k not in kv_cache_ortvalues:
|
||||
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
|
||||
io_binding.bind_ortvalue_input(k, v_device)
|
||||
kv_cache_ortvalues[k] = v_device
|
||||
else:
|
||||
kv_cache_ortvalues[k].update_inplace(v)
|
||||
io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
|
||||
else:
|
||||
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
|
||||
io_binding.bind_ortvalue_input(k, v_device)
|
||||
|
||||
for output in model.get_outputs():
|
||||
name = output.name
|
||||
if use_fp16 and ("out" in name or "present" in name):
|
||||
# Bind present KV cache outputs to past KV cache inputs in order to buffer share
|
||||
input_name = name.replace("out", "cache").replace("present", "past_key_values")
|
||||
io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
|
||||
else:
|
||||
io_binding.bind_output(name, device_type=device, device_id=device_id)
|
||||
|
||||
return io_binding, kv_cache_ortvalues
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from llama_inputs import (
|
||||
add_io_bindings,
|
||||
convert_inputs_for_ort,
|
||||
get_merged_sample_with_past_kv_inputs,
|
||||
get_sample_inputs,
|
||||
|
|
@ -22,22 +23,24 @@ logger = logging.getLogger("")
|
|||
|
||||
def get_sequence_lengths(args: argparse.Namespace):
|
||||
past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
|
||||
max_sequence_length = 2048
|
||||
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
|
||||
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
|
||||
return past_sequence_length, curr_sequence_length, max_sequence_length
|
||||
|
||||
|
||||
def get_inputs(args: argparse.Namespace, config: LlamaConfig):
|
||||
# Dummy values for parity
|
||||
batch_size = 2
|
||||
past_sequence_length, sequence_length, _ = get_sequence_lengths(args)
|
||||
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args)
|
||||
|
||||
if args.merged:
|
||||
inputs = get_merged_sample_with_past_kv_inputs(
|
||||
config,
|
||||
args.device,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
past_sequence_length,
|
||||
seq_len=sequence_length,
|
||||
past_seq_len=past_sequence_length,
|
||||
max_seq_len=max_sequence_length,
|
||||
use_fp16=args.use_fp16,
|
||||
return_dict=True,
|
||||
)
|
||||
|
|
@ -51,31 +54,7 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig):
|
|||
return inputs
|
||||
|
||||
|
||||
def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict):
|
||||
# Add IO bindings for non-CPU execution providers
|
||||
io_binding = model.io_binding()
|
||||
|
||||
for k, v in inputs.items():
|
||||
if args.use_fp16:
|
||||
# Bind all OrtValue inputs to device
|
||||
io_binding.bind_ortvalue_input(k, v)
|
||||
else:
|
||||
io_binding.bind_cpu_input(k, v)
|
||||
|
||||
for output in model.get_outputs():
|
||||
name = output.name
|
||||
if args.use_fp16 and ("out" in name or "present" in name):
|
||||
# Bind present KV cache outputs to OrtValue with buffer sharing
|
||||
io_binding.bind_ortvalue_output(
|
||||
name, inputs[name.replace("out", "cache").replace("present", "past_key_values")]
|
||||
)
|
||||
else:
|
||||
io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id))
|
||||
|
||||
return io_binding
|
||||
|
||||
|
||||
def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM):
|
||||
def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM, kv_cache_ortvalues: dict):
|
||||
inputs = get_inputs(args, config)
|
||||
|
||||
# Run inference with PyTorch
|
||||
|
|
@ -111,7 +90,9 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
|
|||
|
||||
# Add IO bindings for non-CPU execution providers
|
||||
if args.execution_provider != "cpu":
|
||||
io_binding = add_io_bindings(args, ort_model, inputs)
|
||||
io_binding, kv_cache_ortvalues = add_io_bindings(
|
||||
ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues
|
||||
)
|
||||
|
||||
io_binding.synchronize_inputs()
|
||||
start_time = time.time()
|
||||
|
|
@ -131,17 +112,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
|
|||
logger.info(f"ONNX Runtime took {end_time - start_time} s")
|
||||
|
||||
# Compare PyTorch and ONNX Runtime accuracy
|
||||
tol = (
|
||||
2e1
|
||||
if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path
|
||||
else 1e-3
|
||||
if args.precision == "fp32"
|
||||
else 5e-1
|
||||
)
|
||||
tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1
|
||||
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
|
||||
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
|
||||
if not parity:
|
||||
logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
|
||||
return kv_cache_ortvalues
|
||||
|
||||
|
||||
def get_args(argv: List[str]):
|
||||
|
|
@ -250,16 +226,17 @@ def main(argv: List[str] = []): # noqa: B006
|
|||
use_cache=True,
|
||||
).to(args.device)
|
||||
|
||||
kv_cache_ortvalues = {}
|
||||
if not args.merged:
|
||||
verify_parity(args, config, llama)
|
||||
verify_parity(args, config, llama, kv_cache_ortvalues)
|
||||
else:
|
||||
# Verify prompt generation in merged model (decoder_model.onnx)
|
||||
args.use_past_kv = False
|
||||
verify_parity(args, config, llama)
|
||||
kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues)
|
||||
|
||||
# Verify token generation in merged model (decoder_with_past_model.onnx)
|
||||
args.use_past_kv = True
|
||||
verify_parity(args, config, llama)
|
||||
verify_parity(args, config, llama, kv_cache_ortvalues)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue