mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Extend gpt-fast LLM dashboard to support torchao autoquant (#140627)
Summary: We want to test autoquant on relevant LLM models right now only llama2 and mixtral, but want to extend to more models like https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models Test Plan: ``` Llama-2-7b-chat-hf Mixtral-8x7B-v0.1 gpt-fast int8 112.98 147.92 torchao autoquant 87.41 85.90 torchao autoquantv2 131.12 79.59 ``` https://hud.pytorch.org/benchmark/llms?repoName=pytorch%2Fpytorch in pytorch/benchmarks/gpt_fast ``` python benchmark.py ``` output: ``` Loading model Llama-2-7b-chat-hf Using int8 weight-only quantization! Time to load model: 2.80 seconds Compilation time: 170.24 seconds Average tokens/sec: 112.98 tokens/sec Average bandwidth achieved: 746.86 GB/s Memory used: 7.95 GB Loading model Mixtral-8x7B-v0.1 Using int8 weight-only quantization! Time to load model: 0.24 seconds Compilation time: 181.81 seconds Average tokens/sec: 147.92 tokens/sec Average bandwidth achieved: 953.06 GB/s Memory used: 32.45 GB Loading model Llama-2-7b-chat-hf Time to load model: 0.11 seconds Using autoquant Compilation time: 109.31 seconds Average tokens/sec: 87.17 tokens/sec Average bandwidth achieved: 1151.86 GB/s Memory used: 32.45 GB Loading model Llama-2-7b-chat-hf Time to load model: 0.11 seconds Compilation time: 48.08 seconds Average tokens/sec: 87.41 tokens/sec Average bandwidth achieved: 1155.05 GB/s Memory used: 36.86 GB Loading model Mixtral-8x7B-v0.1 Time to load model: 0.20 seconds Using autoquant Compilation time: 47.32 seconds Average tokens/sec: 85.90 tokens/sec Average bandwidth achieved: 1106.37 GB/s Memory used: 66.81 GB local test (autoquant v2): Loading model Mixtral-8x7B-v0.1 Compilation time: 124.40 seconds Average tokens/sec: 90.41 tokens/sec Average bandwidth achieved: 1164.47 GB/s Memory used: 53.91 GB Loading model Llama-2-7b-chat-hf TODO ``` gpt_fast_benchmark.csv: ``` name,metric,target,actual,dtype,device,arch,is_model Llama-2-7b-chat-hf,token_per_sec,144,112.98,int8,cuda,NVIDIA PG509-210,True Llama-2-7b-chat-hf,memory_bandwidth(GB/s),957,746.86,int8,cuda,NVIDIA PG509-210,True Llama-2-7b-chat-hf,compilation_time(s),136,170.24,int8,cuda,NVIDIA PG509-210,True Mixtral-8x7B-v0.1,token_per_sec,175,147.92,int8,cuda,NVIDIA PG509-210,True Mixtral-8x7B-v0.1,memory_bandwidth(GB/s),1130,953.06,int8,cuda,NVIDIA PG509-210,True Mixtral-8x7B-v0.1,compilation_time(s),133,181.81,int8,cuda,NVIDIA PG509-210,True gemv,memory_bandwidth(GB/s),870,867.06,int8,cuda,NVIDIA PG509-210,False gemv,memory_bandwidth(GB/s),990,1092.43,bfloat16,cuda,NVIDIA PG509-210,False layer_norm,memory_bandwidth(GB/s),950,573.57,bfloat16,cuda,NVIDIA PG509-210,False Llama-2-7b-chat-hf,token_per_sec,144,87.17,autoquant,cuda,NVIDIA PG509-210,True Llama-2-7b-chat-hf,memory_bandwidth(GB/s),957,1151.86,autoquant,cuda,NVIDIA PG509-210,True Llama-2-7b-chat-hf,compilation_time(s),136,109.31,autoquant,cuda,NVIDIA PG509-210,True gather_gemv,memory_bandwidth(GB/s),990,945.38,int8,cuda,NVIDIA PG509-210,False gather_gemv,memory_bandwidth(GB/s),1060,1188.29,bfloat16,cuda,NVIDIA PG509-210,False mlp_layer_norm_gelu,flops_utilization,0.8,0.82,bfloat16,cuda,NVIDIA PG509-210,False Llama-2-7b-chat-hf,token_per_sec,94,87.41,bfloat16,cuda,NVIDIA PG509-210,True Llama-2-7b-chat-hf,memory_bandwidth(GB/s),1253,1155.05,bfloat16,cuda,NVIDIA PG509-210,True Llama-2-7b-chat-hf,compilation_time(s),133,48.08,bfloat16,cuda,NVIDIA PG509-210,True Mixtral-8x7B-v0.1,token_per_sec,175,85.90,autoquant,cuda,NVIDIA PG509-210,True Mixtral-8x7B-v0.1,memory_bandwidth(GB/s),1130,1106.37,autoquant,cuda,NVIDIA PG509-210,True Mixtral-8x7B-v0.1,compilation_time(s),133,47.32,autoquant,cuda,NVIDIA PG509-210,True ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/140627 Approved by: https://github.com/huydhn
This commit is contained in:
parent
30ab10247d
commit
a962ae511d
5 changed files with 267 additions and 2 deletions
|
|
@ -241,6 +241,12 @@ function checkout_install_torchbench() {
|
|||
popd
|
||||
}
|
||||
|
||||
function install_torchao() {
|
||||
local commit
|
||||
commit=$(get_pinned_commit torchao)
|
||||
pip_install --no-use-pep517 --user "git+https://github.com/pytorch/ao.git@${commit}"
|
||||
}
|
||||
|
||||
function print_sccache_stats() {
|
||||
echo 'PyTorch Build Statistics'
|
||||
sccache --show-stats
|
||||
|
|
|
|||
|
|
@ -615,6 +615,11 @@ test_single_dynamo_benchmark() {
|
|||
}
|
||||
|
||||
test_inductor_micro_benchmark() {
|
||||
# torchao requires cuda 8.0 or above for bfloat16 support
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;8.6"
|
||||
fi
|
||||
install_torchao
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
|
||||
test_inductor_set_cpu_affinity
|
||||
|
|
|
|||
1
.github/ci_commit_pins/torchao.txt
vendored
Normal file
1
.github/ci_commit_pins/torchao.txt
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
51c87b6ead6b7e098ada95d6a7609ee873b854cf
|
||||
|
|
@ -265,9 +265,16 @@ DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
|
|||
|
||||
all_experiments = {
|
||||
# A list of GPT models: LlaMa, Mixtral, etc.
|
||||
# waiting for A100-80G machine to be available in CI
|
||||
# https://github.com/pytorch/pytorch/actions/runs/12018005803/job/33503683582?pr=140627
|
||||
# before we can turn on autoquant
|
||||
# or alterantively, we can save the model after autoquant and just load here to track
|
||||
# the performance
|
||||
# run_llama2_7b_autoquant,
|
||||
run_llama2_7b_bf16,
|
||||
run_llama2_7b_int8,
|
||||
run_mixtral_8x7b_int8,
|
||||
# run_mixtral_8x7b_autoquant,
|
||||
# A list of micro-benchmarks.
|
||||
run_mlp_layer_norm_gelu,
|
||||
run_layer_norm,
|
||||
|
|
@ -286,6 +293,7 @@ def main(output_file=DEFAULT_OUTPUT_FILE):
|
|||
# This happens when torch is compiled with CUDA turning off completely
|
||||
device = "cpu"
|
||||
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
lst = func(device)
|
||||
for x in lst:
|
||||
results.append(dataclasses.astuple(x))
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import platform
|
|||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torchao
|
||||
from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE
|
||||
from mixtral_moe_quantize import (
|
||||
ConditionalFeedForwardInt8,
|
||||
|
|
@ -21,6 +22,8 @@ torch._inductor.config.triton.unique_kernel_names = True
|
|||
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
||||
torch._inductor.config.assert_indirect_indexing = False
|
||||
|
||||
compiled = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GPTModelConfig:
|
||||
|
|
@ -31,6 +34,7 @@ class GPTModelConfig:
|
|||
token_per_sec: float
|
||||
memory_bandwidth: float
|
||||
compilation_time: float
|
||||
batch_size: Optional[int] = None
|
||||
|
||||
|
||||
def device_sync(device):
|
||||
|
|
@ -74,7 +78,6 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
|||
return idx_next, probs
|
||||
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def prefill(
|
||||
model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
|
||||
) -> torch.Tensor:
|
||||
|
|
@ -83,7 +86,6 @@ def prefill(
|
|||
return sample(logits, **sampling_kwargs)[0]
|
||||
|
||||
|
||||
@torch.compile(fullgraph=True, mode="reduce-overhead")
|
||||
def decode_one_token(
|
||||
model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
|
@ -223,9 +225,48 @@ def run_experiment(
|
|||
start = -1
|
||||
compilation_time = None
|
||||
|
||||
if x.mode == "autoquant":
|
||||
print("Using autoquant")
|
||||
model = torchao.autoquant(model, manual=True, error_on_unseen=False)
|
||||
generate(model, prompt, max_new_tokens, temperature=temperature, top_k=top_k)
|
||||
model.finalize_autoquant()
|
||||
|
||||
if x.mode == "autoquant_v2":
|
||||
print("Using autoquant_v2")
|
||||
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
|
||||
|
||||
p = prompt.view(1, -1)
|
||||
T = prompt.size(0)
|
||||
T_new = T + max_new_tokens
|
||||
max_seq_length = min(T_new, model.config.block_size)
|
||||
input_pos = torch.arange(0, T, device=device)
|
||||
example_input = (p, input_pos)
|
||||
|
||||
with torch.device(device):
|
||||
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
|
||||
model = autoquant_v2(
|
||||
model,
|
||||
manual=True,
|
||||
error_on_unseen=False,
|
||||
example_input=example_input,
|
||||
batch_size=x.batch_size,
|
||||
)
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
generate(model, prompt, max_new_tokens, temperature=temperature, top_k=top_k)
|
||||
model.finalize_autoquant()
|
||||
|
||||
global decode_one_token, prefill, compiled
|
||||
if not compiled:
|
||||
compiled = True
|
||||
decode_one_token = torch.compile(
|
||||
decode_one_token, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
prefill = torch.compile(prefill, fullgraph=True)
|
||||
|
||||
for i in range(start, num_samples):
|
||||
device_sync(device=device) # MKG
|
||||
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
t0 = time.perf_counter()
|
||||
y = generate(
|
||||
model, prompt, max_new_tokens, temperature=temperature, top_k=top_k
|
||||
|
|
@ -402,3 +443,207 @@ def run_mixtral_8x7b_int8(device: str = "cuda"):
|
|||
True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
def run_llama2_7b_autoquant(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
model = GPTModelConfig(
|
||||
"Llama-2-7b-chat-hf",
|
||||
LLaMA,
|
||||
"autoquant",
|
||||
None,
|
||||
144,
|
||||
957,
|
||||
136,
|
||||
)
|
||||
token_per_sec, memory_bandwidth, compilation_time = run_experiment(
|
||||
model, device=device
|
||||
)
|
||||
return [
|
||||
Experiment(
|
||||
model.name,
|
||||
"token_per_sec",
|
||||
model.token_per_sec,
|
||||
f"{token_per_sec:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
Experiment(
|
||||
model.name,
|
||||
"memory_bandwidth(GB/s)",
|
||||
model.memory_bandwidth,
|
||||
f"{memory_bandwidth:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
Experiment(
|
||||
model.name,
|
||||
"compilation_time(s)",
|
||||
model.compilation_time,
|
||||
f"{compilation_time:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
def run_mixtral_8x7b_autoquant(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
|
||||
model = GPTModelConfig(
|
||||
"Mixtral-8x7B-v0.1",
|
||||
MixtralMoE,
|
||||
"autoquant",
|
||||
None,
|
||||
175,
|
||||
1130,
|
||||
133,
|
||||
)
|
||||
token_per_sec, memory_bandwidth, compilation_time = run_experiment(
|
||||
model, device=device
|
||||
)
|
||||
return [
|
||||
Experiment(
|
||||
model.name,
|
||||
"token_per_sec",
|
||||
model.token_per_sec,
|
||||
f"{token_per_sec:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
Experiment(
|
||||
model.name,
|
||||
"memory_bandwidth(GB/s)",
|
||||
model.memory_bandwidth,
|
||||
f"{memory_bandwidth:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
Experiment(
|
||||
model.name,
|
||||
"compilation_time(s)",
|
||||
model.compilation_time,
|
||||
f"{compilation_time:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
def run_llama2_7b_autoquant_v2(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
model = GPTModelConfig(
|
||||
"Llama-2-7b-chat-hf",
|
||||
LLaMA,
|
||||
"autoquant_v2",
|
||||
None,
|
||||
144,
|
||||
957,
|
||||
136,
|
||||
6, # batch_size
|
||||
)
|
||||
token_per_sec, memory_bandwidth, compilation_time = run_experiment(
|
||||
model, device=device
|
||||
)
|
||||
return [
|
||||
Experiment(
|
||||
model.name,
|
||||
"token_per_sec",
|
||||
model.token_per_sec,
|
||||
f"{token_per_sec:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
Experiment(
|
||||
model.name,
|
||||
"memory_bandwidth(GB/s)",
|
||||
model.memory_bandwidth,
|
||||
f"{memory_bandwidth:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
Experiment(
|
||||
model.name,
|
||||
"compilation_time(s)",
|
||||
model.compilation_time,
|
||||
f"{compilation_time:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
def run_mixtral_8x7b_autoquant_v2(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
|
||||
model = GPTModelConfig(
|
||||
"Mixtral-8x7B-v0.1",
|
||||
MixtralMoE,
|
||||
"autoquant_v2",
|
||||
None,
|
||||
175,
|
||||
1130,
|
||||
133,
|
||||
6, # batch_size
|
||||
)
|
||||
token_per_sec, memory_bandwidth, compilation_time = run_experiment(
|
||||
model, device=device
|
||||
)
|
||||
return [
|
||||
Experiment(
|
||||
model.name,
|
||||
"token_per_sec",
|
||||
model.token_per_sec,
|
||||
f"{token_per_sec:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
Experiment(
|
||||
model.name,
|
||||
"memory_bandwidth(GB/s)",
|
||||
model.memory_bandwidth,
|
||||
f"{memory_bandwidth:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
Experiment(
|
||||
model.name,
|
||||
"compilation_time(s)",
|
||||
model.compilation_time,
|
||||
f"{compilation_time:.02f}",
|
||||
model.mode,
|
||||
device,
|
||||
get_arch_name(),
|
||||
True,
|
||||
),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue