Cherry pick LLaMA to rel-1.16.2 (round 2) (#18245)

2nd round of cherry pick LLaMA related changes to 1.16.2 release.
---------

Co-authored-by: aciddelgado <139922440+aciddelgado@users.noreply.github.com>
Co-authored-by: Frank Dong <123416088+frank-dong-ms@users.noreply.github.com>
This commit is contained in:
Tianlei Wu 2023-11-02 17:16:35 -07:00 committed by GitHub
parent 2f57f1e4d7
commit 70b8cda979
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 2088 additions and 550 deletions

View file

@ -113,6 +113,11 @@ set(contrib_ops_excluded_files
"cuda_contrib_kernels.h"
"inverse.cc"
"fused_conv.cc"
"bert/group_query_attention_helper.h"
"bert/group_query_attention.h"
"bert/group_query_attention.cc"
"bert/group_query_attention_impl.h"
"bert/group_query_attention_impl.cu"
)
if (NOT onnxruntime_ENABLE_ATEN)

View file

@ -2265,14 +2265,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
</dl>
#### Outputs (1 - 3)
#### Outputs
<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dt><tt>present_key</tt> : T</dt>
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
<dt><tt>present_value</tt> : T</dt>
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
</dl>

View file

@ -374,6 +374,7 @@ Status EfficientAttention(
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.total_sequence_length;
p.max_sequence_length = parameters.total_sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
@ -395,6 +396,7 @@ Status EfficientAttention(
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
p.is_kv_bsnh = true;
p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
? data.scratch
: nullptr;

View file

@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.num_keys = params.kv_sequence_length;
if (params.causal) {
p.custom_mask_type = Attention::CausalFromTopLeft;
p.custom_mask_type = Attention::CausalFromBottomRight;
}
// Input format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
// We use max_sequence_length to calculate KV stride
if (params.is_kv_bsnh) {
// Input Q, K, V format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.num_heads * params.qk_head_size;
p.v_strideM = params.num_heads * params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.num_heads * params.qk_head_size;
p.v_strideM = params.num_heads * params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.kv_sequence_length;
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.max_sequence_length;
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.max_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
} else {
// Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.max_sequence_length * params.qk_head_size;
p.v_strideH = params.max_sequence_length * params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.qk_head_size;
p.v_strideM = params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length;
p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length;
p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
}
}
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;

View file

@ -14,10 +14,12 @@ namespace cuda {
struct MemoryEfficientAttentionParams {
int32_t sm;
bool is_half;
bool is_kv_bsnh = true;
int32_t batch_size;
int32_t num_heads;
int32_t sequence_length;
int32_t kv_sequence_length;
int32_t max_sequence_length;
int32_t qk_head_size;
int32_t v_head_size;
bool causal;

View file

@ -6,9 +6,8 @@
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention.h"
#include "contrib_ops/cuda/bert/group_query_attention_helper.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
// #include "contrib_ops/cpu/utils/console_dumper.h"
using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
@ -55,6 +54,13 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
#else
disable_flash_attention_ = true;
#endif
#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
}
template <typename T>
@ -92,18 +98,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
output_shape[2] = static_cast<int64_t>(parameters.hidden_size);
Tensor* output = context->Output(0, output_shape);
std::vector<int64_t> present_dims;
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
present_dims = {
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
} else { // BNSH
present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
}
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);
#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
onnxruntime::flash::is_supported(device_prop,
@ -143,8 +137,47 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto seqlens_k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif
// only kernel implemented for gqa right now
ORT_ENFORCE(use_flash_attention);
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
// allocate buffers
size_t kv_buffer_bytes = 0;
// need a buffer if we must ungroup kv
const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads);
if (use_memory_efficient_attention && needs_buff) {
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size);
}
size_t fmha_buffer_bytes = 0;
if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));
}
auto k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
#else
constexpr bool use_memory_efficient_attention = false;
auto k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
#endif
std::vector<int64_t> present_dims;
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
present_dims = {
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
} else { // BNSH
present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
}
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
@ -155,6 +188,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
}
@ -167,6 +201,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (seqlens_k_buffer != nullptr) {
data.seqlens_k = reinterpret_cast<int*>(seqlens_k_buffer.get());
}
if (k_buffer != nullptr) {
data.k = reinterpret_cast<CudaT*>(k_buffer.get());
data.v = reinterpret_cast<CudaT*>(v_buffer.get());
}
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}
cublasHandle_t cublas = GetCublasHandle(context);

View file

@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel {
bool is_past_bsnh_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
};
} // namespace cuda

View file

