onnxruntime/onnxruntime/test/python/transformers/test_gqa_cpu.py
aciddelgado 94c69f55d4
GQA 4 CPU (#20299)
### Description
Support GQA operator on CPU with FP32.



### Motivation and Context
Right now, models generated for CPU and GPU must be different. GQA CPU
allows these models to be the same.
2024-04-22 19:57:05 -07:00

1884 lines
65 KiB
Python

# --------------------------------------------------------------------------
# Copyright 2020 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import math
import random
import unittest
import numpy
import torch
from bert_padding import pad_input, unpad_input
try:
from colorama import Fore, init
init(autoreset=True)
except ImportError:
print("colorama is not installed, please install it to get prettier output")
Fore = None
from einops import rearrange, repeat
from onnx import TensorProto, helper
from onnxruntime import InferenceSession, OrtValue, SessionOptions
torch.manual_seed(0)
pipeline_mode = True # Reduces number of tests so pipeline doesn't time out
class Formats:
BSNH = 0
BNSH = 1
class Config:
batch_size = 0
sequence_length = 0
kv_sequence_length = 0
past_sequence_length = 0
num_heads = 0
kv_num_heads = 0
head_size = 0
def __init__(self, b, s, s2, sp, n, n2, h):
self.batch_size = b
self.sequence_length = s
self.kv_sequence_length = s2
self.past_sequence_length = sp
self.num_heads = n
self.kv_num_heads = n2
self.head_size = h
class PromptConfig:
batch_size = 0
q_sequence_length = 0
kv_sequence_length = 0
buffer_sequence_length = 0
num_heads = 0
kv_num_heads = 0
head_size = 0
def __init__(self, b, sq, skv, sb, n, n2, h):
self.batch_size = b
self.q_sequence_length = sq
self.kv_sequence_length = skv
self.buffer_sequence_length = sb
self.num_heads = n
self.kv_num_heads = n2
self.head_size = h
# LLaMA Microsoft model
class LlamaMSRotaryEmbedding(torch.nn.Module):
def __init__(self):
super().__init__()
def rotate_tensor(
self,
x: torch.Tensor, # BxSxNxH
cos: torch.Tensor, # 1xSx1x(H/2)
sin: torch.Tensor, # 1xSx1x(H/2)
pos: torch.Tensor,
interleaved: bool,
):
# Dimension of x is [batch_size, seq_len, n_heads, head_dim]
rot_dim = 2 * cos.shape[3]
# Dolly requires partial rotation
x_rot = x[:, :, :, :rot_dim]
if interleaved:
x1 = x_rot[:, :, :, 0::2]
x2 = x_rot[:, :, :, 1::2]
else:
half = x_rot.shape[-1] // 2
x1 = x[:, :, :, 0:half]
x2 = x[:, :, :, half : 2 * half]
seq_len = x.shape[1]
# cos_x: (1, S, 1, H/2)
# sin_x: (1, S, 1, H/2)
# x1: (B, S, N, H/2)
# x2: (B, S, N, H/2)
if seq_len == 1:
batch_size = x.shape[0]
pos_i = pos.unsqueeze(1).unsqueeze(2).unsqueeze(3).long()
cos_x = cos.expand(batch_size, -1, -1, -1)
sin_x = sin.expand(batch_size, -1, -1, -1)
cos_x = cos_x.gather(1, pos_i.expand(-1, -1, cos.shape[2], cos.shape[3]))
sin_x = sin_x.gather(1, pos_i.expand(-1, -1, sin.shape[2], sin.shape[3]))
real = cos_x * x1 - sin_x * x2
imag = sin_x * x1 + cos_x * x2
if interleaved:
x_rot[:, :, :, 0::2] = real
x_rot[:, :, :, 1::2] = imag
else:
x_rot = torch.cat((real, imag), dim=-1)
else:
cos_x = cos[:, 0:seq_len, :, :]
sin_x = sin[:, 0:seq_len, :, :]
real = cos_x * x1 - sin_x * x2
imag = sin_x * x1 + cos_x * x2
if interleaved:
x_rot[:, :, :, 0::2] = real
x_rot[:, :, :, 1::2] = imag
else:
x_rot = torch.cat((real, imag), dim=-1)
return torch.cat((x_rot, x[:, :, :, rot_dim:]), dim=-1)
def forward(self, x, cos, sin, pos, interleaved):
return self.rotate_tensor(x, cos, sin, pos, interleaved)
def create_group_query_attention_graph_prompt(
config,
past_kv_format=Formats.BSNH,
share_buffer=True,
local_window_size=-1,
rotary=False,
rotary_interleaved=False,
packed=False,
):
past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0
present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length
nodes = [
helper.make_node(
"GroupQueryAttention",
[
"query",
"key" if not packed else "",
"value" if not packed else "",
"past_key" if share_buffer else "",
"past_value" if share_buffer else "",
"seqlens_k",
"total_sequence_length",
"cos_cache" if rotary else "",
"sin_cache" if rotary else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
local_window_size=local_window_size,
do_rotary=rotary,
rotary_interleaved=rotary_interleaved,
# is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0,
# kv_share_buffer=1 if share_buffer else 0,
domain="com.microsoft",
),
]
graph_input = [
helper.make_tensor_value_info(
"query",
TensorProto.FLOAT,
[
config.batch_size,
config.q_sequence_length,
(
(config.num_heads * config.head_size)
if not packed
else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size)
),
],
),
helper.make_tensor_value_info(
"seqlens_k",
TensorProto.INT32,
[config.batch_size],
),
helper.make_tensor_value_info(
"total_sequence_length",
TensorProto.INT32,
[1],
),
]
if not packed:
graph_input += [
helper.make_tensor_value_info(
"key",
TensorProto.FLOAT,
[
config.batch_size,
config.kv_sequence_length,
config.kv_num_heads * config.head_size,
],
),
helper.make_tensor_value_info(
"value",
TensorProto.FLOAT,
[
config.batch_size,
config.kv_sequence_length,
config.kv_num_heads * config.head_size,
],
),
]
if share_buffer:
graph_input += [
helper.make_tensor_value_info(
"past_key",
TensorProto.FLOAT,
[
config.batch_size,
past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen,
config.head_size,
],
),
helper.make_tensor_value_info(
"past_value",
TensorProto.FLOAT,
[
config.batch_size,
past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen,
config.head_size,
],
),
]
if rotary:
graph_input += [
helper.make_tensor_value_info(
"cos_cache",
TensorProto.FLOAT,
[
config.buffer_sequence_length if share_buffer else config.kv_sequence_length,
(math.floor(config.head_size / 16) * 16) // 2,
],
),
helper.make_tensor_value_info(
"sin_cache",
TensorProto.FLOAT,
[
config.buffer_sequence_length if share_buffer else config.kv_sequence_length,
(math.floor(config.head_size / 16) * 16) // 2,
],
),
]
graph_output = [
helper.make_tensor_value_info(
"output",
TensorProto.FLOAT,
[config.batch_size, config.q_sequence_length, config.num_heads * config.head_size],
),
helper.make_tensor_value_info(
"present_key",
TensorProto.FLOAT,
[
config.batch_size,
present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen,
config.head_size,
],
),
helper.make_tensor_value_info(
"present_value",
TensorProto.FLOAT,
[
config.batch_size,
present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen,
config.head_size,
],
),
helper.make_tensor_value_info(
"present_key",
TensorProto.FLOAT,
[
config.batch_size,
config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
],
),
helper.make_tensor_value_info(
"present_value",
TensorProto.FLOAT,
[
config.batch_size,
config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
],
),
]
graph = helper.make_graph(
nodes,
"GroupQueryAttention_Graph",
graph_input,
graph_output,
)
model = helper.make_model(graph)
return model.SerializeToString()
def create_group_query_attention_graph_past(
config,
past_kv_format=Formats.BSNH,
share_buffer=True,
local_window_size=-1,
rotary=False,
rotary_interleaved=False,
packed=False,
):
past_kv_seqlen = config.kv_sequence_length
present_kv_seqlen = (
config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length
)
nodes = [
helper.make_node(
"GroupQueryAttention",
[
"query",
"key" if not packed else "",
"value" if not packed else "",
"past_key",
"past_value",
"seqlens_k",
"total_sequence_length",
"cos_cache" if rotary else "",
"sin_cache" if rotary else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
local_window_size=local_window_size,
do_rotary=rotary,
rotary_interleaved=rotary_interleaved,
# is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0,
# kv_share_buffer=1 if share_buffer else 0,
domain="com.microsoft",
),
]
graph_input = [
helper.make_tensor_value_info(
"query",
TensorProto.FLOAT,
[
config.batch_size,
config.sequence_length,
(
(config.num_heads * config.head_size)
if not packed
else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size)
),
],
),
helper.make_tensor_value_info(
"past_key",
TensorProto.FLOAT,
[
config.batch_size,
past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen,
config.head_size,
],
),
helper.make_tensor_value_info(
"past_value",
TensorProto.FLOAT,
[
config.batch_size,
past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen,
config.head_size,
],
),
helper.make_tensor_value_info(
"seqlens_k",
TensorProto.INT32,
[config.batch_size],
),
helper.make_tensor_value_info(
"total_sequence_length",
TensorProto.INT32,
[1],
),
]
if not packed:
graph_input += [
helper.make_tensor_value_info(
"key",
TensorProto.FLOAT,
[
config.batch_size,
config.sequence_length,
config.kv_num_heads * config.head_size,
],
),
helper.make_tensor_value_info(
"value",
TensorProto.FLOAT,
[
config.batch_size,
config.sequence_length,
config.kv_num_heads * config.head_size,
],
),
]
if rotary:
graph_input += [
helper.make_tensor_value_info(
"cos_cache",
TensorProto.FLOAT,
[
config.kv_sequence_length + (0 if share_buffer else config.sequence_length),
(math.floor(config.head_size / 16) * 16) // 2,
],
),
helper.make_tensor_value_info(
"sin_cache",
TensorProto.FLOAT,
[
config.kv_sequence_length + (0 if share_buffer else config.sequence_length),
(math.floor(config.head_size / 16) * 16) // 2,
],
),
]
graph_output = [
helper.make_tensor_value_info(
"output",
TensorProto.FLOAT,
[config.batch_size, config.sequence_length, config.num_heads * config.head_size],
),
helper.make_tensor_value_info(
"present_key",
TensorProto.FLOAT,
[
config.batch_size,
present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen,
config.head_size,
],
),
helper.make_tensor_value_info(
"present_value",
TensorProto.FLOAT,
[
config.batch_size,
present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen,
config.head_size,
],
),
]
graph = helper.make_graph(
nodes,
"GroupQueryAttention_Graph",
graph_input,
graph_output,
)
model = helper.make_model(graph)
return model.SerializeToString()
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device)
else:
lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device)
padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
return padding_mask
def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
def output_pad_fn(output_unpad):
return pad_input(output_unpad, indices_q, batch_size, seqlen_q)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
)
max_seqlen_q = seqlen_q
def output_pad_fn(output_unpad):
return rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
def dqkv_pad_fn(dqkv_unpad):
return pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
def dqkv_pad_fn(dqkv_unpad):
return rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
def dkv_pad_fn(dkv_unpad):
return pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else:
def dkv_pad_fn(dkv_unpad):
return rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
def dk_pad_fn(dk_unpad):
return pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
else:
def dk_pad_fn(dk_unpad):
return rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def create_inputs(config: Config, kv_packed=False, qkv_packed=True):
qkv = torch.randn(
config.batch_size,
config.sequence_length,
3,
config.num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
key_padding_mask = generate_random_padding_mask(
config.sequence_length, config.batch_size, device="cpu", mode="random"
)
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
*qkv.unbind(dim=2), key_padding_mask, key_padding_mask, kv_packed, qkv_packed
)
return qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn, key_padding_mask
def generate_token_offset(cu_seqlens, max_seqlen):
token_offset = []
token_padset = [] # These are the indices that contain padding tokens
for i in range(1, len(cu_seqlens)):
start = i - 1
pre_seqlen = cu_seqlens[i - 1]
seqlen = cu_seqlens[i]
token_offset += range(start * max_seqlen, (start * max_seqlen) + (seqlen - pre_seqlen))
token_padset += range((start * max_seqlen) + (seqlen - pre_seqlen), i * max_seqlen)
return numpy.asarray(token_offset + token_padset, dtype=numpy.int32)
def gqa_prompt_func(
q,
k,
v,
config,
new_k,
new_v,
cos=None,
sin=None,
seqlens_k=None,
window_size=-1,
past_kv_format=Formats.BSNH,
share_buffer=True,
rotary_interleaved=False,
):
onnx_model_str = create_group_query_attention_graph_prompt(
config,
past_kv_format,
share_buffer,
local_window_size=window_size,
rotary=cos is not None,
rotary_interleaved=rotary_interleaved,
packed=new_k is None,
)
q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1))
past_k = k.clone() if share_buffer else None
past_v = v.clone() if share_buffer else None
if new_k is not None:
new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1))
new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1))
if share_buffer:
ort_inputs = {
"query": q.detach().cpu().numpy(),
"past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cpu", 0),
"past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cpu", 0),
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
"total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(),
}
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
io_binding = ort_session.io_binding()
if new_k is not None:
ort_inputs["key"] = new_k.detach().cpu().numpy()
ort_inputs["value"] = new_v.detach().cpu().numpy()
io_binding.bind_cpu_input("key", ort_inputs["key"])
io_binding.bind_cpu_input("value", ort_inputs["value"])
if cos is not None:
ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
# TODO: do we need io binding for cpu input?
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_input(
"past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
)
io_binding.bind_input(
"past_value",
"cpu",
0,
numpy.float32,
ort_inputs["past_value"].shape(),
ort_inputs["past_value"].data_ptr(),
)
io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"])
io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"])
io_binding.bind_output("output")
io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"])
io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"])
ort_session.run_with_iobinding(io_binding)
ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
ort_output = numpy.array(ort_output)
output = torch.tensor(ort_output)
return output, present_k, present_v
else:
ort_inputs = {
"query": q.detach().cpu().numpy(),
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
"total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(),
}
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
io_binding = ort_session.io_binding()
if new_k is not None:
ort_inputs["key"] = new_k.detach().cpu().numpy()
ort_inputs["value"] = new_v.detach().cpu().numpy()
io_binding.bind_cpu_input("key", ort_inputs["key"])
io_binding.bind_cpu_input("value", ort_inputs["value"])
if cos is not None:
ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"])
io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"])
io_binding.bind_output("output")
io_binding.bind_output("present_key")
io_binding.bind_output("present_value")
ort_session.run_with_iobinding(io_binding)
ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
ort_output = numpy.array(ort_output)
output = torch.tensor(ort_output)
return output, present_k, present_v
def gqa_past_func(
q,
k,
v,
config,
new_k,
new_v,
cos=None,
sin=None,
seqlens_k=None,
past_kv_format=Formats.BSNH,
share_buffer=True,
window_size=-1,
rotary_interleaved=False,
):
onnx_model_str = create_group_query_attention_graph_past(
config,
past_kv_format,
share_buffer,
local_window_size=window_size,
rotary=cos is not None,
rotary_interleaved=rotary_interleaved,
packed=new_k is None,
)
q = torch.reshape(q, (config.batch_size, config.sequence_length, -1))
past_k = k.clone()
past_v = v.clone()
if new_k is not None:
new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1))
new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1))
if share_buffer:
ort_inputs = {
"query": q.detach().cpu().numpy(),
"past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cpu", 0),
"past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cpu", 0),
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
"total_sequence_length": torch.tensor([config.kv_sequence_length], dtype=torch.int32)
.detach()
.cpu()
.numpy(),
}
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
io_binding = ort_session.io_binding()
if new_k is not None:
ort_inputs["key"] = new_k.detach().cpu().numpy()
ort_inputs["value"] = new_v.detach().cpu().numpy()
io_binding.bind_cpu_input("key", ort_inputs["key"])
io_binding.bind_cpu_input("value", ort_inputs["value"])
if cos is not None:
ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_input(
"past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
)
io_binding.bind_input(
"past_value",
"cpu",
0,
numpy.float32,
ort_inputs["past_value"].shape(),
ort_inputs["past_value"].data_ptr(),
)
io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"])
io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"])
io_binding.bind_output("output")
io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"])
io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"])
ort_session.run_with_iobinding(io_binding)
ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
ort_output = numpy.array(ort_output)
output = torch.tensor(ort_output)
return output, present_k, present_v
else:
ort_inputs = {
"query": q.detach().cpu().numpy(),
"past_key": past_k.detach().cpu().numpy(),
"past_value": past_v.detach().cpu().numpy(),
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
"total_sequence_length": torch.tensor(
[config.kv_sequence_length + config.sequence_length], dtype=torch.int32
)
.detach()
.cpu()
.numpy(),
}
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
io_binding = ort_session.io_binding()
if new_k is not None:
ort_inputs["key"] = new_k.detach().cpu().numpy()
ort_inputs["value"] = new_v.detach().cpu().numpy()
io_binding.bind_cpu_input("key", ort_inputs["key"])
io_binding.bind_cpu_input("value", ort_inputs["value"])
if cos is not None:
ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_cpu_input("past_key", ort_inputs["past_key"])
io_binding.bind_cpu_input("past_value", ort_inputs["past_value"])
io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"])
io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"])
io_binding.bind_output("output")
io_binding.bind_output("present_key")
io_binding.bind_output("present_value")
ort_session.run_with_iobinding(io_binding)
ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
ort_output = numpy.array(ort_output)
output = torch.tensor(ort_output)
return output, present_k, present_v
def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
return col_idx > row_idx + sk - sq
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def attention_qkvpacked_ref(
qkv, key_padding_mask=None, dropout_p=0.0, dropout_mask=None, causal=False, upcast=True, reorder_ops=False
):
return attention_ref(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
key_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
reorder_ops=reorder_ops,
)
def parity_check_gqa_prompt(
config,
causal=True,
local=False,
past_format=Formats.BSNH,
rotary=False,
rotary_interleaved=False,
packed=False,
rtol=1e-3,
atol=1e-3,
):
q = torch.randn(
config.batch_size,
config.q_sequence_length,
config.num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
k = torch.randn(
config.batch_size,
config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
v = torch.randn(
config.batch_size,
config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
new_k = torch.randn(
config.batch_size,
config.kv_sequence_length,
config.kv_num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
new_v = torch.randn(
config.batch_size,
config.kv_sequence_length,
config.kv_num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
window_size = (-1, -1)
left_window_size = -1
if local:
left_window_size = random.randint(1, config.kv_sequence_length)
window_size = (left_window_size, 0)
elif causal:
left_window_size = -1
window_size = (-1, 0)
# Pytorch to compare
k_cache_ref = k.clone()
v_cache_ref = v.clone()
if past_format == Formats.BNSH:
k_cache_ref = k_cache_ref.transpose(1, 2)
v_cache_ref = v_cache_ref.transpose(1, 2)
cache_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size)
rotary_seqlens = torch.tensor([0], device="cpu").repeat(config.batch_size)
if rotary:
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
cos = torch.cos(angle).to(dtype=torch.float32)
sin = torch.sin(angle).to(dtype=torch.float32)
rot = LlamaMSRotaryEmbedding()
q_ro = rot(
q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved
)
k_ro = rot(
new_k.clone(),
cos.unsqueeze(0).unsqueeze(2),
sin.unsqueeze(0).unsqueeze(2),
rotary_seqlens,
rotary_interleaved,
)
else:
cos, sin = None, None
q_ro, k_ro = q, new_k
rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s")
arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size)
kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1")
update_mask = arange < kv_seqlens_expanded
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
key_padding_mask = arange < cache_seqlens_expanded
out_ref, _ = attention_ref(
q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
k_cache_ref = k_cache_ref.transpose(1, 2)
v_cache_ref = v_cache_ref.transpose(1, 2)
# Flash function
if packed:
packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
out, present_k, present_v = gqa_prompt_func(
packed_qkv,
k,
v,
config,
None,
None,
cos,
sin,
cache_seqlens,
left_window_size,
past_format,
True,
rotary_interleaved,
)
else:
out, present_k, present_v = gqa_prompt_func(
q,
k,
v,
config,
new_k,
new_v,
cos,
sin,
cache_seqlens,
left_window_size,
past_format,
True,
rotary_interleaved,
)
out = torch.squeeze(out, 0)
out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
# Make sure past-present buffer updating correctly
assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
# Compare results
all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True)
if Fore is not None:
correct = Fore.GREEN + "True" if all_close else Fore.RED + "False"
else:
correct = "True" if all_close else "False"
print(
"KV-buffer",
" packed:",
packed,
" causal:",
causal,
" local:",
local,
" rotary:",
rotary,
" rotary_interleaved:",
rotary_interleaved,
"past kv format:",
"BSNH" if past_format == Formats.BSNH else "BNSH",
" B:",
config.batch_size,
" S:",
config.q_sequence_length,
" kv S:",
config.kv_sequence_length,
" N:",
config.num_heads,
" kv N:",
config.kv_num_heads,
" h:",
config.head_size,
" Mean Error:",
numpy.mean(numpy.abs(out - out_ref)),
correct,
)
def parity_check_gqa_prompt_no_buff(
config,
causal=True,
local=False,
past_format=Formats.BSNH,
rotary=False,
rotary_interleaved=False,
packed=False,
rtol=1e-3,
atol=1e-3,
):
q = torch.randn(
config.batch_size,
config.q_sequence_length,
config.num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
new_k = torch.randn(
config.batch_size,
config.kv_sequence_length,
config.kv_num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
new_v = torch.randn(
config.batch_size,
config.kv_sequence_length,
config.kv_num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
window_size = (-1, -1)
left_window_size = -1
if local:
left_window_size = random.randint(1, config.kv_sequence_length)
window_size = (left_window_size, 0)
elif causal:
left_window_size = -1
window_size = (-1, 0)
# Pytorch to compare
k_cache_ref = new_k.clone()
v_cache_ref = new_v.clone()
cache_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size)
rotary_seqlens = torch.tensor([0], device="cpu").repeat(config.batch_size)
if rotary:
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
cos = torch.cos(angle).to(dtype=torch.float32)
sin = torch.sin(angle).to(dtype=torch.float32)
rot = LlamaMSRotaryEmbedding()
q_ro = rot(
q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved
)
k_ro = rot(
k_cache_ref.clone(),
cos.unsqueeze(0).unsqueeze(2),
sin.unsqueeze(0).unsqueeze(2),
rotary_seqlens,
rotary_interleaved,
)
else:
cos, sin = None, None
q_ro, k_ro = q, k_cache_ref
k_cache_ref = k_ro
brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
new_mask = brange < cache_seqlens_expanded
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
out_ref, _ = attention_ref(
q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
k_cache_ref = k_cache_ref.transpose(1, 2)
v_cache_ref = v_cache_ref.transpose(1, 2)
# Flash function
if packed:
packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
out, present_k, present_v = gqa_prompt_func(
packed_qkv,
None,
None,
config,
None,
None,
cos,
sin,
cache_seqlens,
left_window_size,
past_format,
False,
rotary_interleaved,
)
else:
out, present_k, present_v = gqa_prompt_func(
q,
None,
None,
config,
new_k,
new_v,
cos,
sin,
cache_seqlens,
left_window_size,
past_format,
False,
rotary_interleaved,
)
out = torch.squeeze(out, 0)
out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
# Make sure past-present buffer updating correctly
assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
# Compare results
all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True)
if Fore is not None:
correct = Fore.GREEN + "True" if all_close else Fore.RED + "False"
else:
correct = "True" if all_close else "False"
print(
"No buff",
" packed:",
packed,
" causal:",
causal,
" local:",
local,
" rotary:",
rotary,
" rotary_interleaved:",
rotary_interleaved,
"past kv format:",
"BSNH" if past_format == Formats.BSNH else "BNSH",
" B:",
config.batch_size,
" S:",
config.q_sequence_length,
" kv S:",
config.kv_sequence_length,
" N:",
config.num_heads,
" kv N:",
config.kv_num_heads,
" h:",
config.head_size,
" Mean Error:",
numpy.mean(numpy.abs(out - out_ref)),
correct,
)
def parity_check_gqa_past(
config,
causal=True,
local=False,
past_format=Formats.BSNH,
rotary=False,
rotary_interleaved=False,
packed=False,
rtol=1e-3,
atol=1e-3,
):
q = torch.randn(
config.batch_size,
config.sequence_length,
config.num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
k = torch.randn(
config.batch_size,
config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
v = torch.randn(
config.batch_size,
config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
new_k = torch.randn(
config.batch_size,
config.sequence_length,
config.kv_num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
new_v = torch.randn(
config.batch_size,
config.sequence_length,
config.kv_num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
window_size = (-1, -1)
left_window_size = -1
if local:
left_window_size = random.randint(1, config.kv_sequence_length)
window_size = (left_window_size, 0)
elif causal:
left_window_size = -1
window_size = (-1, 0)
# Pytorch to compare
k_cache_ref = k.clone()
v_cache_ref = v.clone()
if past_format == Formats.BNSH:
k_cache_ref = k_cache_ref.transpose(1, 2)
v_cache_ref = v_cache_ref.transpose(1, 2)
# cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size)
cache_seqlens = torch.randint(
0,
config.kv_sequence_length - config.sequence_length + 1,
(config.batch_size,),
dtype=torch.int32,
device="cpu",
)
if rotary:
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
cos = torch.cos(angle).to(dtype=torch.float32)
sin = torch.sin(angle).to(dtype=torch.float32)
rot = LlamaMSRotaryEmbedding()
q_ro = rot(
q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved
)
k_ro = rot(
new_k.clone(),
cos.unsqueeze(0).unsqueeze(2),
sin.unsqueeze(0).unsqueeze(2),
cache_seqlens,
rotary_interleaved,
)
else:
cos, sin = None, None
q_ro, k_ro = q, new_k
arange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length
)
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length
out_ref, _ = attention_ref(
q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
k_cache_ref = k_cache_ref.transpose(1, 2)
v_cache_ref = v_cache_ref.transpose(1, 2)
# ORT function
if packed:
packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
out, present_k, present_v = gqa_past_func(
packed_qkv,
k,
v,
config,
None,
None,
cos,
sin,
cache_seqlens,
past_format,
True,
left_window_size,
rotary_interleaved,
)
else:
out, present_k, present_v = gqa_past_func(
q,
k,
v,
config,
new_k,
new_v,
cos,
sin,
cache_seqlens,
past_format,
True,
left_window_size,
rotary_interleaved,
)
out = torch.squeeze(out, 0)
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
# Make sure past-present buffer updating correctly
assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
# Compare results
all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True)
if Fore is not None:
correct = Fore.GREEN + "True" if all_close else Fore.RED + "False"
else:
correct = "True" if all_close else "False"
print(
"KV-buffer",
"past kv format:",
"BSNH" if past_format == Formats.BSNH else "BNSH",
" packed:",
packed,
" causal:",
causal,
" local:",
local,
" rotary:",
rotary,
" rotary_interleaved:",
rotary_interleaved,
" B:",
config.batch_size,
" S:",
config.sequence_length,
" kv S:",
config.kv_sequence_length,
" N:",
config.num_heads,
" kv N:",
config.kv_num_heads,
" h:",
config.head_size,
" Mean Error:",
numpy.mean(numpy.abs(out - out_ref)),
correct,
)
def parity_check_gqa_past_no_buff(
config,
causal=True,
local=False,
past_format=Formats.BSNH,
rotary=False,
rotary_interleaved=False,
packed=False,
rtol=1e-3,
atol=1e-3,
):
torch.manual_seed(69)
q = torch.randn(
config.batch_size,
config.sequence_length,
config.num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
k = torch.randn(
config.batch_size,
config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
v = torch.randn(
config.batch_size,
config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
new_k = torch.randn(
config.batch_size,
config.sequence_length,
config.kv_num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
new_v = torch.randn(
config.batch_size,
config.sequence_length,
config.kv_num_heads,
config.head_size,
device="cpu",
dtype=torch.float32,
requires_grad=False,
)
window_size = (-1, -1)
left_window_size = -1
if local:
left_window_size = random.randint(1, config.kv_sequence_length)
window_size = (left_window_size, 0)
elif causal:
left_window_size = -1
window_size = (-1, 0)
# Pytorch to compare
k_cache_ref = k.clone()
v_cache_ref = v.clone()
if past_format == Formats.BNSH:
k_cache_ref = k_cache_ref.transpose(1, 2)
v_cache_ref = v_cache_ref.transpose(1, 2)
k_cache_ref = torch.cat((k_cache_ref, new_k), 1)
v_cache_ref = torch.cat((v_cache_ref, new_v), 1)
# cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size)
cache_seqlens = torch.randint(
0,
config.kv_sequence_length,
(config.batch_size,),
dtype=torch.int32,
device="cpu",
)
cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length
if rotary:
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
angle = (
torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
)
cos = torch.cos(angle).to(dtype=torch.float32)
sin = torch.sin(angle).to(dtype=torch.float32)
rot = LlamaMSRotaryEmbedding()
q_ro = rot(
q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved
)
k_ro = rot(
new_k.clone(),
cos.unsqueeze(0).unsqueeze(2),
sin.unsqueeze(0).unsqueeze(2),
cache_seqlens,
rotary_interleaved,
)
else:
cos, sin = None, None
q_ro, k_ro = q, new_k
arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cpu"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length
)
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length
out_ref, _ = attention_ref(
q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
k_cache_ref = k_cache_ref.transpose(1, 2)
v_cache_ref = v_cache_ref.transpose(1, 2)
# Flash function
if packed:
packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
out, present_k, present_v = gqa_past_func(
packed_qkv,
k,
v,
config,
None,
None,
cos,
sin,
cache_seqlens,
past_format,
False,
window_size=left_window_size,
rotary_interleaved=rotary_interleaved,
)
else:
out, present_k, present_v = gqa_past_func(
q,
k,
v,
config,
new_k,
new_v,
cos,
sin,
cache_seqlens,
past_format,
False,
window_size=left_window_size,
rotary_interleaved=rotary_interleaved,
)
out = torch.squeeze(out, 0)
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
# Compare results
all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True)
# if not all_close:
# print("seqlens", cache_seqlens)
# print("out", out)
# print("out_ref", out_ref)
# print(out - out_ref)
if Fore is not None:
correct = Fore.GREEN + "True" if all_close else Fore.RED + "False"
else:
correct = "True" if all_close else "False"
print(
"NO buff",
" packed:",
packed,
" causal:",
causal,
" local:",
local,
" rotary:",
rotary,
" rotary_interleaved:",
rotary_interleaved,
"past kv format:",
"BSNH" if past_format == Formats.BSNH else "BNSH",
" B:",
config.batch_size,
" S:",
config.sequence_length,
" kv S:",
config.kv_sequence_length,
" N:",
config.num_heads,
" kv N:",
config.kv_num_heads,
" h:",
config.head_size,
" Mean Error:",
numpy.mean(numpy.abs(out - out_ref)),
correct,
)
class TestGQA(unittest.TestCase):
def test_gqa_no_past(self):
torch.manual_seed(69)
print("-------- TEST GQA NO PAST (PROMPT CASE) ---------")
batches = [1, 3] if pipeline_mode else [1, 3, 5]
seqs = (
[
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
if pipeline_mode
else [
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
)
num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
for b in batches:
for sq, skv in seqs:
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]:
for packed in [False, True]:
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
past_kv_format = Formats.BNSH
parity_check_gqa_prompt(
config,
local=local,
past_format=past_kv_format,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
parity_check_gqa_prompt_no_buff(
config,
local=local,
past_format=past_kv_format,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
def test_gqa_past(self):
print("-------- TEST GQA PAST (TOKEN GEN) ---------")
batches = [1, 3] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
if pipeline_mode
else [
(1, 128),
(1, 339),
(1, 1024),
(1, 5000),
(1, 800),
(1, 256),
(1, 799),
(1, 2048),
# (1, 128 * 512),
# (16, 128 * 512),
# (128, 128),
]
)
num_h = [(16, 16), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]:
for packed in [False, True]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
past_kv_format = Formats.BNSH
parity_check_gqa_past(
config,
local=local,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
parity_check_gqa_past_no_buff(
config,
local=local,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
if __name__ == "__main__":
unittest.main()