Reduce LLaMA memory usage (#18181)

### Description
This PR reduces the memory usage when exporting and benchmarking LLaMA.



### Motivation and Context
- Exporting: The PyTorch model is deleted from memory after a successful
export instead of deleting it from memory after exporting + converting
the ONNX model to the desired precision.
- Benchmarking: In the ONNX model with GroupQueryAttention, the KV cache
inputs use the same GPU memory for both the prompt and token generation
benchmarks.
This commit is contained in:
kunal-vaishnavi 2023-10-31 17:53:52 -07:00 committed by GitHub
parent 2b95e74fa1
commit d1b85f5fb4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 256 additions and 194 deletions

View file

@ -11,9 +11,8 @@ import numpy as np
import onnx
import psutil
import torch
from benchmark_helper import setup_logger
from llama_inputs import (
convert_inputs_for_ort,
add_io_bindings,
get_merged_sample_with_past_kv_inputs,
get_msft_sample_inputs,
get_sample_inputs,
@ -25,7 +24,7 @@ from tqdm import trange
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import onnxruntime as ort
from onnxruntime.transformers.benchmark_helper import measure_memory
from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger
logger = logging.getLogger(__name__)
@ -48,9 +47,19 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
init_inputs, iter_inputs = None, None
# For past_present_share_buffer:
# Set max_seq_len to 2048 for Hugging Face model since that is the default value
# Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported
max_seq_len = 2048
# Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2)
# Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_seq_len = (
2048
if args.benchmark_type == "ort-msft"
else 16384
if "codellama" in temp_name
else 4096
if "llama2" in temp_name
else 2048
)
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
init_inputs = get_sample_inputs(
@ -95,7 +104,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="pt",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
@ -104,7 +115,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="pt",
return_dict=True,
)
@ -116,7 +129,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="ort",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
@ -125,27 +140,11 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="ort",
return_dict=True,
)
init_inputs = convert_inputs_for_ort(
init_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=0,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
)
iter_inputs = convert_inputs_for_ort(
iter_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
)
elif args.benchmark_type == "ort-msft":
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
@ -156,6 +155,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
split_kv=split_kv,
)
@ -164,27 +164,10 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
past_seq_len=args.sequence_length,
seq_len=1,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
split_kv=split_kv,
)
init_inputs = convert_inputs_for_ort(
init_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=0,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
)
iter_inputs = convert_inputs_for_ort(
iter_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
)
else:
raise Exception("Unable to auto-detect inputs for provided model")
@ -449,7 +432,7 @@ def run_hf_inference(args, init_inputs, iter_inputs, model):
def run_ort_inference(args, init_inputs, iter_inputs, model):
def prepare_ort_inputs(inputs):
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(inputs.keys())
@ -467,29 +450,13 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding = model.io_binding()
for k, v in inputs.items():
if args.past_present_share_buffer:
# Bind all OrtValue inputs to device
io_binding.bind_ortvalue_input(k, v)
else:
io_binding.bind_cpu_input(k, v)
for output in model.get_outputs():
name = output.name
if args.past_present_share_buffer and ("out" in name or "present" in name):
# Bind present KV cache outputs to OrtValue with buffer sharing
io_binding.bind_ortvalue_output(
name, inputs[name.replace("out", "cache").replace("present", "past_key_values")]
)
else:
io_binding.bind_output(name, device_type=args.device, device_id=args.device_id)
io_binding, kv_cache_ortvalues = add_io_bindings(
model, inputs, args.device, int(args.device_id), kv_cache_ortvalues
)
setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding
return io_binding, kv_cache_ortvalues
return inputs
return inputs, kv_cache_ortvalues
def with_io_binding(io_binding):
# Inference pass with IO binding
@ -501,9 +468,10 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
return outputs
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
kv_cache_ortvalues = {}
if args.profile:
ort_init_inputs = prepare_ort_inputs(init_inputs)
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
# Turn profiling off to stop appending to log file
@ -513,7 +481,7 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
# Re-initialize model for new log file instead of appending to old log file
model = get_model(args)
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
# Turn profiling off to stop appending to log
@ -524,12 +492,12 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
# ORT evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
ort_init_inputs = prepare_ort_inputs(init_inputs)
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_init_inputs)
measure_fn(args, generate_fn, ort_init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_iter_inputs)
measure_fn(args, generate_fn, ort_iter_inputs)