@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query,
// query (Q) : (B, S, D)
// key (K) : (B, S+, D_kv)
// value (V) : (B, S+, D_kv)
ORT_UNUSED_PARAMETER(value);
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = Q_K_V_BSNH;
const auto& query_dims = query->Shape().GetDims();
const auto& key_dims = key->Shape().GetDims();
const auto& value_dims = value->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query,
int q_hidden_size = static_cast<int>(query_dims[2]);
int head_size = static_cast<int>(q_hidden_size) / num_heads;
int kv_sequence_length = sequence_length;
int kv_hidden_size = (key_dims.size() == 3)
? static_cast<int>(key_dims[2])
: (kv_num_heads * static_cast<int>(key_dims[3]));
int kv_sequence_length = static_cast<int>(key_dims[1]);
int kv_hidden_size = static_cast<int>(key_dims[2]);
int max_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query,
"Input 'past_key' and 'past_value' shall be both present or both absent");
}
if (key != nullptr) {
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}
if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}
if (key_dims[2] != value_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)");
}
qkv_format = Q_K_V_BSNH;
kv_sequence_length = static_cast<int>(key_dims[1]);
} else {
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Missing key tensor.");
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}
if (value != nullptr) {
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
}
if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
}
if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}
} else {
if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Missing value tensor.");
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
}
if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
}
if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}
// When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly.
int32_t past_sequence_length = 0;
int present_sequence_length = 0;
int present_sequence_length = kv_sequence_length;
if (past_seq_len != nullptr) {
if (past_key == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Past KV must be present as share-buffer when using past_seq_len pointer.");
}
if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"past_sequence_length tensor must be of one element when using past kv.");
@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query,
} else {
past_sequence_length = static_cast<int32_t>(*((*past_seq_len).template Data<int64_t>()));
}
if (past_sequence_length + kv_sequence_length > max_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length");
}
present_sequence_length = max_sequence_length;
} else if (past_key != nullptr) {
past_sequence_length = max_sequence_length; // this is the length of past_key tensor

View file

@ -37,6 +37,7 @@ limitations under the License.
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
@ -47,6 +48,8 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
////////// Auxiliary Kernels for KV prep
// Kernel for seqlens_k
__global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) {
int id = blockDim.x * blockIdx.x + threadIdx.x;
@ -75,7 +78,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen,
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
// past_kv: BPNH or BNPH
// new_kv: BLNH or BNLH
// new_kv: BLNH
// present_kv: BTNH or BNTH, where T = P + L
const int past_seqlen = present_seqlen - new_seqlen;
@ -95,33 +98,32 @@ __global__ void ConcatNewToPastKV(const int new_seqlen,
}
}
// Use when (H*)*num_heads > 1024
template <typename T>
__global__ void ConcatNewToPastKVLarge(const int new_seqlen,
const int H,
const int num_heads,
const T* past_kv,
const T* new_kv,
T* present_kv,
const bool is_bsnh) {
// Use when (H*)*num_heads > 1024
int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
int i = threadIdx.x + (blockDim.x * blockIdx.x);
if (i < H * num_heads) {
const int h = i % H;
const int n = i / H;
const int s = blockIdx.y;
const int b = blockIdx.z;
const int present_seqlen = gridDim.y;
const int present_seqlen = gridDim.x;
const int num_heads = blockDim.y;
const int thread_stride = blockDim.x;
const int present_batch_stride = present_seqlen * num_heads * H;
const int row_stride = is_bsnh ? num_heads * H : H;
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
const int present_batch_stride = present_seqlen * num_heads * H;
const int row_stride = is_bsnh ? num_heads * H : H;
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
// past_kv: BPNH or BNPH
// new_kv: BLNH
// present_kv: BTNH or BNTH, where T = P + L
const int past_seqlen = present_seqlen - new_seqlen;
// past_kv: BPNH or BNPH
// new_kv: BLNH or BNLH
// present_kv: BTNH or BNTH, where T = P + L
const int past_seqlen = present_seqlen - new_seqlen;
while (h < H) {
int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h;
if (s < past_seqlen) {
const int past_batch_stride = past_seqlen * num_heads * H;
@ -135,21 +137,296 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen,
const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h;
present_kv[out_offset] = new_kv[in_offset];
}
h += thread_stride;
}
}
// Concat new to past in present. Supports past BSNH or past BNSH
template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
Stream* ort_stream,
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data) {
assert(data.use_flash_attention);
Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data,
cudaStream_t stream,
const int max_threads_per_block) {
const int batch_size = parameters.batch_size;
const int kv_sequence_length = parameters.kv_sequence_length;
const int present_sequence_length = parameters.present_sequence_length;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time.
if (H * kv_num_heads <= max_threads_per_block) {
const dim3 grid(present_sequence_length, batch_size, 1);
const dim3 block(H, kv_num_heads, 1);
ConcatNewToPastKV<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
reinterpret_cast<const float2*>(data.past_key),
reinterpret_cast<const float2*>(data.key),
reinterpret_cast<float2*>(data.present_key),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatNewToPastKV<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
reinterpret_cast<const float2*>(data.past_value),
reinterpret_cast<const float2*>(data.value),
reinterpret_cast<float2*>(data.present_value),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
} else {
int steps = (H * kv_num_heads + 255) / 256;
const dim3 grid(steps, present_sequence_length, batch_size);
const dim3 block(256, 1, 1);
ConcatNewToPastKVLarge<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
H,
kv_num_heads,
reinterpret_cast<const float2*>(data.past_key),
reinterpret_cast<const float2*>(data.key),
reinterpret_cast<float2*>(data.present_key),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatNewToPastKVLarge<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
H,
kv_num_heads,
reinterpret_cast<const float2*>(data.past_value),
reinterpret_cast<const float2*>(data.value),
reinterpret_cast<float2*>(data.present_value),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
}
return CUDA_CALL(cudaGetLastError());
}
// Kernel to append new kv to kv buffer in place
template <typename T>
__global__ void ConcatKVInPlace(const int past_seqlen,
const int present_seqlen,
T* kv_buff,
const T* new_kv,
const bool is_bsnh) { // refers to kv buff; otherwise bnsh
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int new_seqlen = gridDim.x;
const int num_heads = blockDim.y;
const int H = blockDim.x;
const int present_batch_stride = present_seqlen * num_heads * H;
const int present_row_stride = is_bsnh ? num_heads * H : H;
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
// kv_buff: BTNH or BNTH with buffered memory for new
// new_kv: BLNH
int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h;
// Note: new KV always BSNH
const int new_batch_stride = new_seqlen * num_heads * H;
const int new_row_stride = num_heads * H;
const int new_head_stride = H;
const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
kv_buff[out_offset] = new_kv[in_offset];
}
template <typename T>
__global__ void ConcatKVInPlaceLarge(const int past_seqlen,
const int present_seqlen,
const int H,
const int num_heads,
T* kv_buff,
const T* new_kv,
const bool is_bsnh) { // refers to kv buff; otherwise bnsh
int i = threadIdx.x + (blockDim.x * blockIdx.x);
if (i < H * num_heads) {
const int h = i % H;
const int n = i / H;
const int s = blockIdx.y;
const int b = blockIdx.z;
const int new_seqlen = gridDim.y;
const int present_batch_stride = present_seqlen * num_heads * H;
const int present_row_stride = is_bsnh ? num_heads * H : H;
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
// kv_buff: BTNH or BNTH with buffered memory for new
// new_kv: BLNH
int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h;
// Note: new KV always BSNH
const int new_batch_stride = new_seqlen * num_heads * H;
const int new_row_stride = num_heads * H;
const int new_head_stride = H;
const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
kv_buff[out_offset] = new_kv[in_offset];
}
}
// Concat new to kv buffer in place
template <typename T>
Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data,
cudaStream_t stream,
const int max_threads_per_block) {
const int batch_size = parameters.batch_size;
const int kv_sequence_length = parameters.kv_sequence_length;
const int present_sequence_length = parameters.present_sequence_length;
const int past_sequence_length = parameters.past_sequence_length;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
const int H = head_size / 4;
if (H * kv_num_heads <= max_threads_per_block) {
const dim3 grid(kv_sequence_length, batch_size, 1);
const dim3 block(H, kv_num_heads, 1);
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(past_sequence_length,
present_sequence_length,
reinterpret_cast<float2*>(data.present_key),
reinterpret_cast<const float2*>(data.key),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(past_sequence_length,
present_sequence_length,
reinterpret_cast<float2*>(data.present_value),
reinterpret_cast<const float2*>(data.value),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
} else {
int steps = int(ceil(float(H * kv_num_heads) / 256.0));
const dim3 grid(steps, kv_sequence_length, batch_size);
const dim3 block(256, 1, 1);
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(past_sequence_length,
present_sequence_length,
H,
kv_num_heads,
reinterpret_cast<float2*>(data.present_key),
reinterpret_cast<const float2*>(data.key),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(past_sequence_length,
present_sequence_length,
H,
kv_num_heads,
reinterpret_cast<float2*>(data.present_value),
reinterpret_cast<const float2*>(data.value),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
}
return CUDA_CALL(cudaGetLastError());
}
// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh
template <typename T>
__global__ void Ungroup(const T* kv_in,
T* kv_out,
const int in_seqlen,
const int kv_num_heads,
const bool is_bsnh) {
const int h = threadIdx.x;
const int out_n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int out_seqlen = gridDim.x;
const int q_num_heads = blockDim.y;
const int H = blockDim.x;
const int q_kv_head_ratio = q_num_heads / kv_num_heads;
const int out_batch_stride = out_seqlen * q_num_heads * H;
const int out_row_stride = is_bsnh ? q_num_heads * H : H;
const int out_head_stride = is_bsnh ? H : out_seqlen * H;
const int in_batch_stride = in_seqlen * kv_num_heads * H;
const int in_row_stride = is_bsnh ? kv_num_heads * H : H;
const int in_head_stride = is_bsnh ? H : in_seqlen * H;
const int in_n = out_n / q_kv_head_ratio;
const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h;
const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h;
kv_out[out_offset] = kv_in[in_offset];
}
template <typename T>
__global__ void UngroupLarge(const T* kv_in,
T* kv_out,
const int H,
const int in_seqlen,
const int q_num_heads,
const int kv_num_heads,
const bool is_bsnh) {
int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements
if (i < H * q_num_heads) {
const int out_seqlen = gridDim.y;
const int s = blockIdx.y;
const int b = blockIdx.z;
const int q_kv_head_ratio = q_num_heads / kv_num_heads;
const int out_batch_stride = out_seqlen * q_num_heads * H;
const int out_row_stride = is_bsnh ? q_num_heads * H : H;
const int out_head_stride = is_bsnh ? H : out_seqlen * H;
const int in_batch_stride = in_seqlen * kv_num_heads * H;
const int in_row_stride = is_bsnh ? kv_num_heads * H : H;
const int in_head_stride = is_bsnh ? H : in_seqlen * H;
const int h = i % H;
const int out_n = i / H;
const int in_n = out_n / q_kv_head_ratio;
const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h;
const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h;
kv_out[out_offset] = kv_in[in_offset];
}
}
// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it.
Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters,
float2* k_buff, float2* v_buff,
const float2* k_og, const float2* v_og,
const int buff_seqlen, const int og_seqlen,
const bool is_bsnh,
cudaStream_t stream,
const int max_threads_per_block) {
const int batch_size = parameters.batch_size;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
const int H = head_size / 4;
if (H * num_heads <= max_threads_per_block) {
const dim3 grid(buff_seqlen, batch_size, 1);
const dim3 block(H, num_heads, 1);
Ungroup<float2><<<grid, block, 0, stream>>>(k_og,
k_buff,
og_seqlen,
kv_num_heads,
is_bsnh);
Ungroup<float2><<<grid, block, 0, stream>>>(v_og,
v_buff,
og_seqlen,
kv_num_heads,
is_bsnh);
} else {
int steps = int(ceil(float(H * num_heads) / 256.0));
const dim3 grid(steps, buff_seqlen, batch_size);
const dim3 block(256, 1, 1);
UngroupLarge<float2><<<grid, block, 0, stream>>>(k_og,
k_buff,
H,
og_seqlen,
num_heads,
kv_num_heads,
is_bsnh);
UngroupLarge<float2><<<grid, block, 0, stream>>>(v_og,
v_buff,
H,
og_seqlen,
num_heads,
kv_num_heads,
is_bsnh);
}
return CUDA_CALL(cudaGetLastError());
}
////////// Launch Kernels
#if USE_FLASH_ATTENTION
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
template <typename T>
Status FlashAttention(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data,
float scale) {
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
@ -160,108 +437,177 @@ Status QkvToContext(
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(head_size)) : parameters.scale;
if (data.use_flash_attention) {
assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(parameters.num_heads % parameters.kv_num_heads == 0);
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
void* key = reinterpret_cast<void*>(const_cast<T*>(data.key));
void* value = reinterpret_cast<void*>(const_cast<T*>(data.value));
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
void* key = reinterpret_cast<void*>(const_cast<T*>(data.key));
void* value = reinterpret_cast<void*>(const_cast<T*>(data.value));
bool is_causal = parameters.is_unidirectional;
bool is_causal = parameters.is_unidirectional;
if (data.past_key != nullptr && data.past_key == data.present_key) {
// Share buffer case
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
if (data.past_key == nullptr && data.present_key == nullptr) {
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
device_prop, stream, query, key, value, data.output, reinterpret_cast<void*>(data.softmax_lse),
parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, head_size,
parameters.sequence_length, parameters.kv_sequence_length, scale, is_causal, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum)));
// Launch kernel to copy seqlen
int thr_per_blk = 256;
int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
} else if (data.past_key == data.present_key) {
// Assume past and present kv share buffer.
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
assert(parameters.past_sequence_length >= 0);
assert(data.past_value != nullptr);
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast<void*>(data.softmax_lse),
reinterpret_cast<void*>(data.seqlens_k), batch_size, num_heads, kv_num_heads,
head_size, sequence_length, present_sequence_length, kv_sequence_length,
scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum)));
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
} else {
// Not share buffer or no past (prompt generation)
// Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
// Launch kernel to copy seqlen
int thr_per_blk = 256;
int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
DUMP_TENSOR_INIT();
DUMP_TENSOR("seqlens_k", data.seqlens_k, 1, batch_size);
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast<void*>(data.softmax_lse),
batch_size, num_heads, kv_num_heads, head_size,
sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum), past_bsnh));
}
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast<void*>(data.softmax_lse),
reinterpret_cast<void*>(data.seqlens_k), batch_size, num_heads, kv_num_heads,
head_size, sequence_length, present_sequence_length, kv_sequence_length,
scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum)));
DUMP_TENSOR_INIT();
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
} else if (data.present_key != nullptr && (data.past_key != nullptr || kv_sequence_length == present_sequence_length)) {
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
// Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
if (head_size % 4 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "requires head_size be divisible by 4");
}
const int H = head_size / 4;
if (H * kv_num_heads <= max_threads_per_block) {
const dim3 grid(present_sequence_length, batch_size, 1);
const dim3 block(H, kv_num_heads, 1);
ConcatNewToPastKV<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
reinterpret_cast<const float2*>(data.past_key),
reinterpret_cast<const float2*>(data.key),
reinterpret_cast<float2*>(data.present_key),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatNewToPastKV<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
reinterpret_cast<const float2*>(data.past_value),
reinterpret_cast<const float2*>(data.value),
reinterpret_cast<float2*>(data.present_value),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
} else {
const dim3 grid(present_sequence_length, batch_size, 1);
const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1);
ConcatNewToPastKVLarge<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
H,
reinterpret_cast<const float2*>(data.past_key),
reinterpret_cast<const float2*>(data.key),
reinterpret_cast<float2*>(data.present_key),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatNewToPastKVLarge<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
H,
reinterpret_cast<const float2*>(data.past_value),
reinterpret_cast<const float2*>(data.value),
reinterpret_cast<float2*>(data.present_value),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
}
return Status::OK();
}
#endif
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
#if USE_MEMORY_EFFICIENT_ATTENTION
template <typename T>
Status EfficientAttention(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data,
float scale) {
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int past_sequence_length = parameters.past_sequence_length;
const int present_sequence_length = parameters.present_sequence_length;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
// Launch kernel to copy seqlen
int thr_per_blk = 256;
int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast<void*>(data.softmax_lse),
batch_size, num_heads, kv_num_heads, head_size,
sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum), past_bsnh));
const void* query = reinterpret_cast<const void*>(data.query);
const void* key = reinterpret_cast<const void*>(data.key);
const void* value = reinterpret_cast<const void*>(data.value);
if (data.past_key != nullptr) {
// Past key case
// concatenate new kv to past kv
if (data.past_key == data.present_key) {
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block));
} else {
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
}
const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
if (num_heads == kv_num_heads) {
// Use present kv directly if not grouped
key = reinterpret_cast<const void*>(data.present_key);
value = reinterpret_cast<const void*>(data.present_value);
} else {
// Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path
float2* k_buff = reinterpret_cast<float2*>(data.k);
float2* v_buff = reinterpret_cast<float2*>(data.v);
const float2* k_og = reinterpret_cast<const float2*>(data.present_key);
const float2* v_og = reinterpret_cast<const float2*>(data.present_value);
ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, past_sequence_length + kv_sequence_length,
present_sequence_length, is_bsnh, stream, max_threads_per_block));
key = reinterpret_cast<const void*>(data.k);
value = reinterpret_cast<const void*>(data.v);
}
} else if (num_heads == kv_num_heads) {
// no past or present and no need to ungroup... still copy kv into present buffer
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
key = reinterpret_cast<const void*>(data.present_key);
value = reinterpret_cast<const void*>(data.present_value);
} else {
// intermediate buffer so q and kv have same num heads... still copy kv into present buffer
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
float2* k_buff = reinterpret_cast<float2*>(data.k);
float2* v_buff = reinterpret_cast<float2*>(data.v);
const float2* k_og = reinterpret_cast<const float2*>(data.present_key);
const float2* v_og = reinterpret_cast<const float2*>(data.present_value);
ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, kv_sequence_length,
kv_sequence_length, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH, stream,
max_threads_per_block));
key = reinterpret_cast<const void*>(data.k);
value = reinterpret_cast<const void*>(data.v);
}
DUMP_TENSOR_INIT();
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
p.is_half = sizeof(T) == 2;
p.batch_size = batch_size;
p.num_heads = num_heads;
p.sequence_length = sequence_length;
p.kv_sequence_length = past_sequence_length + kv_sequence_length;
p.max_sequence_length = (num_heads == kv_num_heads) ? present_sequence_length : past_sequence_length + kv_sequence_length;
p.qk_head_size = head_size;
p.v_head_size = head_size;
p.causal = parameters.is_unidirectional;
p.scale = scale;
p.seqlen_k_ptr = nullptr;
p.seqstart_q_ptr = nullptr;
p.seqstart_k_ptr = nullptr;
p.query = query;
p.key = key;
p.value = value;
p.attn_bias = nullptr;
p.is_attn_bias_batched = false;
p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
p.output = data.output;
p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float))
? data.fmha_buffer
: nullptr;
p.stream = stream;
run_memory_efficient_attention(p);
return Status::OK();
DUMP_TENSOR_INIT();
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
return Status::OK();
}
#endif
////////// API Functions
template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
Stream* ort_stream,
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data) {
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size)) : parameters.scale;
#if USE_FLASH_ATTENTION
if (data.use_flash_attention) {
return FlashAttention(device_prop, stream, parameters, data, scale);
}
#endif
#if USE_MEMORY_EFFICIENT_ATTENTION
if (data.use_memory_efficient_attention) {
return EfficientAttention(device_prop, stream, parameters, data, scale);
}
#endif
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet.");
}

