mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Support large model export using multi-gpu (#17990)
### Description This PR is to implemente a exporter which works for large language models(LLM). It works for models like Llama2-70b or gpt-175. The main idea is to utilize multiple-GPU and dispatch differnet layers to different GPU, in short, it symply implemented auto pipeline parallelism. For example : to export Llama2-70b, you need 8x V100-32GB or 4x A100-80G or More GPU memories. It would expect to export decoder-only models. For encoder-decoder arch-like models, we didn't test it yet. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
parent
444a0eda30
commit
b7ae293be0
1 changed files with 385 additions and 0 deletions
385
onnxruntime/python/tools/transformers/large_model_exporter.py
Normal file
385
onnxruntime/python/tools/transformers/large_model_exporter.py
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Export LLM to onnx
|
||||
"""
|
||||
import argparse
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import transformers
|
||||
from torch import nn
|
||||
|
||||
|
||||
def disable_huggingface_init():
|
||||
"""do not init model twice as it slow initialization"""
|
||||
|
||||
torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x
|
||||
torch.nn.init.uniform_ = lambda x, *args, **kwargs: x
|
||||
torch.nn.init.normal_ = lambda x, *args, **kwargs: x
|
||||
torch.nn.init.constant_ = lambda x, *args, **kwargs: x
|
||||
torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x
|
||||
torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x
|
||||
torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x
|
||||
torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x
|
||||
|
||||
|
||||
def get_model_parameter_size(model: nn.Module):
|
||||
"""to calculate how much memory this model needs"""
|
||||
param_size = 0
|
||||
param_sum = 0
|
||||
for param in model.parameters():
|
||||
param_size += param.nelement() * param.element_size()
|
||||
param_sum += param.nelement()
|
||||
buffer_size = 0
|
||||
buffer_sum = 0
|
||||
for buffer in model.buffers():
|
||||
buffer_size += buffer.nelement() * buffer.element_size()
|
||||
buffer_sum += buffer.nelement()
|
||||
all_size = (param_size + buffer_size) / 1024 / 1024
|
||||
return all_size
|
||||
|
||||
|
||||
def initialize_model_and_sample_inputs(hf_model: str, cache_dir: Optional[str], tokenizer=None):
|
||||
"""
|
||||
get the pretrained torch model from hugginface,
|
||||
and sample model-inputs
|
||||
"""
|
||||
|
||||
disable_huggingface_init()
|
||||
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore
|
||||
hf_model, torch_dtype=torch.float16, cache_dir=cache_dir, trust_remote_code=True
|
||||
)
|
||||
if tokenizer is None:
|
||||
tokenizer = hf_model
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore
|
||||
|
||||
sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values())
|
||||
return model, sample_inputs
|
||||
|
||||
|
||||
def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple):
|
||||
"""Make the model executable across multiple GPUs."""
|
||||
|
||||
def input_gpu_device_hook(mod, inputs, kwargs):
|
||||
modifyed_inputs = []
|
||||
first_dev = None
|
||||
for layer_input in inputs:
|
||||
if type(layer_input) is not torch.Tensor:
|
||||
modifyed_inputs.append(layer_input)
|
||||
elif hasattr(mod, "weight"):
|
||||
modifyed_inputs.append(layer_input.to(mod.weight.device))
|
||||
elif hasattr(mod, "parameters"):
|
||||
device = next(mod.parameters(), layer_input).device
|
||||
modifyed_inputs.append(layer_input.to(device))
|
||||
elif hasattr(next(mod.children(), None), "weight"):
|
||||
modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device))
|
||||
elif first_dev is not None and layer_input.device != first_dev:
|
||||
modifyed_inputs.append(layer_input.to(first_dev))
|
||||
else:
|
||||
modifyed_inputs.append(layer_input)
|
||||
if first_dev is None:
|
||||
first_dev = modifyed_inputs[0].device
|
||||
for key, value in kwargs.items():
|
||||
if type(value) is torch.Tensor:
|
||||
kwargs[key] = value.to(first_dev)
|
||||
|
||||
return (tuple(modifyed_inputs), kwargs)
|
||||
|
||||
def move_layer_to_device_rurc(mod, dev):
|
||||
mod.to(dev)
|
||||
for layer in mod.named_children():
|
||||
move_layer_to_device_rurc(layer[1], dev)
|
||||
|
||||
model = model.half()
|
||||
all_hooks = []
|
||||
all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
|
||||
pre_fix = next(iter(model.named_children()))[0]
|
||||
for top_name, top_module in model.named_children():
|
||||
for name, module in top_module.named_children():
|
||||
all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
|
||||
if type(module) in [torch.nn.ModuleList]:
|
||||
num_layers_on_each_gpu = math.floor(len(module) / len(gpulist))
|
||||
for idx, attn_layer in enumerate(module):
|
||||
all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
|
||||
|
||||
to_dev = gpulist[min(idx // num_layers_on_each_gpu, len(gpulist))]
|
||||
attn_layer.to(to_dev)
|
||||
move_layer_to_device_rurc(attn_layer, to_dev)
|
||||
print(f"move {pre_fix}.{name}.{idx} to {to_dev}")
|
||||
else:
|
||||
module.to(gpulist[0])
|
||||
print(f"move {pre_fix}.{name} to {gpulist[0]}")
|
||||
if len(list(top_module.named_children())) == 0:
|
||||
top_module.to(gpulist[0])
|
||||
print(f"move {top_name} to {gpulist[0]}")
|
||||
|
||||
with torch.no_grad():
|
||||
model(sample_inputs[0], attention_mask=sample_inputs[1])
|
||||
return model
|
||||
|
||||
|
||||
def retrieve_onnx_inputs(model: nn.Module, sample_inputs: tuple, with_past: bool):
|
||||
"""
|
||||
auto retrieve onnx inputs from torch model as we can't enumlate all possibilities
|
||||
for all models
|
||||
"""
|
||||
user_inputs = []
|
||||
|
||||
def hook_for_inputs(_, inputs, kwargs):
|
||||
user_inputs.append((inputs, kwargs))
|
||||
return user_inputs[0]
|
||||
|
||||
hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True)
|
||||
|
||||
forward_params = inspect.signature(model.forward).parameters
|
||||
input_keys = list(forward_params.keys())
|
||||
default_values = [forward_params.get(key).default for key in input_keys]
|
||||
out = model(sample_inputs[0], attention_mask=sample_inputs[1])
|
||||
hook_handle.remove()
|
||||
user_inputs = user_inputs[0]
|
||||
onnx_inputs = default_values
|
||||
for idx, _val in enumerate(user_inputs[0]):
|
||||
onnx_inputs[idx] = user_inputs[0][idx]
|
||||
for key, value in user_inputs[1].items():
|
||||
idx = input_keys.index(key)
|
||||
onnx_inputs[idx] = value
|
||||
for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)):
|
||||
if type(value) is torch.Tensor:
|
||||
value.to(model.device)
|
||||
# Didn't touch past_key_value now, please change it if you want
|
||||
if "use_cache" in key:
|
||||
onnx_inputs[idx] = with_past
|
||||
|
||||
return input_keys, onnx_inputs, out.past_key_values
|
||||
|
||||
|
||||
def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module:
|
||||
"""
|
||||
According to the model size, we will upload it to
|
||||
CPU if has no GPU or enough GPU memory,
|
||||
Single GPU if has only one GPU in local or model size is enough to fit one GPU
|
||||
Multiple GPU if there is more than one gpu in local and model is too large
|
||||
"""
|
||||
total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024
|
||||
|
||||
print(f"Model_Size = {get_model_parameter_size(model)/1024} GB")
|
||||
print(f"total_mem_per_cpu = {total_mem_per_cpu/1024} GB")
|
||||
if get_model_parameter_size(model) > total_mem_per_cpu * 0.45:
|
||||
device_collection = [torch.device(i) for i in range(torch.cuda.device_count())]
|
||||
if len(device_collection) > 1:
|
||||
print(
|
||||
f"{len(device_collection)} GPUs are used to export onnx, \
|
||||
Please set CUDA_VISIBLE_DEVICES to use specific GPU group"
|
||||
)
|
||||
model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp)
|
||||
else:
|
||||
print("!!!! convert model to float and export onnx using CPU")
|
||||
model = model.cpu().float()
|
||||
else:
|
||||
print("Export model on a single GPU")
|
||||
model = model.cuda().half()
|
||||
return model
|
||||
|
||||
|
||||
def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple:
|
||||
"""move inputs to device"""
|
||||
sample_inputs_ = []
|
||||
for sample_int in sample_inputs:
|
||||
if isinstance(sample_int, torch.Tensor):
|
||||
sample_inputs_.append(sample_int.to(device))
|
||||
else:
|
||||
sample_inputs_.append(sample_int)
|
||||
return tuple(sample_inputs_)
|
||||
|
||||
|
||||
def fetch_onnx_inputs_outputs_name(
|
||||
model: nn.Module,
|
||||
onnx_inputs: list,
|
||||
torch_input_names: tuple,
|
||||
past_key_values: tuple,
|
||||
with_past: bool,
|
||||
input_with_past: bool,
|
||||
):
|
||||
"""fetch onnx inputs and outputs name"""
|
||||
num_of_past_key = 0
|
||||
kv_cache_axis = {0: "batch_size"}
|
||||
# try get num_of_past_key and shape of past_key_value
|
||||
if past_key_values is not None:
|
||||
num_of_past_key = len(past_key_values)
|
||||
seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1)
|
||||
assert seq_index.numel() == 1
|
||||
kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"}
|
||||
|
||||
if not num_of_past_key:
|
||||
num_of_past_key = model.config.num_hidden_layers
|
||||
|
||||
onnx_inp_names = ("input_ids", "attention_mask")
|
||||
onnx_out_names = ("logits",)
|
||||
onnx_dynamic_axes = {
|
||||
"input_ids": {0: "batch_size", 1: "seq_len"},
|
||||
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
||||
}
|
||||
if input_with_past:
|
||||
for i in range(num_of_past_key):
|
||||
onnx_inp_names += (f"present_key.{i}",)
|
||||
onnx_inp_names += (f"present_values.{i}",)
|
||||
|
||||
onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis
|
||||
onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis
|
||||
|
||||
if with_past or input_with_past:
|
||||
for i in range(num_of_past_key):
|
||||
onnx_out_names += (f"past_key.{i}",)
|
||||
onnx_out_names += (f"past_values.{i}",)
|
||||
onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis
|
||||
onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis
|
||||
|
||||
for idx, name in enumerate(torch_input_names):
|
||||
if input_with_past:
|
||||
if name == "past_key_values":
|
||||
onnx_inputs[idx] = past_key_values
|
||||
elif name == "attention_mask":
|
||||
attn_mask = onnx_inputs[idx]
|
||||
onnx_inputs[idx] = torch.cat(
|
||||
(attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device)), dim=1
|
||||
)
|
||||
elif name == "input_ids":
|
||||
input_ids = onnx_inputs[idx]
|
||||
onnx_inputs[idx] = input_ids[:, -1:]
|
||||
|
||||
return onnx_inp_names, onnx_out_names, onnx_dynamic_axes
|
||||
|
||||
|
||||
def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int):
|
||||
"""do export with torch.onnx.export"""
|
||||
onnx_model_name = onnx_path.name
|
||||
onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple
|
||||
# two step to export onnx
|
||||
# 1. export onnx with lots of pieces of weights
|
||||
# 2. save all weights to external data
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmp_onnx = os.path.join(tmpdirname, "tmp.onnx")
|
||||
|
||||
torch.onnx.export(
|
||||
model=model,
|
||||
args=tuple(onnx_inputs),
|
||||
f=tmp_onnx,
|
||||
verbose=False,
|
||||
opset_version=opset,
|
||||
input_names=onnx_inp_names,
|
||||
output_names=onnx_out_names,
|
||||
dynamic_axes=onnx_dynamic_axes,
|
||||
)
|
||||
|
||||
onnx_path.unlink(missing_ok=True)
|
||||
(onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True)
|
||||
|
||||
onnx_model = onnx.load(str(tmp_onnx))
|
||||
onnx.save_model(
|
||||
onnx_model,
|
||||
str(onnx_path),
|
||||
save_as_external_data=(len(os.listdir(tmpdirname)) > 1),
|
||||
all_tensors_to_one_file=True,
|
||||
location=f"{onnx_model_name}_ext.data",
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int):
|
||||
"""
|
||||
do export
|
||||
model: torch model
|
||||
onnx_path: where the onnx model saved to
|
||||
sample_inputs_tp: inputs for torch model
|
||||
"""
|
||||
model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir)
|
||||
|
||||
model = move_to_approprate_device(model, sample_inputs_tp)
|
||||
|
||||
sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)
|
||||
|
||||
# input_keys would be usesful if the model has some special inputs
|
||||
input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past)
|
||||
|
||||
onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False)
|
||||
|
||||
onnx_model_name = "model.onnx"
|
||||
onnx_path: Path = Path(onnx_path_str).absolute()
|
||||
if onnx_path.suffix != ".onnx":
|
||||
onnx_path = onnx_path / onnx_model_name
|
||||
|
||||
do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)
|
||||
if not with_past:
|
||||
return
|
||||
|
||||
onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True)
|
||||
|
||||
onnx_model_name = "model_with_past.onnx"
|
||||
onnx_path = onnx_path.parent / onnx_model_name
|
||||
|
||||
do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""arguments parsing."""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
required=True,
|
||||
type=str,
|
||||
default=["meta-llama/Llama-2-70b-hf"],
|
||||
help="Pre-trained models in huggingface model hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--saved_path",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./onnx_models/",
|
||||
help="where the onnx model will be saved",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help=("cache directy of huggingface, by setting this to avoid useless downloading if you have one"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_past",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=("The tool will export onnx without past-key-value by default"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
required=False,
|
||||
type=int,
|
||||
default=17,
|
||||
help=(
|
||||
"the opset to save onnx model, \
|
||||
try to increase it if this opset doens't have new features you want"
|
||||
),
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
|
||||
export_onnx(args.model, args.cache_dir, args.saved_path, args.with_past, args.opset)
|
||||
Loading…
Reference in a new issue