View file

@ -716,6 +716,7 @@ def main():
run_torchscript_separate_export(args, l_config, llama)
else:
run_torchscript_merged_export(args, l_config, llama)
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
# Set model paths to store FP32 optimized model
decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx")
@ -811,7 +812,6 @@ def main():
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
remove_existing_model(fp_path)
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
logger.info("Verifying parity on all ONNX models created")
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models

View file

@ -4,7 +4,7 @@ import numpy as np
import torch
from transformers import LlamaConfig
from onnxruntime import OrtValue
from onnxruntime import InferenceSession, OrtValue
# Get position_ids from attention_mask
@ -12,22 +12,36 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if use_past_kv:
# Shape: (batch_size, 1)
position_ids = position_ids[:, -1].unsqueeze(-1)
# Shape: (batch_size, sequence_length)
return position_ids
# Inputs for first pass to get initial past_key_values
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, sequence_length)
# position_ids: (batch_size, sequence_length)
def get_sample_inputs(
config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, return_dict: bool = False
config: LlamaConfig,
device: torch.device,
batch_size: int,
seq_len: int,
engine: str = "pt",
return_dict: bool = False,
):
input_ids = torch.randint(
low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64
)
attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len)
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
position_ids = get_position_ids(attention_mask, use_past_kv=False)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
if not return_dict:
# For export
return (input_ids, attention_mask, position_ids)
inputs = {
@ -39,135 +53,131 @@ def get_sample_inputs(
# Inputs for subsequent passes with past_key_values
# input_ids: (batch_size, 1)
# attention_mask: (batch_size, past_sequence_length + 1)
# position_ids: (batch_size, 1)
# past_key: (batch_size, num_heads, past_sequence_length, head_size)
# past_value: (batch_size, num_heads, past_sequence_length, head_size)
def get_sample_with_past_kv_inputs(
config: LlamaConfig,
device: torch.device,
batch_size: int,
past_seq_len: int,
use_fp16: bool = False,
engine: str = "pt",
return_dict: bool = False,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64)
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
# position_ids is of shape (batch_size, 1)
position_ids = get_position_ids(attention_mask, use_past_kv=True)
past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
past_kv = (
flatten_past_kv_inputs(past_kv)
if engine == "ort"
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
)
if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, attention_mask, position_ids, past_kv)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_kv,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv
return inputs
# Inputs for all passes with past_key_values
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, past_sequence_length + sequence_length)
# position_ids: (batch_size, sequence_length)
# past_key: (batch_size, num_heads, kv_sequence_length, head_size)
# For models with GQA, kv_sequence_length = max_sequence_length
# For models without GQA, kv_sequence_length = past_sequence_length
# past_value: (batch_size, num_heads, kv_sequence_length, head_size)
# For models with GQA, kv_sequence_length = max_sequence_length
# For models without GQA, kv_sequence_length = past_sequence_length
def get_merged_sample_with_past_kv_inputs(
config: LlamaConfig,
device: torch.device,
batch_size: int,
seq_len: int,
past_seq_len: int,
max_seq_len: int,
use_fp16: bool = False,
engine: str = "pt",
return_dict: bool = False,
):
input_ids = torch.randint(
low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64
)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64)
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
past_kv = (
flatten_past_kv_inputs(past_kv)
if engine == "ort"
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
)
if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, attention_mask, position_ids, past_kv)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_kv,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
if use_fp16: # If model has GQA
del inputs["attention_mask"]
inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64)
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv
return inputs
# Create past_key_values
def get_sample_past_kv_inputs(
config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool
):
num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
)
for _ in range(config.num_hidden_layers)
]
return past_kv
# Convert list of past_kv to dict of past_key and past_value
def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool):
past_kv = {}
np_dtype = np.float16 if use_fp16 else np.float32
for i, (past_k, past_v) in enumerate(past_key_values):
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype)
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype)
return past_kv
# Format PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
pt_inputs: dict,
use_fp16: bool,
use_buffer_share: bool = False,
past_seq_len: int = 0,
max_seq_len: int = 2048,
device: str = "",
device_id: int = -1,
):
ort_inputs = {}
for k, v in pt_inputs.items():
if isinstance(v, np.ndarray):
ort_inputs[k] = v
elif k == "past_key_values":
ort_inputs.update(flatten_past_kv_inputs(v, use_fp16))
elif k == "attention_mask" and use_fp16 and use_buffer_share:
# Skip because FP16 model has GroupQueryAttention, uses buffer sharing,
# and GQA supports a causal mask by default
# Instead, add the past sequence length input for GQA
ort_inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64)
else:
ort_inputs[k] = v.detach().cpu().numpy()
# Enable past-present-share-buffer by using device memory directly
if use_buffer_share and device != "" and device != "cpu" and device_id > -1:
for k, v in ort_inputs.items():
new_v = v
# Allocate new buffers with max_sequence_length for GQA
if "cache" in k or "past_key_values" in k:
# Copy v (BxSxPxH) into new_v (BxSxMxH)
batch_size, num_heads, _, head_size = v.shape
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id)
return ort_inputs
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
def get_msft_sample_inputs(
config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool
config: LlamaConfig,
batch_size: int,
past_seq_len: int,
seq_len: int,
max_seq_len: int,
use_fp16: bool,
split_kv: bool,
):
np_dtype = np.float16 if use_fp16 else np.float32
head_size = config.hidden_size // config.num_attention_heads
max_seq_len = 2048
if not split_kv:
ort_inputs = {
@ -201,4 +211,111 @@ def get_msft_sample_inputs(
}
)
if use_fp16: # If model has GQA
del ort_inputs["attn_mask"]
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
return ort_inputs
# Create past_key_values
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool):
num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
)
for _ in range(config.num_hidden_layers)
]
return past_kv
# Convert list of past_key_values to dict of past_key and past_value
def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]):
past_kv = {}
for i, (past_k, past_v) in enumerate(past_key_values):
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
return past_kv
# Format PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
pt_inputs: dict,
use_fp16: bool,
use_buffer_share: bool = False,
past_seq_len: int = 0,
max_seq_len: int = 2048,
device: str = "",
device_id: int = -1,
):
ort_inputs = {}
for k, v in pt_inputs.items():
if isinstance(v, np.ndarray):
ort_inputs[k] = v
elif k == "past_key_values":
ort_inputs.update(flatten_past_kv_inputs(v))
elif k == "attention_mask" and use_fp16 and use_buffer_share:
# Skip because FP16 model has GroupQueryAttention, uses buffer sharing,
# and GQA supports a causal mask by default
# Instead, add the past sequence length input for GQA
ort_inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64)
else:
ort_inputs[k] = v.detach().cpu().numpy()
# Reshape kv caches if using past-present-share-buffer
if use_buffer_share and device != "" and device != "cpu" and device_id > -1:
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
return ort_inputs
def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
for k, v in ort_inputs.items():
# Allocate new buffers with max_sequence_length for GQA
if "cache" in k or "past_key_values" in k:
# Copy v (BxSxPxH) into new_v (BxSxMxH)
batch_size, num_heads, _, head_size = v.shape
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
ort_inputs[k] = new_v
return ort_inputs
# Add IO bindings for execution providers
def add_io_bindings(model: InferenceSession, ort_inputs: dict, device: str, device_id: int, kv_cache_ortvalues: dict):
use_fp16 = False
io_binding = model.io_binding()
for k, v in ort_inputs.items():
# Detect if model is in FP16
if v.dtype == np.float16:
use_fp16 = True
# Bind OrtValue inputs to device
if use_fp16 and ("cache" in k or "past_key_values" in k):
if k not in kv_cache_ortvalues:
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
io_binding.bind_ortvalue_input(k, v_device)
kv_cache_ortvalues[k] = v_device
else:
kv_cache_ortvalues[k].update_inplace(v)
io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
else:
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
io_binding.bind_ortvalue_input(k, v_device)
for output in model.get_outputs():
name = output.name
if use_fp16 and ("out" in name or "present" in name):
# Bind present KV cache outputs to past KV cache inputs in order to buffer share
input_name = name.replace("out", "cache").replace("present", "past_key_values")
io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
else:
io_binding.bind_output(name, device_type=device, device_id=device_id)
return io_binding, kv_cache_ortvalues