View file

@ -14,19 +14,28 @@ namespace cuda {
template <typename T>
struct GroupQueryAttentionData {
// Input Tensors
const T* query = nullptr;
const T* key = nullptr;
const T* value = nullptr;
const T* past_key = nullptr;
const T* past_value = nullptr;
// Flash buffers
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;
T* out_accum = nullptr;
int* seqlens_k = nullptr;
// Memory Efficient buffers
T* fmha_buffer = nullptr;
T* k = nullptr;
T* v = nullptr;
// Output Tensors
T* output = nullptr;
T* present_key = nullptr;
T* present_value = nullptr;
// Kernel Flags
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
};
template <typename T>

View file

@ -507,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass(
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
p.is_half = sizeof(T) == 2;
p.is_kv_bsnh = true;
p.batch_size = parameters.batch_size;
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.sequence_length;
p.max_sequence_length = parameters.sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;

View file

@ -688,6 +688,7 @@ Status FusedAttentionCutlass(
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.sequence_length;
p.max_sequence_length = parameters.sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;
@ -702,6 +703,7 @@ Status FusedAttentionCutlass(
p.attn_bias = data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
p.is_kv_bsnh = true;
p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float))
? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v)))
: nullptr;

View file

@ -1041,15 +1041,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key"
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
"T",
OpSchema::Optional)
"T")
.Output(2,
"present_value",
"present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value"
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
"T",
OpSchema::Optional)
"T")
.TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {

View file

@ -147,6 +147,7 @@ class SymbolicShapeInference:
"GatherElements": self._infer_GatherElements,
"GatherND": self._infer_GatherND,
"Identity": self._pass_on_shape_and_type,
"AllReduce": self._pass_on_shape_and_type,
"If": self._infer_If,
"Loop": self._infer_Loop,
"MatMul": self._infer_MatMul,

View file

@ -1272,7 +1272,7 @@ def find_past_seq_len_usage(subg: GraphProto):
return tensor_names_to_rename, nodes_to_remove
def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0):
def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1):
past_seq_len = past_seq_len_input
if past_seq_len not in model.get_graphs_input_names():
# Add model input for past sequence length
@ -1282,6 +1282,10 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads
# Replace MultiHeadAttention with GroupQueryAttention
for node in model.model.graph.node:
if node.op_type == "MultiHeadAttention":
num_heads_mha = 0
for att in node.attribute:
if att.name == "num_heads":
num_heads_mha = att.i
gqa_node = onnx.helper.make_node(
"GroupQueryAttention",
inputs=[
@ -1295,8 +1299,8 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads
outputs=node.output,
name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
domain="com.microsoft",
num_heads=node.attribute[0].i,
kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads,
num_heads=num_heads_mha // world_size,
kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
is_past_bsnh=0,
)
model.model.graph.node.remove(node)

View file

@ -130,3 +130,8 @@ class Fusion:
for node in nodes:
if node not in self.nodes_to_remove:
self.nodes_to_remove.append(node)
def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]):
for node in nodes:
if node not in self.nodes_to_remove and node not in nodes_to_keep:
self.nodes_to_remove.append(node)

View file