View file

@ -8,6 +8,7 @@ import numpy as np
import torch
from benchmark_helper import setup_logger
from llama_inputs import (
add_io_bindings,
convert_inputs_for_ort,
get_merged_sample_with_past_kv_inputs,
get_sample_inputs,
@ -22,22 +23,24 @@ logger = logging.getLogger("")
def get_sequence_lengths(args: argparse.Namespace):
past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
max_sequence_length = 2048
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
return past_sequence_length, curr_sequence_length, max_sequence_length
def get_inputs(args: argparse.Namespace, config: LlamaConfig):
# Dummy values for parity
batch_size = 2
past_sequence_length, sequence_length, _ = get_sequence_lengths(args)
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args)
if args.merged:
inputs = get_merged_sample_with_past_kv_inputs(
config,
args.device,
batch_size,
sequence_length,
past_sequence_length,
seq_len=sequence_length,
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
@ -51,31 +54,7 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig):
return inputs
def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict):
# Add IO bindings for non-CPU execution providers
io_binding = model.io_binding()
for k, v in inputs.items():
if args.use_fp16:
# Bind all OrtValue inputs to device
io_binding.bind_ortvalue_input(k, v)
else:
io_binding.bind_cpu_input(k, v)
for output in model.get_outputs():
name = output.name
if args.use_fp16 and ("out" in name or "present" in name):
# Bind present KV cache outputs to OrtValue with buffer sharing
io_binding.bind_ortvalue_output(
name, inputs[name.replace("out", "cache").replace("present", "past_key_values")]
)
else:
io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id))
return io_binding
def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM):
def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM, kv_cache_ortvalues: dict):
inputs = get_inputs(args, config)
# Run inference with PyTorch
@ -111,7 +90,9 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
# Add IO bindings for non-CPU execution providers
if args.execution_provider != "cpu":
io_binding = add_io_bindings(args, ort_model, inputs)
io_binding, kv_cache_ortvalues = add_io_bindings(
ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues
)
io_binding.synchronize_inputs()
start_time = time.time()
@ -131,17 +112,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
logger.info(f"ONNX Runtime took {end_time - start_time} s")
# Compare PyTorch and ONNX Runtime accuracy
tol = (
2e1
if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path
else 1e-3
if args.precision == "fp32"
else 5e-1
)
tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
if not parity:
logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
return kv_cache_ortvalues
def get_args(argv: List[str]):
@ -250,16 +226,17 @@ def main(argv: List[str] = []): # noqa: B006
use_cache=True,
).to(args.device)
kv_cache_ortvalues = {}
if not args.merged:
verify_parity(args, config, llama)
verify_parity(args, config, llama, kv_cache_ortvalues)
else:
# Verify prompt generation in merged model (decoder_model.onnx)
args.use_past_kv = False
verify_parity(args, config, llama)
kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues)
# Verify token generation in merged model (decoder_with_past_model.onnx)
args.use_past_kv = True
verify_parity(args, config, llama)
verify_parity(args, config, llama, kv_cache_ortvalues)
if __name__ == "__main__":