@ -323,6 +323,7 @@ class FusionRotaryAttention(FusionAttention):
# qkv_nodes_1 is for LLaMA-2 Microsoft
# qkv_nodes_2 is for LLaMA-2 Hugging Face
# qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model
qkv_nodes = None
qkv_nodes_1 = self.model.match_parent_path(
normalize_node,
@ -334,18 +335,27 @@ class FusionRotaryAttention(FusionAttention):
["MatMul", "Reshape", "Transpose", "MatMul"],
[1, 0, 0, 0],
)
qkv_nodes_3 = self.model.match_parent_path(
normalize_node,
["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 0, 0, 0, 0],
)
if qkv_nodes_1 is not None:
_, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
qkv_nodes = qkv_nodes_1
elif qkv_nodes_2 is not None:
_, reshape_qkv, _, matmul_qkv = qkv_nodes_2
qkv_nodes = qkv_nodes_2
elif qkv_nodes_3 is not None:
_, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3
qkv_nodes = qkv_nodes_3
else:
logger.debug("fuse_rotary_attention: failed to match qkv nodes")
return
# v_nodes_1 is for LLaMA-2 Microsoft
# v_nodes_3 is for LLaMA-2 Hugging Face
# v_nodes_4 is for LLaMA-2 70B model
past_v, present_v, past_seq_len = "", "", ""
v_nodes = None
v_nodes_1 = self.model.match_parent_path(
@ -363,6 +373,118 @@ class FusionRotaryAttention(FusionAttention):
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
_, v_nodes_4, _ = self.model.match_parent_paths_all(
matmul_qkv,
[
(
["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Expand",
"Where",
"Equal",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Expand",
"Where",
"Equal",
"Mul",
"ConstantOfShape",
"Shape",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Expand",
"Where",
"ConstantOfShape",
"Shape",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Expand",
"Where",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0],
),
(
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 1, 0, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Concat",
"Unsqueeze",
"Mul",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
),
(
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 1, 2, 0, 0, 0, 1, 0, 0],
),
(
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 1, 3, 0, 0, 0, 1, 0, 0],
),
],
output_name_to_node=None,
)
if v_nodes_1 is not None:
reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
v_nodes = v_nodes_1
@ -388,6 +510,11 @@ class FusionRotaryAttention(FusionAttention):
transpose_v, reshape_v, matmul_v = v_nodes_3
v_nodes = v_nodes_3
present_v = transpose_v.output[0]
elif v_nodes_4 is not None and len(v_nodes_4) == 9:
concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:]
v_nodes = v_nodes_4
past_v = concat_v.input[0]
present_v = concat_v.output[0]
else:
logger.debug("fuse_rotary_attention: failed to match v path")
return
@ -461,6 +588,7 @@ class FusionRotaryAttention(FusionAttention):
# k_nodes_1 is for LLaMA-2 Microsoft
# k_nodes_2 is for LLaMA-2 Hugging Face
# k_nodes_4 is for LLaMA-2 70B Hugging Face
past_k, present_k = "", ""
k_nodes = None
k_nodes_1 = self.model.match_parent_path(
@ -478,6 +606,174 @@ class FusionRotaryAttention(FusionAttention):
["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
[1, 0, 1, 0, 0, 0],
)
_, k_nodes_4, _ = self.model.match_parent_paths_all(
matmul_qk,
[
(
[
"Transpose",
"Reshape",
"Expand",
"Unsqueeze",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Expand",
"Where",
"Equal",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Expand",
"Where",
"Equal",
"Mul",
"ConstantOfShape",
"Shape",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Expand",
"Where",
"ConstantOfShape",
"Shape",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Expand",
"Where",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Concat",
"Unsqueeze",
"Mul",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0],
),
],
output_name_to_node=None,
)
if k_nodes_1 is not None:
reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
k_nodes = k_nodes_1
@ -505,6 +801,12 @@ class FusionRotaryAttention(FusionAttention):
k_nodes = k_nodes_3
past_k = concat_k.input[0]
present_k = concat_k.output[0]
elif k_nodes_4 is not None and len(k_nodes_4) == 9:
reshape_k, matmul_k = k_nodes_4[0][-2:]
concat_k, rotary_k = k_nodes_4[0][-5:-3]
k_nodes = k_nodes_4
past_k = concat_k.input[0]
present_k = concat_k.output[0]
else:
logger.debug("fuse_rotary_attention: failed to match k nodes")
return
@ -552,7 +854,7 @@ class FusionRotaryAttention(FusionAttention):
return
root_output = reshape_qkv_2.output[0]
elif qkv_nodes == qkv_nodes_2:
elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3):
if not self.check_runtime_shape_paths_for_nodes(
reshape_qkv,
reshape_q,
@ -573,6 +875,9 @@ class FusionRotaryAttention(FusionAttention):
# Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
rotary_k.output[0] = rotary_k.name + "_output_0"
if qkv_nodes == qkv_nodes_3:
qkv_nodes = qkv_nodes[1:]
new_node = self.create_mha_node(
matmul_q.input[0],
root_output,
@ -594,7 +899,14 @@ class FusionRotaryAttention(FusionAttention):
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend(qkv_nodes[1:])
self.nodes_to_remove.extend(v_nodes[:-1])
if v_nodes != v_nodes_4:
self.nodes_to_remove.extend(v_nodes[:-1])
else:
nodes_to_keep = [v_nodes[0][-1]]
for temp_path in v_nodes:
self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
self.nodes_to_remove.extend(qk_nodes)
if k_nodes == k_nodes_1:
@ -608,6 +920,10 @@ class FusionRotaryAttention(FusionAttention):
self.nodes_to_remove.append(k_nodes[1])
self.nodes_to_remove.append(k_nodes[3])
self.nodes_to_remove.append(k_nodes[4])
elif k_nodes == k_nodes_4:
nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
for temp_path in k_nodes:
self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
if q_nodes == q_nodes_1:
self.nodes_to_remove.extend(q_nodes[:-2])

View file

@ -10,6 +10,8 @@ Please note the package versions needed for using LLaMA-2 in the `requirements.t
- Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file.
- `requirements-quant.txt`
- For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor)
- `requirements-70b-model.txt`
- For running the LLaMA-2 70B model on multiple GPUs
- `requirements.txt`
- Package versions needed in each of the above files
@ -79,6 +81,15 @@ model.save_pretrained(name.split("/")[-1] + "-onnx")
Here are some additional examples for exporting LLaMA-2.
Export Model with Different GPU Device Ids
```
# From source using first GPU:
$ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b
# From wheel using second GPU:
$ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b
```
Export Saved Model on Disk
```
# From source:
@ -153,6 +164,19 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu
```
Export LLaMA-2 70B sharded model into 4 partitions
```
# From source:
# 1. Install necessary packages from requirements-70b-model.txt
# 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command:
$ ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/
# 3. Shard and export the LLaMA-2 70B model. With FP16, you will need at least 140GB of GPU memory to load the model. Therefore, you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command:
$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda
```
## Benchmark LLaMA-2
Here are some examples of how you can benchmark LLaMA-2.
@ -220,11 +244,11 @@ python3 -m models.llama.benchmark \
--device cuda
```
6. ONNX Runtime, FP32, convert_to_onnx
6. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \
--benchmark-type ort-convert-to-onnx \
--ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
--batch-sizes "1 2" \
@ -232,11 +256,11 @@ python3 -m models.llama.benchmark \
--device cpu
```
7. ONNX Runtime, FP16, convert_to_onnx
7. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \
--benchmark-type ort-convert-to-onnx \
--ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
--ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \

View file

@ -11,6 +11,7 @@ import numpy as np
import onnx
import psutil
import torch
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings,
get_merged_sample_with_past_kv_inputs,
@ -133,6 +134,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
use_fp16=args.use_fp16,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
@ -144,6 +146,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
use_fp16=args.use_fp16,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
elif args.benchmark_type == "ort-msft":
@ -244,10 +247,10 @@ def get_model(args: argparse.Namespace):
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
# Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
logger.info(f"Loading model from {args.ort_model_path}")
logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
start_time = time.time()
model = ort.InferenceSession(
args.ort_model_path,
args.ort_model_path.format(args.rank),
sess_options,
providers=[args.execution_provider],
)
@ -315,10 +318,11 @@ def time_fn(args, fn, inputs):
latency = total_time / args.num_runs
throughput = args.batch_size / latency
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Sequence Length: {args.sequence_length}")
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} tps")
if args.rank == 0:
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Sequence Length: {args.sequence_length}")
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} tps")
return
@ -358,7 +362,8 @@ def measure_fn(args, fn, inputs):
process.cpu_percent(interval=0.1)
fn(inputs)
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
if args.rank == 0:
logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
# Measure memory usage
gc.collect()
@ -451,7 +456,7 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings(
model, inputs, args.device, int(args.device_id), kv_cache_ortvalues
model, inputs, args.device, int(args.rank), kv_cache_ortvalues
)
setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding, kv_cache_ortvalues
@ -511,7 +516,7 @@ def run_inference(args, init_inputs, iter_inputs, model):
raise Exception(f"Cannot recognize {args.benchmark_type}")
def get_args():
def get_args(rank=0):
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
@ -569,7 +574,7 @@ def get_args():
parser.add_argument(
"-s",
"--sequence-lengths",
default="8 16 32 64 128 256 512",
default="32 64 128 256 512",
)
parser.add_argument(
"-d",
@ -606,9 +611,9 @@ def get_args():
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
args.execution_provider = (args.execution_provider, {"device_id": rank})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
args.execution_provider = (args.execution_provider, {"device_id": rank})
args.device = "cuda"
# Check that paths have been specified for any benchmarking with ORT
@ -635,14 +640,19 @@ def get_args():
def main():
args = get_args()
rank = get_rank()
world_size = get_size()
args = get_args(rank)
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
args.rank = rank
args.world_size = world_size
tokenizer = LlamaTokenizer.from_pretrained(args.model_name)
config = LlamaConfig.from_pretrained(args.model_name)
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
setattr(args, "tokenizer", tokenizer) # noqa: B010
@ -656,7 +666,7 @@ def main():
# Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
onnx_model = onnx.load_model(args.ort_model_path, load_external_data=False)
onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
@ -666,7 +676,8 @@ def main():
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
if args.rank == 0:
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
setattr(args, "batch_size", int(batch_size)) # noqa: B010
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010

View file

@ -0,0 +1,12 @@
#!/bin/bash
NUM_GPUS=${1:-1}
MPI="mpirun --allow-run-as-root
-mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0
--tag-output --npernode $NUM_GPUS --bind-to numa
-x MIOPEN_FIND_MODE=1"
CMD="$MPI python benchmark.py ${@:2}"
$CMD

View file

@ -247,6 +247,7 @@ def main():
torch.backends.cudnn.benchmark = True
all_results = []
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
# Benchmark PyTorch without torch.compile
if args.hf_pt_eager:
@ -266,8 +267,6 @@ def main():
args.sequence_lengths,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
@ -298,8 +297,6 @@ def main():
args.sequence_lengths,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
@ -332,8 +329,6 @@ def main():
args.sequence_lengths,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
@ -366,8 +361,6 @@ def main():
args.sequence_lengths,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
@ -399,8 +392,6 @@ def main():
args.sequence_lengths,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",

View file

@ -0,0 +1,12 @@
#!/bin/bash
NUM_GPUS=${1:-1}
MPI="mpirun --allow-run-as-root
-mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0
--tag-output --npernode $NUM_GPUS --bind-to numa
-x MIOPEN_FIND_MODE=1"
CMD="$MPI python convert_to_onnx.py ${@:2}"
$CMD

View file

@ -1,16 +1,16 @@
import argparse
import logging
import os
import tempfile
import shutil
from itertools import chain
from typing import List
import onnx
import torch
from benchmark_helper import Precision, prepare_environment, setup_logger
from convert_generation import replace_mha_with_gqa
from dist_settings import barrier, get_rank, get_size, init_dist
from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs
from llama_parity import main as parity_check
from llama_torch import setup_torch_model
from onnx_model import OnnxModel
from optimizer import optimize_model
from packaging import version
@ -18,8 +18,11 @@ from transformers import LlamaConfig, LlamaForCausalLM
from onnxruntime import quantization as ort_quantization
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger
from onnxruntime.transformers.convert_generation import replace_mha_with_gqa
logger = logging.getLogger("")
init_dist()
def get_model_dynamic_axes(input_names: List[str], output_names: List[str]):
@ -129,7 +132,9 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st
# del onnx_model
# temp_dir.cleanup()
#
def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
def run_dynamo_export(
args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1
):
from torch._dynamo import config
config.capture_scalar_outputs = True
@ -150,9 +155,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_model_fp32.onnx.data")
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
@ -160,7 +165,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll
# Export decoder_with_past_model.onnx
input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs(
l_config, device, batch_size, sequence_length
l_config, device, batch_size, sequence_length, world_size=world_size
)
temp_dir = args.output # tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
@ -172,9 +177,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx")
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_with_past_model_fp32.onnx.data")
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
@ -183,10 +188,21 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll
logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")
def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
def _prepare_dir(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def run_torchscript_separate_export(
args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1
):
# Dummy values for export
batch_size, sequence_length = 2, 8
device = torch.device("cpu")
# set device used to export model
# for llama-2-70b we will use current gpus to speed up export process
# for other models, we will use CPU to make sure we have enough memory to do export
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
# Export decoder_model.onnx
decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length)
@ -199,8 +215,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon
),
]
dynamic_axes = get_model_dynamic_axes(input_names, output_names)
temp_dir = tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir.name, "temp.onnx")
# Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
# Use temp folder per rank to avoid race condition here.
temp_dir = f"./temp_{rank}"
_prepare_dir(temp_dir)
temp_path = os.path.join(temp_dir, "temp.onnx")
torch.onnx.export(
llama,
args=decoder_inputs,
@ -218,18 +238,25 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(
onnx_model,
output_path,
f"{args.model_name}_decoder_model_fp32.onnx.data",
f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data",
)
del onnx_model
temp_dir.cleanup()
shutil.rmtree(temp_dir)
# Export decoder_with_past_model.onnx
decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length)
decoder_with_past_inputs = get_sample_with_past_kv_inputs(
l_config,
device,
batch_size,
sequence_length,
use_fp16=args.precision == Precision.FLOAT16,
world_size=world_size,
)
input_names = [
"input_ids",
"attention_mask",
@ -247,8 +274,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon
),
]
dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names)
temp_dir = tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir.name, "temp.onnx")
# Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
# Use temp folder per rank to avoid race condition here.
temp_dir = f"./temp_past_{rank}"
_prepare_dir(temp_dir)
temp_path = os.path.join(temp_dir, "temp.onnx")
torch.onnx.export(
llama,
args=decoder_with_past_inputs,
@ -266,27 +297,45 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx")
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(
onnx_model,
output_path,
f"{args.model_name}_decoder_with_past_model_fp32.onnx.data",
f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data",
)
del onnx_model
temp_dir.cleanup()
shutil.rmtree(temp_dir)
logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!")
logger.info(
f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!"
)
def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
def run_torchscript_merged_export(
args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1
):
# Dummy values for export
batch_size, sequence_length, past_sequence_length = 2, 8, 0
device = torch.device("cpu")
# set device used to export model
# for llama-2-70b we will use current gpus to speed up export process
# for other models, we will use CPU to make sure we have enough memory to do export
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
# Export decoder_merged_model.onnx
decoder_merged_inputs = get_merged_sample_with_past_kv_inputs(
l_config, device, batch_size, sequence_length, past_sequence_length
l_config,
device,
batch_size,
sequence_length,
past_sequence_length,
max_seq_len=max_sequence_length,
use_fp16=args.precision == Precision.FLOAT16,
world_size=world_size,
)
input_names = [
"input_ids",
@ -305,8 +354,12 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi
),
]
dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names)
temp_dir = tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir.name, "temp.onnx")
# Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
# Use temp folder per rank to avoid race condition here.
temp_dir = f"./temp_{rank}"
_prepare_dir(temp_dir)
temp_path = os.path.join(temp_dir, "temp.onnx")
torch.onnx.export(
llama,
args=decoder_merged_inputs,
@ -324,17 +377,17 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
output_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx")
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(
onnx_model,
output_path,
f"{args.model_name}_decoder_merged_model_fp32.onnx.data",
f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx.data",
)
del onnx_model
temp_dir.cleanup()
shutil.rmtree(temp_dir)
logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!")
logger.info(f"The {args.model_name} merged ONNX model has been successfully created with the TorchScript exporter!")
# Optimize the model as FP32
@ -357,12 +410,16 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str):
remove_existing_model(input_path)
def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str]):
decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx")
def convert_to_float16(
args: argparse.Namespace, config: LlamaConfig, old_paths: List[str], rank: int = 0, world_size: int = 1
):
decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx")
decoder_with_past_model_fp16_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx"
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx"
)
decoder_merged_model_fp16_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx"
)
decoder_merged_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp16.onnx")
new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path]
logger.info("Converting to float16...")
@ -370,7 +427,7 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths:
if os.path.exists(fp32_path):
model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True))
model.convert_float_to_float16(keep_io_types=False)
model = use_group_query_attention(config, model)
model = use_group_query_attention(config, model, world_size)
model.save_model_to_file(fp16_path, use_external_data_format=True)
del model
logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!")
@ -380,9 +437,11 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths:
return new_paths
def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel):
def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1):
# Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes
fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads)
fp16_model_opt = replace_mha_with_gqa(
fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size
)
fp16_model_opt.prune_graph()
fp16_model_opt.update_graph(allow_remove_graph_inputs=True)
return fp16_model_opt
@ -406,7 +465,7 @@ def smooth_quant(
calibration_sampling_size=[args.calibration_sampling_size],
recipes={
"optypes_to_exclude_output_quant": ["MatMul"],
"smooth_quant": args.smooth_quant,
"smooth_quant": True,
"smooth_quant_args": {"alpha": args.smooth_quant_alpha},
},
op_type_dict={
@ -526,15 +585,6 @@ def get_args():
help="Execution provider to verify parity with",
)
parser.add_argument(
"-id",
"--device-id",
required=False,
type=str,
default="0",
help="Device ID for GPUs",
)
parser.add_argument(
"-r",
"--reexport",
@ -655,6 +705,14 @@ def get_args():
)
parser.set_defaults(use_dynamo_export=False)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default="./model_cache",
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
)
args = parser.parse_args()
return args
@ -673,144 +731,182 @@ def main():
remove_existing_files(args.output)
logger.info(f"Arguments: {args}")
world_size = get_size()
rank = get_rank()
# Load model and config
use_auth_token = args.input == os.path.join(".")
setattr(args, "use_auth_token", use_auth_token) # noqa: B010
location = args.model_name if use_auth_token else args.input
l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token)
llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, use_cache=True)
original_model_name = args.model_name
setattr(args, "original_model_name", original_model_name) # noqa: B010
args.model_name = args.model_name.split("/")[-1]
# Set model paths for FP32 model
decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
decoder_with_past_model_fp32_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx"
)
decoder_merged_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx")
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
missing_separate_exports = (
args.no_merged
and not os.path.exists(decoder_model_fp32_path)
and not os.path.exists(decoder_with_past_model_fp32_path)
)
missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path)
location = args.original_model_name if use_auth_token else args.input
# Export to ONNX
if missing_separate_exports or missing_merged_export:
if args.use_dynamo_export and missing_separate_exports:
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
logger.warning(
"Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
# use cuda for Llama-2-70b to speedup export, other models use CPU by default
l_config, llama = setup_torch_model(
args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None
)
assert l_config.num_attention_heads % world_size == 0 and l_config.num_key_value_heads % world_size == 0
barrier()
for i in range(world_size):
if i == rank:
# Set model paths for FP32 model
decoder_model_fp32_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx"
)
logger.warning(
"Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
decoder_with_past_model_fp32_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx"
)
run_dynamo_export(args, l_config, llama)
elif args.no_merged:
run_torchscript_separate_export(args, l_config, llama)
else:
run_torchscript_merged_export(args, l_config, llama)
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
# Set model paths to store FP32 optimized model
decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx")
decoder_with_past_model_fp32_opt_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_fp32_opt.onnx"
)
decoder_merged_model_fp32_opt_path = os.path.join(
args.output, f"{args.model_name}_decoder_merged_model_fp32_opt.onnx"
)
new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path]
# Run the optimizer script
logger.info("Optimizing models...")
for orig_path, opt_path in zip(old_paths, new_paths):
if os.path.exists(orig_path):
optimize_export(l_config, input_path=orig_path, output_path=opt_path)
# Re-assign default FP32 model paths as their optimized versions
decoder_model_fp32_path = decoder_model_fp32_opt_path
decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
logger.info(
f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
)
# Change precision of exported models from FP32
if args.precision == Precision.FLOAT16:
new_paths = convert_to_float16(args, l_config, old_paths)
elif args.precision == Precision.INT8:
decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx")
decoder_with_past_model_int8_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx"
)
decoder_merged_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int8.onnx")
new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]
if args.quantization_method == "smooth_quant":
if not args.no_merged:
logger.error("SmoothQuant must be used on separately exported models")
else:
logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8")
smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])
elif args.quantization_method == "quantize_dynamic":
logger.warning(
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
decoder_merged_model_fp32_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx"
)
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
logger.info("Quantizing to int8...")
for fp32_path, int8_path in zip(old_paths, new_paths):
if os.path.exists(fp32_path):
ort_quantization.quantize_dynamic(
fp32_path,
int8_path,
op_types_to_quantize=["MatMul", "Gemm", "Gather"]
if args.quantize_embedding_layer
else ["MatMul", "Gemm"],
per_channel=args.quantize_per_channel,
reduce_range=args.quantize_reduce_range,
use_external_data_format=True,
extra_options={"MatMulConstBOnly": True},
missing_separate_exports = (
args.no_merged
and not os.path.exists(decoder_model_fp32_path)
and not os.path.exists(decoder_with_past_model_fp32_path)
)
missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path)
# Export to ONNX
if missing_separate_exports or missing_merged_export:
if args.use_dynamo_export and missing_separate_exports:
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
logger.warning(
"Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
)
logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!")
remove_existing_model(decoder_model_fp32_path)
logger.warning(
"Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
)
run_dynamo_export(args, l_config, llama)
elif args.no_merged:
run_torchscript_separate_export(args, l_config, llama, rank, world_size)
else:
run_torchscript_merged_export(args, l_config, llama, rank, world_size)
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
# Set model paths to store FP32 optimized model
decoder_model_fp32_opt_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx"
)
decoder_with_past_model_fp32_opt_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx"
)
decoder_merged_model_fp32_opt_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx"
)
new_paths = [
decoder_model_fp32_opt_path,
decoder_with_past_model_fp32_opt_path,
decoder_merged_model_fp32_opt_path,
]
else:
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
# Run the optimizer script
logger.info("Optimizing models...")
for orig_path, opt_path in zip(old_paths, new_paths):
if os.path.exists(orig_path):
optimize_export(l_config, input_path=orig_path, output_path=opt_path)
elif args.precision == Precision.INT4:
if args.execution_provider != "cpu":
old_paths = convert_to_float16(args, l_config, old_paths)
# Re-assign default FP32 model paths as their optimized versions
decoder_model_fp32_path = decoder_model_fp32_opt_path
decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
decoder_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int4.onnx")
decoder_with_past_model_int4_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_int4.onnx"
)
decoder_merged_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int4.onnx")
new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]
logger.info(
f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
)
for fp_path, int4_path in zip(old_paths, new_paths):
if os.path.exists(fp_path):
model = onnx.load_model(fp_path, load_external_data=True)
quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[])
quant.process()
quant.model.save_model_to_file(int4_path, use_external_data_format=True)
del model
del quant
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
remove_existing_model(fp_path)
# Change precision of exported models from FP32
if args.precision == Precision.FLOAT16:
new_paths = convert_to_float16(args, l_config, old_paths, rank, world_size)
elif args.precision == Precision.INT8:
decoder_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx"
)
decoder_with_past_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx"
)
decoder_merged_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx"
)
new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]
if args.quantization_method == "smooth_quant":
if not args.no_merged:
logger.error("SmoothQuant must be used on separately exported models")
else:
logger.info(
f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8"
)
smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])
elif args.quantization_method == "quantize_dynamic":
logger.warning(
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
)
logger.info("Quantizing to int8...")
for fp32_path, int8_path in zip(old_paths, new_paths):
if os.path.exists(fp32_path):
ort_quantization.quantize_dynamic(
fp32_path,
int8_path,
op_types_to_quantize=["MatMul", "Gemm", "Gather"]
if args.quantize_embedding_layer
else ["MatMul", "Gemm"],
per_channel=args.quantize_per_channel,
reduce_range=args.quantize_reduce_range,
use_external_data_format=True,
extra_options={"MatMulConstBOnly": True},
)
logger.info(
f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!"
)
remove_existing_model(decoder_model_fp32_path)
logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
else:
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
elif args.precision == Precision.INT4:
if args.execution_provider != "cpu":
old_paths = convert_to_float16(args, l_config, old_paths, rank, world_size)
decoder_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx"
)
decoder_with_past_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx"
)
decoder_merged_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx"
)
new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]
for fp_path, int4_path in zip(old_paths, new_paths):
if os.path.exists(fp_path):
model = onnx.load_model(fp_path, load_external_data=True)
quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[])
quant.process()
quant.model.save_model_to_file(int4_path, use_external_data_format=True)
del model
del quant
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
remove_existing_model(fp_path)
barrier()
logger.info("Verifying parity on all ONNX models created")
@ -824,7 +920,12 @@ def main():
# Verify parity on all saved ONNX models
for filename in os.listdir(args.output):
if ".data" in filename or ".onnx" not in filename:
if (
".data" in filename
or ".onnx" not in filename
or args.precision not in filename
or f"rank_{rank}" not in filename
):
continue
parity_cmd = [
@ -834,10 +935,10 @@ def main():
os.path.join(args.output, filename),
"-ep",
args.execution_provider,
"-id",
args.device_id,
"-fp",
args.precision,
"--cache_dir",
args.cache_dir,
]
if "with_past" in filename:
parity_cmd.append("--use_past_kv")
@ -845,6 +946,7 @@ def main():
parity_cmd.append("--merged")
try:
logger.debug(f"check parity with cmd: {parity_cmd}")
parity_check(parity_cmd)
except Exception as e:
logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True)

View file

@ -0,0 +1,45 @@
import os
import torch.distributed as dist
comm = None
def init_dist():
if "LOCAL_RANK" in os.environ:
int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
from mpi4py import MPI
comm = MPI.COMM_WORLD # noqa: F841
int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0))
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
else:
# don't need to do init for single process
pass
def get_rank():
return comm.Get_rank() if comm is not None else 0
def get_size():
return comm.Get_size() if comm is not None else 1
def barrier():
if comm is not None:
comm.Barrier()
def print_out(*args):
if get_rank() == 0:
print(*args)

View file

@ -66,12 +66,13 @@ def get_sample_with_past_kv_inputs(
use_fp16: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
# position_ids is of shape (batch_size, 1)
position_ids = get_position_ids(attention_mask, use_past_kv=True)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
@ -123,12 +124,13 @@ def get_merged_sample_with_past_kv_inputs(
use_fp16: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
@ -220,8 +222,8 @@ def get_msft_sample_inputs(
# Create past_key_values
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool):
num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads
def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(

View file

@ -6,7 +6,7 @@ from typing import List
import numpy as np
import torch
from benchmark_helper import setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings,
convert_inputs_for_ort,
@ -14,9 +14,11 @@ from llama_inputs import (
get_sample_inputs,
get_sample_with_past_kv_inputs,
)
from llama_torch import setup_torch_model
from transformers import LlamaConfig, LlamaForCausalLM
import onnxruntime as ort
from onnxruntime.transformers.benchmark_helper import setup_logger
logger = logging.getLogger("")
@ -30,6 +32,7 @@ def get_sequence_lengths(args: argparse.Namespace):
def get_inputs(args: argparse.Namespace, config: LlamaConfig):
# Dummy values for parity
world_size = get_size()
batch_size = 2
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args)
@ -43,10 +46,17 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig):
max_seq_len=max_sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
world_size=world_size,
)
elif args.use_past_kv:
inputs = get_sample_with_past_kv_inputs(
config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True
config,
args.device,
batch_size,
sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
world_size=world_size,
)
else:
inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
@ -66,6 +76,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
torch.cuda.synchronize()
end_time = time.time()
logger.info(f"PyTorch took {end_time - start_time} s")
del pt_model
# Run inference with ORT
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args)
@ -76,12 +87,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
device=args.execution_provider,
device_id=int(args.device_id),
device_id=int(args.rank),
)
ep = f"{args.execution_provider.upper()}ExecutionProvider"
if ep == "CUDAExecutionProvider":
ep = (ep, {"device_id": args.device_id})
ep = (ep, {"device_id": args.rank})
ort_model = ort.InferenceSession(
args.onnx_model_path,
sess_options=ort.SessionOptions(),
@ -91,7 +102,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
# Add IO bindings for non-CPU execution providers
if args.execution_provider != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings(
ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues
ort_model, inputs, args.execution_provider, int(args.rank), kv_cache_ortvalues
)
io_binding.synchronize_inputs()
@ -101,6 +112,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
end_time = time.time()
ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
del ort_model
else:
start_time = time.time()
@ -155,15 +167,6 @@ def get_args(argv: List[str]):
help="Execution provider to verify parity with",
)
parser.add_argument(
"-id",
"--device-id",
required=False,
type=str,
default="0",
help="Device ID for GPUs",
)
parser.add_argument(
"-v",
"--verbose",
@ -195,6 +198,14 @@ def get_args(argv: List[str]):
help="Precision of model",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default="./model_cache",
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
)
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
@ -210,21 +221,23 @@ def main(argv: List[str] = []): # noqa: B006
args = get_args(argv)
setup_logger(args.verbose)
logger.info(f"Arguments: {args}")
rank = get_rank()
# Load model and config
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{args.device_id}") # noqa: B010
args.rank = rank
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
use_auth_token = args.torch_model_directory == os.path.join(".")
location = args.model_name if use_auth_token else args.torch_model_directory
config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token)
llama = LlamaForCausalLM.from_pretrained(
config, llama = setup_torch_model(
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
use_auth_token=use_auth_token,
use_cache=True,
).to(args.device)
device=args.device,
)
kv_cache_ortvalues = {}
if not args.merged:

View file

@ -0,0 +1,38 @@
import logging
import os
import torch
from dist_settings import barrier, get_rank, get_size
from transformers import LlamaConfig, LlamaForCausalLM
logger = logging.getLogger("")
def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, device=None):
world_size = get_size()
logger.info(f"world_size: {world_size}")
rank = get_rank()
barrier()
if not os.path.exists(args.cache_dir):
os.makedirs(args.cache_dir, exist_ok=True)
for i in range(world_size):
if i == rank % (world_size):
l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir)
l_config.use_cache = True
llama = LlamaForCausalLM.from_pretrained(
location,
use_auth_token=use_auth_token,
config=l_config,
torch_dtype=torch_dtype,
cache_dir=args.cache_dir,
)
if world_size > 1:
llama.parallel_model()
if device:
llama.to(device)
llama.eval()
llama.requires_grad_(False)
barrier()
return l_config, llama

View file

@ -0,0 +1,4 @@
-r requirements.txt
git+https://github.com/frankdongms/transformers.git@frdong/shard_llama
mpi4py
psutil

View file

@ -337,6 +337,18 @@ class OnnxModel:
return i, matched, return_indice
return -1, None, None
def match_parent_paths_all(self, node, paths, output_name_to_node):
match_i, matches, return_indices = [], [], []
for i, path in enumerate(paths):
assert isinstance(path, (List, Tuple))
return_indice = []
matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
if matched:
match_i.append(i)
matches.append(matched)
return_indices.append(return_indice)
return match_i, matches, return_indices
def match_parent_path(
self,
node,

View file

@ -10,7 +10,10 @@
# license information.
# -------------------------------------------------------------------------
import math
import os
import platform
import random
import unittest
import numpy
import torch
@ -22,6 +25,8 @@ from onnxruntime import InferenceSession, OrtValue, SessionOptions
torch.manual_seed(0)
pipeline_mode = True # Reduces number of tests so pipeline doesn't time out
class Formats:
BSNH = 0
@ -159,7 +164,7 @@ def create_multihead_attention_graph(config):
return model.SerializeToString()
def create_group_query_attention_graph_no_past(config, causal=False):
def create_group_query_attention_graph_no_past(config, causal=False, present_kv_format=Formats.BSNH):
nodes = [
helper.make_node(
"GroupQueryAttention",
@ -168,11 +173,12 @@ def create_group_query_attention_graph_no_past(config, causal=False):
"key",
"value",
],
["output"],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
unidirectional=1 if causal else 0,
is_past_bsnh=1 if present_kv_format == Formats.BSNH else 0,
domain="com.microsoft",
),
]
@ -213,6 +219,26 @@ def create_group_query_attention_graph_no_past(config, causal=False):
TensorProto.FLOAT16,
[config.batch_size, config.sequence_length, config.num_heads * config.head_size],
),
helper.make_tensor_value_info(
"present_key",
TensorProto.FLOAT16,
[
config.batch_size,
config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
],
),
helper.make_tensor_value_info(
"present_value",
TensorProto.FLOAT16,
[
config.batch_size,
config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads,
config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
],
),
]
graph = helper.make_graph(
@ -514,7 +540,6 @@ def generate_token_offset(cu_seqlens, max_seqlen):
return numpy.asarray(token_offset + token_padset, dtype=numpy.int32)
# TODO(aciddelgado): rename
def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False):
onnx_model_str = create_packed_multihead_attention_graph(config)
qkv_unpad = torch.swapdims(qkv_unpad, 1, 2)
@ -548,8 +573,8 @@ def mha_func(q, k, v, config):
return output
def gqa_no_past_func(q, k, v, config, causal=True):
onnx_model_str = create_group_query_attention_graph_no_past(config, causal)
def gqa_no_past_func(q, k, v, config, causal=True, present_kv_format=Formats.BSNH):
onnx_model_str = create_group_query_attention_graph_no_past(config, causal, present_kv_format=present_kv_format)
q = torch.reshape(q, (config.batch_size, config.sequence_length, -1))
k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1))
v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1))
@ -560,7 +585,7 @@ def gqa_no_past_func(q, k, v, config, causal=True):
}
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"])
ort_output = ort_session.run(None, ort_inputs)
ort_output, _, _ = ort_session.run(None, ort_inputs)
ort_output = numpy.array(ort_output)
output = torch.tensor(ort_output)
return output
@ -689,17 +714,12 @@ def attention_ref(
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if causal:
# causal_mask = torch.triu(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
# )
causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device)
scores.masked_fill_(causal_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1)
if causal: # Some rows are completely masked out so we fill them with zero instead of NaN
attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
@ -1072,12 +1092,6 @@ def parity_check_gqa_past_no_buff(
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
# print(present_k[0, 0, config.past_sequence_length, :10])
# print(k_cache_ref[0, 0, config.past_sequence_length, :10])
# print(k_cache_ref.shape)
# print(present_k - k_cache_ref.detach().cpu().numpy())
# Make sure past-present buffer updating correctly
if past_format == Formats.BSNH:
assert numpy.allclose(
@ -1141,84 +1155,185 @@ def parity_check_gqa_past_no_buff(
)
class TestMHA(unittest.TestCase):
def test_packed_mha(self):
if not torch.cuda.is_available() or platform.system() != "Linux":
return
major, _ = torch.cuda.get_device_capability()
if major < 8:
return
print("-------- TEST PACKED MHA ---------")
batches = [2] if pipeline_mode else [1, 5]
seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
for b in batches:
for s in seqs:
for n in num_h:
for h in h_sizes:
config = Config(b, s, s, 0, n, n, h)
parity_check_mha(config, True)
def test_mha(self):
if not torch.cuda.is_available() or platform.system() != "Linux":
return
major, _ = torch.cuda.get_device_capability()
if major < 8:
return
print("-------- TEST MHA ---------")
batches = [2] if pipeline_mode else [1, 5]
seqs = (
[(1, 128), (113, 211), (2048, 2048)]
if pipeline_mode
else [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
]
)
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
for b in batches:
for s, s2 in seqs:
for n in num_h:
for h in h_sizes:
config = Config(b, s, s2, 0, n, n, h)
parity_check_mha(config, False)
class TestGQA(unittest.TestCase):
def test_gqa_no_past(self):
if not torch.cuda.is_available():
return
major, minor = torch.cuda.get_device_capability()
torch.manual_seed(69)
print("-------- TEST GQA ---------")
batches = [2] if pipeline_mode else [1, 5]
seqs = (
[(1, 128), (113, 211), (2048, 2048)]
if pipeline_mode
else [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1024, 1024),
(1023, 1024),
(2048, 2048),
]
)
num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
if major < 5 or (major == 5 and minor < 3):
return
print("------- MEMORY EFFICIENT ATTENTION ---------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for causal in [True, False]:
config = Config(b, s, s2, 0, n, n2, h)
parity_check_gqa_no_past(config, causal=causal)
if major < 8 or platform.system() != "Linux":
return
print("------- FLASH ATTENTION --------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for causal in [True, False]:
config = Config(b, s, s2, 0, n, n2, h)
parity_check_gqa_no_past(config, causal=causal)
def test_gqa_past(self):
if not torch.cuda.is_available():
return
major, minor = torch.cuda.get_device_capability()
if major < 5 or (major == 5 and minor < 3):
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("-------- TEST GQA PAST ---------")
print("-------- MEMORY EFFICEINT --------")
batches = [2] if pipeline_mode else [1, 2]
seqs = (
[(1, 128), (3, 1024), (64, 2048)]
if pipeline_mode
else [
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 512),
(16, 128 * 512),
(128, 128),
]
)
num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for causal in [True]:
for past_kv_format in [Formats.BNSH, Formats.BSNH]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
parity_check_gqa_past(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
parity_check_gqa_past_no_buff(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
if major < 8 or platform.system() != "Linux":
return
print("------- FLASH ATTENTION -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for causal in [True]:
for past_kv_format in [Formats.BNSH, Formats.BSNH]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
parity_check_gqa_past(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
parity_check_gqa_past_no_buff(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
if __name__ == "__main__":
print("-------- TEST PACKED MHA ---------")
for b in [5]:
for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]:
for n in [6]:
for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
config = Config(b, s, s, 0, n, n, h)
parity_check_mha(config, True)
print("-------- TEST MHA ---------")
for b in [5]:
for s, s2 in [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
]:
for n in [6]:
for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
config = Config(b, s, s2, 0, n, n, h)
parity_check_mha(config, False)
print("-------- TEST GQA ---------")
for b in [5]:
for s, s2 in [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
]:
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
for causal in [True, False]:
config = Config(b, s, s2, 0, n, n2, h)
parity_check_gqa_no_past(config, causal=causal)
print("-------- TEST GQA PAST ---------")
random.seed(69)
for b in [2]:
for s, s2 in [
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 512),
(16, 128 * 512),
(128, 128),
]:
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
for causal in [True]:
for past_kv_format in [Formats.BNSH, Formats.BSNH]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
parity_check_gqa_past(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
parity_check_gqa_past_no_buff(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
unittest.main()

View file

@ -96,7 +96,7 @@ class TestRotaryAttentionFusion(unittest.TestCase):
helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]),
helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size),
]
if model_type in {"past", "merged", "llama2_msft"}:
if model_type in {"past", "merged", "llama2_msft", "70b_distributed_merged"}:
inputs.extend(
[
helper.make_tensor_value_info(
@ -164,14 +164,14 @@ class TestRotaryAttentionFusion(unittest.TestCase):
if is_fused or model_type == "llama2_msft":
# q_out/k_out
return f"{node_type}_out"
if model_type in {"no_past", "past", "merged"}:
if model_type in {"no_past", "past", "merged", "70b_distributed_merged"}:
if node_type == "k":
return "k_before_rope"
return "q_before_rope"
return ""
def get_first_rope_output(node_type: str):
if is_fused or model_type in {"llama2_msft", "past", "merged"}:
if is_fused or model_type in {"llama2_msft", "past", "merged", "70b_distributed_merged"}:
if node_type == "q":
return "q_rope"
return "k_rope"
@ -295,23 +295,225 @@ class TestRotaryAttentionFusion(unittest.TestCase):
)
k_nodes = [reshape_k_node, transpose_k_1_node]
if model_type in {"past", "merged"}:
if model_type == "70b_distributed_merged":
concat_k_node = helper.make_node(
"Concat",
inputs=["past_key", "k_rope"],
outputs=["present_key"],
axis=2,
)
k_nodes.append(concat_k_node)
shape_k1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k1_out"], name="Shape_k1")
shape_k2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k2_out"], name="Shape_k2")
shape_k3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k3_out"], name="Shape_k3")
shape_k4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k4_out"], name="Shape_k4")
transpose_k_2_node = helper.make_node(
"Transpose",
inputs=["present_key"],
outputs=["k"],
name="Transpose_k_2",
perm=[0, 1, 3, 2],
)
return k_nodes + [transpose_k_2_node] # noqa: RUF005
gather_k_1 = helper.make_node(
"Gather",
inputs=["shape_k1_out", "one"],
outputs=["gather_k1_out"],
name="Gather_k_1",
axis=0,
)
gather_k_2 = helper.make_node(
"Gather",
inputs=["shape_k2_out", "one"],
outputs=["gather_k2_out"],
name="Gather_k_2",
axis=0,
)
gather_k_3 = helper.make_node(
"Gather",
inputs=["shape_k3_out", "one"],
outputs=["gather_k3_out"],
name="Gather_k_3",
axis=0,
)
gather_k_4 = helper.make_node(
"Gather",
inputs=["shape_k4_out", "one"],
outputs=["gather_k4_out"],
name="Gather_k_4",
axis=0,
)
unsqueeze_k_1 = helper.make_node(
"Unsqueeze",
inputs=["present_value", "zero"],
outputs=["unsqueeze_k1_out"],
name="Unsqueeze_k1",
)
unsqueeze_k_2 = helper.make_node(
"Unsqueeze",
inputs=["gather_k1_out", "zero"],
outputs=["unsqueeze_k2_out"],
name="Unsqueeze_k2",
)
unsqueeze_k_3 = helper.make_node(
"Unsqueeze",
inputs=["gather_k2_out", "zero"],
outputs=["unsqueeze_k3_out"],
name="Unsqueeze_k3",
)
unsqueeze_k_4 = helper.make_node(
"Unsqueeze",
inputs=["gather_k3_out", "zero"],
outputs=["unsqueeze_k4_out"],
name="Unsqueeze_k4",
)
unsqueeze_k_5 = helper.make_node(
"Unsqueeze",
inputs=["gather_k4_out", "zero"],
outputs=["unsqueeze_k5_out"],
name="Unsqueeze_k5",
)
concat_k_2 = helper.make_node(
"Concat",
inputs=["unsqueeze_k2_out", "unsqueeze_k3_out", "One", "unsqueeze_k4_out", "unsqueeze_k5_out"],
outputs=["concat_k2_ouot"],
name="Concat_k2",
axis=0,
)
reshape_k_2 = helper.make_node(
"Reshape",
inputs=["concat_k2_ouot", "One"],
outputs=["reshape_k2_out"],
name="Reshape_k_2",
)
shape_k5 = helper.make_node("Shape", inputs=["reshape_k2_out"], outputs=["shape_k5_out"], name="Shape_k5")
constant_of_shape_k_1 = helper.make_node(
"ConstantOfShape",
inputs=["shape_k5_out"],
outputs=["constant_of_shape_k1_out"],
name="ConstantOfShape_k1",
)
mul_k_1 = helper.make_node(
"Mul",
inputs=["constant_of_shape_k1_out", "One"],
outputs=["mul_k1_out"],
name="mul_k1",
)
equal_k_1 = helper.make_node(
"Equal",
inputs=["reshape_k2_out", "mul_k1_out"],
outputs=["equal_k_1_out"],
name="equal_k1",
)
where_k_1 = helper.make_node(
"Where",
inputs=["equal_k_1_out", "constant_of_shape_k1_out", "reshape_k2_out"],
outputs=["where_k_1_out"],
name="where_k1",
)
unsqueeze_k_6 = helper.make_node(
"Unsqueeze",
inputs=["gather_k1_out", "zero"],
outputs=["unsqueeze_k6_out"],
name="Unsqueeze_k6",
)
mul_k_2 = helper.make_node(
"Mul",
inputs=["gather_k2_out", "One"],
outputs=["mul_k2_out"],
name="mul_k2",
)
unsqueeze_k_7 = helper.make_node(
"Unsqueeze",
inputs=["mul_k2_out", "zero"],
outputs=["unsqueeze_k7_out"],
name="Unsqueeze_k7",
)
unsqueeze_k_8 = helper.make_node(
"Unsqueeze",
inputs=["gather_k3_out", "zero"],
outputs=["unsqueeze_k8_out"],
name="Unsqueeze_k8",
)
unsqueeze_k_9 = helper.make_node(
"Unsqueeze",
inputs=["gather_k4_out", "zero"],
outputs=["unsqueeze_k9_out"],
name="Unsqueeze_k9",
)
concat_k_3 = helper.make_node(
"Concat",
inputs=["unsqueeze_k6_out", "unsqueeze_k7_out", "unsqueeze_k8_out", "unsqueeze_k9_out"],
outputs=["concat_k3_out"],
name="Concat_k3",
axis=0,
)
expand_k_1 = helper.make_node(
"Expand",
inputs=["unsqueeze_k1_out", "where_k_1_out"],
outputs=["expand_k1_out"],
name="expand_k1",
)
reshape_k_3 = helper.make_node(
"Reshape",
inputs=["expand_k1_out", "concat_k3_out"],
outputs=["reshape_k3_out"],
name="Reshape_k_3",
)
transpose_k_2_node = helper.make_node(
"Transpose",
inputs=["reshape_k3_out"],
outputs=["k"],
name="Transpose_k_2",
perm=[0, 1, 3, 2],
)
k_nodes_for_70b_model = [
concat_k_node,
shape_k1,
shape_k2,
shape_k3,
shape_k4,
gather_k_1,
gather_k_2,
gather_k_3,
gather_k_4,
unsqueeze_k_1,
unsqueeze_k_2,
unsqueeze_k_3,
unsqueeze_k_4,
unsqueeze_k_5,
concat_k_2,
reshape_k_2,
shape_k5,
constant_of_shape_k_1,
mul_k_1,
equal_k_1,
where_k_1,
unsqueeze_k_6,
mul_k_2,
unsqueeze_k_7,
unsqueeze_k_8,
unsqueeze_k_9,
concat_k_3,
expand_k_1,
reshape_k_3,
transpose_k_2_node,
]
k_nodes.extend(k_nodes_for_70b_model)
return k_nodes
else:
if model_type in {"past", "merged"}:
concat_k_node = helper.make_node(
"Concat",
inputs=["past_key", "k_rope"],
outputs=["present_key"],
axis=2,
)
k_nodes.append(concat_k_node)
transpose_k_2_node = helper.make_node(
"Transpose",
inputs=["present_key"],
outputs=["k"],
name="Transpose_k_2",
perm=[0, 1, 3, 2],
)
return k_nodes + [transpose_k_2_node] # noqa: RUF005
def create_k_path(self, model_type: str):
if model_type == "llama2_msft":
@ -505,7 +707,7 @@ class TestRotaryAttentionFusion(unittest.TestCase):
if model_type == "no_past":
return v_nodes
if model_type in {"past", "merged"}:
if model_type in {"past", "merged", "70b_distributed_merged"}:
concat_v_node = helper.make_node(
"Concat",
inputs=["past_value", "transpose_v_1_out"],
@ -513,7 +715,194 @@ class TestRotaryAttentionFusion(unittest.TestCase):
name="Concat_v",
axis=2,
)
return v_nodes + [concat_v_node] # noqa: RUF005
if model_type != "70b_distributed_merged":
return v_nodes + [concat_v_node] # noqa: RUF005
shape_v1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_1_out"], name="Shape_v1")
shape_v2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_2_out"], name="Shape_v2")
shape_v3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_3_out"], name="Shape_v3")
shape_v4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_4_out"], name="Shape_v4")
gather_v_1 = helper.make_node(
"Gather",
inputs=["shape_1_out", "one"],
outputs=["gather_1_out"],
name="Gather_v1",
axis=0,
)
gather_v_2 = helper.make_node(
"Gather",
inputs=["shape_2_out", "one"],
outputs=["gather_2_out"],
name="Gather_v2",
axis=0,
)
gather_v_3 = helper.make_node(
"Gather",
inputs=["shape_3_out", "one"],
outputs=["gather_3_out"],
name="Gather_v3",
axis=0,
)
gather_v_4 = helper.make_node(
"Gather",
inputs=["shape_4_out", "one"],
outputs=["gather_4_out"],
name="Gather_v4",
axis=0,
)
unsqueeze_v_1 = helper.make_node(
"Unsqueeze",
inputs=["present_value", "zero"],
outputs=["unsqueeze_v1_out"],
name="Unsqueeze_v1",
)
unsqueeze_v_2 = helper.make_node(
"Unsqueeze",
inputs=["gather_1_out", "zero"],
outputs=["unsqueeze_v2_out"],
name="Unsqueeze_v2",
)
unsqueeze_v_3 = helper.make_node(
"Unsqueeze",
inputs=["gather_2_out", "zero"],
outputs=["unsqueeze_v3_out"],
name="Unsqueeze_v3",
)
unsqueeze_v_4 = helper.make_node(
"Unsqueeze",
inputs=["gather_3_out", "zero"],
outputs=["unsqueeze_v4_out"],
name="Unsqueeze_v4",
)
unsqueeze_v_5 = helper.make_node(
"Unsqueeze",
inputs=["gather_4_out", "zero"],
outputs=["unsqueeze_v5_out"],
name="Unsqueeze_v5",
)
concat_v_2 = helper.make_node(
"Concat",
inputs=["unsqueeze_v2_out", "unsqueeze_v3_out", "One", "unsqueeze_v4_out", "unsqueeze_v5_out"],
outputs=["concat_v2_ouot"],
name="Concat_v2",
axis=0,
)
reshape_v_2 = helper.make_node(
"Reshape",
inputs=["concat_v2_ouot", "One"],
outputs=["reshape_v2_out"],
name="Reshape_v2",
)
shape_v5 = helper.make_node("Shape", inputs=["reshape_v2_out"], outputs=["shape_5_out"], name="Shape_v5")
constant_of_shape_v_1 = helper.make_node(
"ConstantOfShape",
inputs=["shape_5_out"],
outputs=["constant_of_shape_v1_out"],
name="ConstantOfShape_v1",
)
mul_v_1 = helper.make_node(
"Mul",
inputs=["constant_of_shape_v1_out", "One"],
outputs=["mul_v1_out"],
name="mul_v1",
)
equal_v_1 = helper.make_node(
"Equal",
inputs=["reshape_v2_out", "mul_v1_out"],
outputs=["equal_v_1_out"],
name="equal_v1",
)
where_v_1 = helper.make_node(
"Where",
inputs=["equal_v_1_out", "constant_of_shape_v1_out", "reshape_v2_out"],
outputs=["where_v_1_out"],
name="where_v1",
)
unsqueeze_v_6 = helper.make_node(
"Unsqueeze",
inputs=["gather_1_out", "zero"],
outputs=["unsqueeze_v6_out"],
name="Unsqueeze_v6",
)
mul_v_2 = helper.make_node(
"Mul",
inputs=["gather_2_out", "One"],
outputs=["mul_v2_out"],
name="mul_v2",
)
unsqueeze_v_7 = helper.make_node(
"Unsqueeze",
inputs=["mul_v2_out", "zero"],
outputs=["unsqueeze_v7_out"],
name="Unsqueeze_v7",
)
unsqueeze_v_8 = helper.make_node(
"Unsqueeze",
inputs=["gather_3_out", "zero"],
outputs=["unsqueeze_v8_out"],
name="Unsqueeze_v8",
)
unsqueeze_v_9 = helper.make_node(
"Unsqueeze",
inputs=["gather_4_out", "zero"],
outputs=["unsqueeze_v9_out"],
name="Unsqueeze_v9",
)
concat_v_3 = helper.make_node(
"Concat",
inputs=["unsqueeze_v6_out", "unsqueeze_v7_out", "unsqueeze_v8_out", "unsqueeze_v9_out"],
outputs=["concat_v3_out"],
name="Concat_v3",
axis=0,
)
expand_v_1 = helper.make_node(
"Expand",
inputs=["unsqueeze_v1_out", "where_v_1_out"],
outputs=["expand_v1_out"],
name="expand_v1",
)
reshape_v_3 = helper.make_node(
"Reshape",
inputs=["expand_v1_out", "concat_v3_out"],
outputs=["reshape_v3_out"],
name="Reshape_v3",
)
v_nodes_for_70b_model = [
concat_v_node,
shape_v1,
shape_v2,
shape_v3,
shape_v4,
gather_v_1,
gather_v_2,
gather_v_3,
gather_v_4,
unsqueeze_v_1,
unsqueeze_v_2,
unsqueeze_v_3,
unsqueeze_v_4,
unsqueeze_v_5,
concat_v_2,
reshape_v_2,
shape_v5,
constant_of_shape_v_1,
mul_v_1,
equal_v_1,
where_v_1,
unsqueeze_v_6,
mul_v_2,
unsqueeze_v_7,
unsqueeze_v_8,
unsqueeze_v_9,
concat_v_3,
expand_v_1,
reshape_v_3,
]
v_nodes.extend(v_nodes_for_70b_model)
return v_nodes
# Create extra nodes for `position_ids`
unsqueeze_v_node = helper.make_node(
@ -672,7 +1061,28 @@ class TestRotaryAttentionFusion(unittest.TestCase):
return extra_nodes
def create_end_nodes(self):
def create_end_nodes(self, model_type):
if model_type == "70b_distributed_merged":
matmul_o_node = helper.make_node(
"MatMul",
inputs=["attn_output", "o_weight"],
outputs=["output_proj"],
name="MatMul_o_proj",
)
all_reduce = helper.make_node(
"AllReduce",
inputs=["output_proj"],
outputs=["allreduce_proj"],
name="allreduce_proj",
)
end_node = helper.make_node(
"Add",
inputs=["zero", "allreduce_proj"],
outputs=["output_0"],
name="Add_normalize_node",
)
return [matmul_o_node, all_reduce, end_node]
matmul_o_node = helper.make_node(
"MatMul",
inputs=["attn_output", "o_weight"],
@ -711,7 +1121,7 @@ class TestRotaryAttentionFusion(unittest.TestCase):
num_heads=self.num_heads,
)
end_nodes = self.create_end_nodes()
end_nodes = self.create_end_nodes(model_type)
graph = helper.make_graph(
nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes,
@ -740,7 +1150,7 @@ class TestRotaryAttentionFusion(unittest.TestCase):
reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes))
extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes)
end_nodes = self.create_end_nodes()
end_nodes = self.create_end_nodes(model_type)
first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes
second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes
@ -790,6 +1200,11 @@ class TestRotaryAttentionFusion(unittest.TestCase):
interleaved = False
self.check_models(model_type, interleaved)
def test_hf_70b_distributed_decoder_merged_model(self):
model_type = "70b_distributed_merged"
interleaved = False
self.check_models(model_type, interleaved)
if __name__ == "__main__":
unittest.main()