mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Cherry pick LLaMA to rel-1.16.2 (round 2) (#18245)
2nd round of cherry pick LLaMA related changes to 1.16.2 release. --------- Co-authored-by: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Co-authored-by: Frank Dong <123416088+frank-dong-ms@users.noreply.github.com>
This commit is contained in:
parent
2f57f1e4d7
commit
70b8cda979
31 changed files with 2088 additions and 550 deletions
|
|
@ -113,6 +113,11 @@ set(contrib_ops_excluded_files
|
|||
"cuda_contrib_kernels.h"
|
||||
"inverse.cc"
|
||||
"fused_conv.cc"
|
||||
"bert/group_query_attention_helper.h"
|
||||
"bert/group_query_attention.h"
|
||||
"bert/group_query_attention.cc"
|
||||
"bert/group_query_attention_impl.h"
|
||||
"bert/group_query_attention_impl.cu"
|
||||
)
|
||||
|
||||
if (NOT onnxruntime_ENABLE_ATEN)
|
||||
|
|
|
|||
|
|
@ -2265,14 +2265,14 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs (1 - 3)
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>output</tt> : T</dt>
|
||||
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
|
||||
<dt><tt>present_key</tt> (optional) : T</dt>
|
||||
<dt><tt>present_key</tt> : T</dt>
|
||||
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
|
||||
<dt><tt>present_value</tt> (optional) : T</dt>
|
||||
<dt><tt>present_value</tt> : T</dt>
|
||||
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
|
||||
</dl>
|
||||
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ Status EfficientAttention(
|
|||
p.num_heads = parameters.num_heads;
|
||||
p.sequence_length = parameters.sequence_length;
|
||||
p.kv_sequence_length = parameters.total_sequence_length;
|
||||
p.max_sequence_length = parameters.total_sequence_length;
|
||||
p.qk_head_size = parameters.head_size;
|
||||
p.v_head_size = parameters.v_head_size;
|
||||
p.causal = parameters.is_unidirectional;
|
||||
|
|
@ -395,6 +396,7 @@ Status EfficientAttention(
|
|||
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
|
||||
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
|
||||
p.output = data.output;
|
||||
p.is_kv_bsnh = true;
|
||||
p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
|
||||
? data.scratch
|
||||
: nullptr;
|
||||
|
|
|
|||
|
|
@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
|
|||
p.num_keys = params.kv_sequence_length;
|
||||
|
||||
if (params.causal) {
|
||||
p.custom_mask_type = Attention::CausalFromTopLeft;
|
||||
p.custom_mask_type = Attention::CausalFromBottomRight;
|
||||
}
|
||||
|
||||
// Input format is BxSxNxH, output is BxSxNxH
|
||||
p.q_strideH = params.qk_head_size;
|
||||
p.k_strideH = params.qk_head_size;
|
||||
p.v_strideH = params.v_head_size;
|
||||
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
|
||||
// We use max_sequence_length to calculate KV stride
|
||||
if (params.is_kv_bsnh) {
|
||||
// Input Q, K, V format is BxSxNxH, output is BxSxNxH
|
||||
p.q_strideH = params.qk_head_size;
|
||||
p.k_strideH = params.qk_head_size;
|
||||
p.v_strideH = params.v_head_size;
|
||||
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
|
||||
|
||||
p.q_strideM = params.num_heads * params.qk_head_size;
|
||||
p.k_strideM = params.num_heads * params.qk_head_size;
|
||||
p.v_strideM = params.num_heads * params.v_head_size;
|
||||
p.o_strideM = params.num_heads * params.v_head_size;
|
||||
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
|
||||
p.q_strideM = params.num_heads * params.qk_head_size;
|
||||
p.k_strideM = params.num_heads * params.qk_head_size;
|
||||
p.v_strideM = params.num_heads * params.v_head_size;
|
||||
p.o_strideM = params.num_heads * params.v_head_size;
|
||||
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
|
||||
|
||||
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
|
||||
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.kv_sequence_length;
|
||||
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
|
||||
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
|
||||
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
|
||||
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.max_sequence_length;
|
||||
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.max_sequence_length;
|
||||
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
|
||||
} else {
|
||||
// Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH
|
||||
p.q_strideH = params.qk_head_size;
|
||||
p.k_strideH = params.max_sequence_length * params.qk_head_size;
|
||||
p.v_strideH = params.max_sequence_length * params.v_head_size;
|
||||
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
|
||||
|
||||
p.q_strideM = params.num_heads * params.qk_head_size;
|
||||
p.k_strideM = params.qk_head_size;
|
||||
p.v_strideM = params.v_head_size;
|
||||
p.o_strideM = params.num_heads * params.v_head_size;
|
||||
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
|
||||
|
||||
p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length;
|
||||
p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length;
|
||||
p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length;
|
||||
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
|
||||
|
|
|
|||
|
|
@ -14,10 +14,12 @@ namespace cuda {
|
|||
struct MemoryEfficientAttentionParams {
|
||||
int32_t sm;
|
||||
bool is_half;
|
||||
bool is_kv_bsnh = true;
|
||||
int32_t batch_size;
|
||||
int32_t num_heads;
|
||||
int32_t sequence_length;
|
||||
int32_t kv_sequence_length;
|
||||
int32_t max_sequence_length;
|
||||
int32_t qk_head_size;
|
||||
int32_t v_head_size;
|
||||
bool causal;
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@
|
|||
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/group_query_attention.h"
|
||||
#include "contrib_ops/cuda/bert/group_query_attention_helper.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
|
||||
// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
// #include "contrib_ops/cpu/utils/console_dumper.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace ::onnxruntime::common;
|
||||
|
|
@ -55,6 +54,13 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
|
|||
#else
|
||||
disable_flash_attention_ = true;
|
||||
#endif
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
disable_memory_efficient_attention_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
|
||||
#else
|
||||
disable_memory_efficient_attention_ = true;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -92,18 +98,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
output_shape[2] = static_cast<int64_t>(parameters.hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
std::vector<int64_t> present_dims;
|
||||
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
|
||||
present_dims = {
|
||||
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
|
||||
} else { // BNSH
|
||||
present_dims = {
|
||||
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
|
||||
}
|
||||
TensorShape present_shape(present_dims);
|
||||
Tensor* present_key = context->Output(1, present_shape);
|
||||
Tensor* present_value = context->Output(2, present_shape);
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
bool use_flash_attention = !disable_flash_attention_ &&
|
||||
onnxruntime::flash::is_supported(device_prop,
|
||||
|
|
@ -143,8 +137,47 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
auto seqlens_k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
|
||||
#endif
|
||||
|
||||
// only kernel implemented for gqa right now
|
||||
ORT_ENFORCE(use_flash_attention);
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
int sm = (device_prop.major * 10) + device_prop.minor;
|
||||
bool use_memory_efficient_attention =
|
||||
!use_flash_attention &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
(parameters.head_size & 7) == 0 &&
|
||||
parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length &&
|
||||
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2);
|
||||
// allocate buffers
|
||||
size_t kv_buffer_bytes = 0;
|
||||
// need a buffer if we must ungroup kv
|
||||
const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads);
|
||||
if (use_memory_efficient_attention && needs_buff) {
|
||||
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size);
|
||||
}
|
||||
size_t fmha_buffer_bytes = 0;
|
||||
if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
|
||||
fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));
|
||||
}
|
||||
auto k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
|
||||
auto v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
|
||||
auto fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
|
||||
#else
|
||||
constexpr bool use_memory_efficient_attention = false;
|
||||
auto k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
|
||||
auto v_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
|
||||
auto fmha_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
|
||||
#endif
|
||||
|
||||
std::vector<int64_t> present_dims;
|
||||
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
|
||||
present_dims = {
|
||||
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
|
||||
} else { // BNSH
|
||||
present_dims = {
|
||||
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
|
||||
}
|
||||
TensorShape present_shape(present_dims);
|
||||
Tensor* present_key = context->Output(1, present_shape);
|
||||
Tensor* present_value = context->Output(2, present_shape);
|
||||
|
||||
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
|
||||
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
|
||||
|
|
@ -155,6 +188,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
|
||||
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
|
||||
data.use_flash_attention = use_flash_attention;
|
||||
data.use_memory_efficient_attention = use_memory_efficient_attention;
|
||||
if (softmax_lse_buffer != nullptr) {
|
||||
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
|
||||
}
|
||||
|
|
@ -167,6 +201,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
if (seqlens_k_buffer != nullptr) {
|
||||
data.seqlens_k = reinterpret_cast<int*>(seqlens_k_buffer.get());
|
||||
}
|
||||
if (k_buffer != nullptr) {
|
||||
data.k = reinterpret_cast<CudaT*>(k_buffer.get());
|
||||
data.v = reinterpret_cast<CudaT*>(v_buffer.get());
|
||||
}
|
||||
if (fmha_buffer != nullptr) {
|
||||
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
|
||||
}
|
||||
|
||||
cublasHandle_t cublas = GetCublasHandle(context);
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel {
|
|||
bool is_past_bsnh_;
|
||||
float scale_;
|
||||
bool disable_flash_attention_;
|
||||
bool disable_memory_efficient_attention_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query,
|
|||
// query (Q) : (B, S, D)
|
||||
// key (K) : (B, S+, D_kv)
|
||||
// value (V) : (B, S+, D_kv)
|
||||
ORT_UNUSED_PARAMETER(value);
|
||||
|
||||
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
|
||||
AttentionQkvFormat past_kv_format = Q_K_V_BSNH;
|
||||
|
||||
const auto& query_dims = query->Shape().GetDims();
|
||||
const auto& key_dims = key->Shape().GetDims();
|
||||
const auto& value_dims = value->Shape().GetDims();
|
||||
|
||||
if (query_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
|
||||
|
|
@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query,
|
|||
int q_hidden_size = static_cast<int>(query_dims[2]);
|
||||
int head_size = static_cast<int>(q_hidden_size) / num_heads;
|
||||
|
||||
int kv_sequence_length = sequence_length;
|
||||
int kv_hidden_size = (key_dims.size() == 3)
|
||||
? static_cast<int>(key_dims[2])
|
||||
: (kv_num_heads * static_cast<int>(key_dims[3]));
|
||||
int kv_sequence_length = static_cast<int>(key_dims[1]);
|
||||
int kv_hidden_size = static_cast<int>(key_dims[2]);
|
||||
|
||||
int max_sequence_length = 0;
|
||||
if (past_key != nullptr && past_value != nullptr) {
|
||||
|
|
@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query,
|
|||
"Input 'past_key' and 'past_value' shall be both present or both absent");
|
||||
}
|
||||
|
||||
if (key != nullptr) {
|
||||
const auto& key_dims = key->Shape().GetDims();
|
||||
if (key_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
|
||||
key_dims.size());
|
||||
}
|
||||
if (query_dims[0] != key_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'key' shall have same dim 0 (batch size)");
|
||||
}
|
||||
|
||||
if (num_heads % kv_num_heads != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
|
||||
num_heads % kv_num_heads);
|
||||
}
|
||||
if (key_dims[2] != value_dims[2]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)");
|
||||
}
|
||||
|
||||
qkv_format = Q_K_V_BSNH;
|
||||
kv_sequence_length = static_cast<int>(key_dims[1]);
|
||||
} else {
|
||||
if (key_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
|
||||
key_dims.size());
|
||||
}
|
||||
if (query_dims[0] != key_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Missing key tensor.");
|
||||
"Input 'query' and 'key' shall have same dim 0 (batch size)");
|
||||
}
|
||||
|
||||
if (value != nullptr) {
|
||||
const auto& value_dims = value->Shape().GetDims();
|
||||
if (value_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
|
||||
value_dims.size());
|
||||
}
|
||||
|
||||
if (query_dims[0] != value_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
|
||||
}
|
||||
|
||||
if (value_dims[2] != kv_hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
|
||||
}
|
||||
} else {
|
||||
if (num_heads % kv_num_heads != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Missing value tensor.");
|
||||
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
|
||||
num_heads % kv_num_heads);
|
||||
}
|
||||
|
||||
const auto& value_dims = value->Shape().GetDims();
|
||||
if (value_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
|
||||
value_dims.size());
|
||||
}
|
||||
|
||||
if (query_dims[0] != value_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
|
||||
}
|
||||
|
||||
if (value_dims[2] != kv_hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
|
||||
}
|
||||
|
||||
// When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly.
|
||||
int32_t past_sequence_length = 0;
|
||||
int present_sequence_length = 0;
|
||||
int present_sequence_length = kv_sequence_length;
|
||||
if (past_seq_len != nullptr) {
|
||||
if (past_key == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Past KV must be present as share-buffer when using past_seq_len pointer.");
|
||||
}
|
||||
if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"past_sequence_length tensor must be of one element when using past kv.");
|
||||
|
|
@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query,
|
|||
} else {
|
||||
past_sequence_length = static_cast<int32_t>(*((*past_seq_len).template Data<int64_t>()));
|
||||
}
|
||||
if (past_sequence_length + kv_sequence_length > max_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length");
|
||||
}
|
||||
present_sequence_length = max_sequence_length;
|
||||
} else if (past_key != nullptr) {
|
||||
past_sequence_length = max_sequence_length; // this is the length of past_key tensor
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#include "contrib_ops/cuda/bert/bert_padding.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
|
||||
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
|
|
@ -47,6 +48,8 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
////////// Auxiliary Kernels for KV prep
|
||||
|
||||
// Kernel for seqlens_k
|
||||
__global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) {
|
||||
int id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
|
@ -75,7 +78,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen,
|
|||
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
|
||||
|
||||
// past_kv: BPNH or BNPH
|
||||
// new_kv: BLNH or BNLH
|
||||
// new_kv: BLNH
|
||||
// present_kv: BTNH or BNTH, where T = P + L
|
||||
const int past_seqlen = present_seqlen - new_seqlen;
|
||||
|
||||
|
|
@ -95,33 +98,32 @@ __global__ void ConcatNewToPastKV(const int new_seqlen,
|
|||
}
|
||||
}
|
||||
|
||||
// Use when (H*)*num_heads > 1024
|
||||
template <typename T>
|
||||
__global__ void ConcatNewToPastKVLarge(const int new_seqlen,
|
||||
const int H,
|
||||
const int num_heads,
|
||||
const T* past_kv,
|
||||
const T* new_kv,
|
||||
T* present_kv,
|
||||
const bool is_bsnh) {
|
||||
// Use when (H*)*num_heads > 1024
|
||||
int h = threadIdx.x;
|
||||
const int n = threadIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
int i = threadIdx.x + (blockDim.x * blockIdx.x);
|
||||
if (i < H * num_heads) {
|
||||
const int h = i % H;
|
||||
const int n = i / H;
|
||||
const int s = blockIdx.y;
|
||||
const int b = blockIdx.z;
|
||||
const int present_seqlen = gridDim.y;
|
||||
|
||||
const int present_seqlen = gridDim.x;
|
||||
const int num_heads = blockDim.y;
|
||||
const int thread_stride = blockDim.x;
|
||||
const int present_batch_stride = present_seqlen * num_heads * H;
|
||||
const int row_stride = is_bsnh ? num_heads * H : H;
|
||||
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
|
||||
|
||||
const int present_batch_stride = present_seqlen * num_heads * H;
|
||||
const int row_stride = is_bsnh ? num_heads * H : H;
|
||||
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
|
||||
// past_kv: BPNH or BNPH
|
||||
// new_kv: BLNH
|
||||
// present_kv: BTNH or BNTH, where T = P + L
|
||||
const int past_seqlen = present_seqlen - new_seqlen;
|
||||
|
||||
// past_kv: BPNH or BNPH
|
||||
// new_kv: BLNH or BNLH
|
||||
// present_kv: BTNH or BNTH, where T = P + L
|
||||
const int past_seqlen = present_seqlen - new_seqlen;
|
||||
|
||||
while (h < H) {
|
||||
int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h;
|
||||
if (s < past_seqlen) {
|
||||
const int past_batch_stride = past_seqlen * num_heads * H;
|
||||
|
|
@ -135,21 +137,296 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen,
|
|||
const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h;
|
||||
present_kv[out_offset] = new_kv[in_offset];
|
||||
}
|
||||
h += thread_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Concat new to past in present. Supports past BSNH or past BNSH
|
||||
template <typename T>
|
||||
Status QkvToContext(
|
||||
const cudaDeviceProp& device_prop,
|
||||
cublasHandle_t& cublas,
|
||||
Stream* ort_stream,
|
||||
contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<T>& data) {
|
||||
assert(data.use_flash_attention);
|
||||
Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int present_sequence_length = parameters.present_sequence_length;
|
||||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
|
||||
|
||||
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time.
|
||||
if (H * kv_num_heads <= max_threads_per_block) {
|
||||
const dim3 grid(present_sequence_length, batch_size, 1);
|
||||
const dim3 block(H, kv_num_heads, 1);
|
||||
ConcatNewToPastKV<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
|
||||
reinterpret_cast<const float2*>(data.past_key),
|
||||
reinterpret_cast<const float2*>(data.key),
|
||||
reinterpret_cast<float2*>(data.present_key),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
ConcatNewToPastKV<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
|
||||
reinterpret_cast<const float2*>(data.past_value),
|
||||
reinterpret_cast<const float2*>(data.value),
|
||||
reinterpret_cast<float2*>(data.present_value),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
} else {
|
||||
int steps = (H * kv_num_heads + 255) / 256;
|
||||
const dim3 grid(steps, present_sequence_length, batch_size);
|
||||
const dim3 block(256, 1, 1);
|
||||
ConcatNewToPastKVLarge<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
|
||||
H,
|
||||
kv_num_heads,
|
||||
reinterpret_cast<const float2*>(data.past_key),
|
||||
reinterpret_cast<const float2*>(data.key),
|
||||
reinterpret_cast<float2*>(data.present_key),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
ConcatNewToPastKVLarge<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
|
||||
H,
|
||||
kv_num_heads,
|
||||
reinterpret_cast<const float2*>(data.past_value),
|
||||
reinterpret_cast<const float2*>(data.value),
|
||||
reinterpret_cast<float2*>(data.present_value),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
}
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
// Kernel to append new kv to kv buffer in place
|
||||
template <typename T>
|
||||
__global__ void ConcatKVInPlace(const int past_seqlen,
|
||||
const int present_seqlen,
|
||||
T* kv_buff,
|
||||
const T* new_kv,
|
||||
const bool is_bsnh) { // refers to kv buff; otherwise bnsh
|
||||
const int h = threadIdx.x;
|
||||
const int n = threadIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
|
||||
const int new_seqlen = gridDim.x;
|
||||
const int num_heads = blockDim.y;
|
||||
const int H = blockDim.x;
|
||||
|
||||
const int present_batch_stride = present_seqlen * num_heads * H;
|
||||
const int present_row_stride = is_bsnh ? num_heads * H : H;
|
||||
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
|
||||
|
||||
// kv_buff: BTNH or BNTH with buffered memory for new
|
||||
// new_kv: BLNH
|
||||
|
||||
int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h;
|
||||
// Note: new KV always BSNH
|
||||
const int new_batch_stride = new_seqlen * num_heads * H;
|
||||
const int new_row_stride = num_heads * H;
|
||||
const int new_head_stride = H;
|
||||
const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
|
||||
kv_buff[out_offset] = new_kv[in_offset];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ConcatKVInPlaceLarge(const int past_seqlen,
|
||||
const int present_seqlen,
|
||||
const int H,
|
||||
const int num_heads,
|
||||
T* kv_buff,
|
||||
const T* new_kv,
|
||||
const bool is_bsnh) { // refers to kv buff; otherwise bnsh
|
||||
int i = threadIdx.x + (blockDim.x * blockIdx.x);
|
||||
if (i < H * num_heads) {
|
||||
const int h = i % H;
|
||||
const int n = i / H;
|
||||
const int s = blockIdx.y;
|
||||
const int b = blockIdx.z;
|
||||
const int new_seqlen = gridDim.y;
|
||||
|
||||
const int present_batch_stride = present_seqlen * num_heads * H;
|
||||
const int present_row_stride = is_bsnh ? num_heads * H : H;
|
||||
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
|
||||
|
||||
// kv_buff: BTNH or BNTH with buffered memory for new
|
||||
// new_kv: BLNH
|
||||
|
||||
int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h;
|
||||
// Note: new KV always BSNH
|
||||
const int new_batch_stride = new_seqlen * num_heads * H;
|
||||
const int new_row_stride = num_heads * H;
|
||||
const int new_head_stride = H;
|
||||
const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
|
||||
kv_buff[out_offset] = new_kv[in_offset];
|
||||
}
|
||||
}
|
||||
|
||||
// Concat new to kv buffer in place
|
||||
template <typename T>
|
||||
Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int present_sequence_length = parameters.present_sequence_length;
|
||||
const int past_sequence_length = parameters.past_sequence_length;
|
||||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
|
||||
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
const int H = head_size / 4;
|
||||
if (H * kv_num_heads <= max_threads_per_block) {
|
||||
const dim3 grid(kv_sequence_length, batch_size, 1);
|
||||
const dim3 block(H, kv_num_heads, 1);
|
||||
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(past_sequence_length,
|
||||
present_sequence_length,
|
||||
reinterpret_cast<float2*>(data.present_key),
|
||||
reinterpret_cast<const float2*>(data.key),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(past_sequence_length,
|
||||
present_sequence_length,
|
||||
reinterpret_cast<float2*>(data.present_value),
|
||||
reinterpret_cast<const float2*>(data.value),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
} else {
|
||||
int steps = int(ceil(float(H * kv_num_heads) / 256.0));
|
||||
const dim3 grid(steps, kv_sequence_length, batch_size);
|
||||
const dim3 block(256, 1, 1);
|
||||
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(past_sequence_length,
|
||||
present_sequence_length,
|
||||
H,
|
||||
kv_num_heads,
|
||||
reinterpret_cast<float2*>(data.present_key),
|
||||
reinterpret_cast<const float2*>(data.key),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(past_sequence_length,
|
||||
present_sequence_length,
|
||||
H,
|
||||
kv_num_heads,
|
||||
reinterpret_cast<float2*>(data.present_value),
|
||||
reinterpret_cast<const float2*>(data.value),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
}
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh
|
||||
template <typename T>
|
||||
__global__ void Ungroup(const T* kv_in,
|
||||
T* kv_out,
|
||||
const int in_seqlen,
|
||||
const int kv_num_heads,
|
||||
const bool is_bsnh) {
|
||||
const int h = threadIdx.x;
|
||||
const int out_n = threadIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
|
||||
const int out_seqlen = gridDim.x;
|
||||
const int q_num_heads = blockDim.y;
|
||||
const int H = blockDim.x;
|
||||
|
||||
const int q_kv_head_ratio = q_num_heads / kv_num_heads;
|
||||
const int out_batch_stride = out_seqlen * q_num_heads * H;
|
||||
const int out_row_stride = is_bsnh ? q_num_heads * H : H;
|
||||
const int out_head_stride = is_bsnh ? H : out_seqlen * H;
|
||||
|
||||
const int in_batch_stride = in_seqlen * kv_num_heads * H;
|
||||
const int in_row_stride = is_bsnh ? kv_num_heads * H : H;
|
||||
const int in_head_stride = is_bsnh ? H : in_seqlen * H;
|
||||
const int in_n = out_n / q_kv_head_ratio;
|
||||
|
||||
const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h;
|
||||
const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h;
|
||||
kv_out[out_offset] = kv_in[in_offset];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void UngroupLarge(const T* kv_in,
|
||||
T* kv_out,
|
||||
const int H,
|
||||
const int in_seqlen,
|
||||
const int q_num_heads,
|
||||
const int kv_num_heads,
|
||||
const bool is_bsnh) {
|
||||
int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements
|
||||
if (i < H * q_num_heads) {
|
||||
const int out_seqlen = gridDim.y;
|
||||
const int s = blockIdx.y;
|
||||
const int b = blockIdx.z;
|
||||
|
||||
const int q_kv_head_ratio = q_num_heads / kv_num_heads;
|
||||
const int out_batch_stride = out_seqlen * q_num_heads * H;
|
||||
const int out_row_stride = is_bsnh ? q_num_heads * H : H;
|
||||
const int out_head_stride = is_bsnh ? H : out_seqlen * H;
|
||||
|
||||
const int in_batch_stride = in_seqlen * kv_num_heads * H;
|
||||
const int in_row_stride = is_bsnh ? kv_num_heads * H : H;
|
||||
const int in_head_stride = is_bsnh ? H : in_seqlen * H;
|
||||
|
||||
const int h = i % H;
|
||||
const int out_n = i / H;
|
||||
const int in_n = out_n / q_kv_head_ratio;
|
||||
const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h;
|
||||
const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h;
|
||||
kv_out[out_offset] = kv_in[in_offset];
|
||||
}
|
||||
}
|
||||
|
||||
// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it.
|
||||
Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters,
|
||||
float2* k_buff, float2* v_buff,
|
||||
const float2* k_og, const float2* v_og,
|
||||
const int buff_seqlen, const int og_seqlen,
|
||||
const bool is_bsnh,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
|
||||
const int H = head_size / 4;
|
||||
if (H * num_heads <= max_threads_per_block) {
|
||||
const dim3 grid(buff_seqlen, batch_size, 1);
|
||||
const dim3 block(H, num_heads, 1);
|
||||
Ungroup<float2><<<grid, block, 0, stream>>>(k_og,
|
||||
k_buff,
|
||||
og_seqlen,
|
||||
kv_num_heads,
|
||||
is_bsnh);
|
||||
Ungroup<float2><<<grid, block, 0, stream>>>(v_og,
|
||||
v_buff,
|
||||
og_seqlen,
|
||||
kv_num_heads,
|
||||
is_bsnh);
|
||||
} else {
|
||||
int steps = int(ceil(float(H * num_heads) / 256.0));
|
||||
const dim3 grid(steps, buff_seqlen, batch_size);
|
||||
const dim3 block(256, 1, 1);
|
||||
UngroupLarge<float2><<<grid, block, 0, stream>>>(k_og,
|
||||
k_buff,
|
||||
H,
|
||||
og_seqlen,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
is_bsnh);
|
||||
UngroupLarge<float2><<<grid, block, 0, stream>>>(v_og,
|
||||
v_buff,
|
||||
H,
|
||||
og_seqlen,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
is_bsnh);
|
||||
}
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
////////// Launch Kernels
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
|
||||
template <typename T>
|
||||
Status FlashAttention(
|
||||
const cudaDeviceProp& device_prop,
|
||||
cudaStream_t stream,
|
||||
contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<T>& data,
|
||||
float scale) {
|
||||
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
|
|
@ -160,108 +437,177 @@ Status QkvToContext(
|
|||
const int head_size = parameters.head_size;
|
||||
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
|
||||
|
||||
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(head_size)) : parameters.scale;
|
||||
if (data.use_flash_attention) {
|
||||
assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
assert(parameters.num_heads % parameters.kv_num_heads == 0);
|
||||
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
|
||||
void* key = reinterpret_cast<void*>(const_cast<T*>(data.key));
|
||||
void* value = reinterpret_cast<void*>(const_cast<T*>(data.value));
|
||||
|
||||
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
|
||||
void* key = reinterpret_cast<void*>(const_cast<T*>(data.key));
|
||||
void* value = reinterpret_cast<void*>(const_cast<T*>(data.value));
|
||||
bool is_causal = parameters.is_unidirectional;
|
||||
|
||||
bool is_causal = parameters.is_unidirectional;
|
||||
if (data.past_key != nullptr && data.past_key == data.present_key) {
|
||||
// Share buffer case
|
||||
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
|
||||
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
|
||||
|
||||
if (data.past_key == nullptr && data.present_key == nullptr) {
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
|
||||
device_prop, stream, query, key, value, data.output, reinterpret_cast<void*>(data.softmax_lse),
|
||||
parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, head_size,
|
||||
parameters.sequence_length, parameters.kv_sequence_length, scale, is_causal, parameters.num_splits,
|
||||
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum)));
|
||||
// Launch kernel to copy seqlen
|
||||
int thr_per_blk = 256;
|
||||
int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
|
||||
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
|
||||
|
||||
} else if (data.past_key == data.present_key) {
|
||||
// Assume past and present kv share buffer.
|
||||
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
assert(parameters.past_sequence_length >= 0);
|
||||
assert(data.past_value != nullptr);
|
||||
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
|
||||
device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast<void*>(data.softmax_lse),
|
||||
reinterpret_cast<void*>(data.seqlens_k), batch_size, num_heads, kv_num_heads,
|
||||
head_size, sequence_length, present_sequence_length, kv_sequence_length,
|
||||
scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
|
||||
reinterpret_cast<void*>(data.out_accum)));
|
||||
|
||||
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
|
||||
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
|
||||
} else {
|
||||
// Not share buffer or no past (prompt generation)
|
||||
// Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
|
||||
|
||||
// Launch kernel to copy seqlen
|
||||
int thr_per_blk = 256;
|
||||
int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
|
||||
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
|
||||
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
|
||||
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("seqlens_k", data.seqlens_k, 1, batch_size);
|
||||
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
|
||||
device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast<void*>(data.softmax_lse),
|
||||
batch_size, num_heads, kv_num_heads, head_size,
|
||||
sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits,
|
||||
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum), past_bsnh));
|
||||
}
|
||||
|
||||
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
|
||||
device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast<void*>(data.softmax_lse),
|
||||
reinterpret_cast<void*>(data.seqlens_k), batch_size, num_heads, kv_num_heads,
|
||||
head_size, sequence_length, present_sequence_length, kv_sequence_length,
|
||||
scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
|
||||
reinterpret_cast<void*>(data.out_accum)));
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
|
||||
|
||||
} else if (data.present_key != nullptr && (data.past_key != nullptr || kv_sequence_length == present_sequence_length)) {
|
||||
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
// Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
|
||||
if (head_size % 4 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "requires head_size be divisible by 4");
|
||||
}
|
||||
const int H = head_size / 4;
|
||||
if (H * kv_num_heads <= max_threads_per_block) {
|
||||
const dim3 grid(present_sequence_length, batch_size, 1);
|
||||
const dim3 block(H, kv_num_heads, 1);
|
||||
ConcatNewToPastKV<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
|
||||
reinterpret_cast<const float2*>(data.past_key),
|
||||
reinterpret_cast<const float2*>(data.key),
|
||||
reinterpret_cast<float2*>(data.present_key),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
ConcatNewToPastKV<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
|
||||
reinterpret_cast<const float2*>(data.past_value),
|
||||
reinterpret_cast<const float2*>(data.value),
|
||||
reinterpret_cast<float2*>(data.present_value),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
} else {
|
||||
const dim3 grid(present_sequence_length, batch_size, 1);
|
||||
const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1);
|
||||
ConcatNewToPastKVLarge<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
|
||||
H,
|
||||
reinterpret_cast<const float2*>(data.past_key),
|
||||
reinterpret_cast<const float2*>(data.key),
|
||||
reinterpret_cast<float2*>(data.present_key),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
ConcatNewToPastKVLarge<float2><<<grid, block, 0, stream>>>(kv_sequence_length,
|
||||
H,
|
||||
reinterpret_cast<const float2*>(data.past_value),
|
||||
reinterpret_cast<const float2*>(data.value),
|
||||
reinterpret_cast<float2*>(data.present_value),
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
|
||||
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
template <typename T>
|
||||
Status EfficientAttention(
|
||||
const cudaDeviceProp& device_prop,
|
||||
cudaStream_t stream,
|
||||
contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<T>& data,
|
||||
float scale) {
|
||||
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int past_sequence_length = parameters.past_sequence_length;
|
||||
const int present_sequence_length = parameters.present_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
|
||||
|
||||
// Launch kernel to copy seqlen
|
||||
int thr_per_blk = 256;
|
||||
int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
|
||||
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
|
||||
|
||||
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
|
||||
device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast<void*>(data.softmax_lse),
|
||||
batch_size, num_heads, kv_num_heads, head_size,
|
||||
sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits,
|
||||
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum), past_bsnh));
|
||||
const void* query = reinterpret_cast<const void*>(data.query);
|
||||
const void* key = reinterpret_cast<const void*>(data.key);
|
||||
const void* value = reinterpret_cast<const void*>(data.value);
|
||||
if (data.past_key != nullptr) {
|
||||
// Past key case
|
||||
// concatenate new kv to past kv
|
||||
if (data.past_key == data.present_key) {
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
|
||||
}
|
||||
const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
|
||||
if (num_heads == kv_num_heads) {
|
||||
// Use present kv directly if not grouped
|
||||
key = reinterpret_cast<const void*>(data.present_key);
|
||||
value = reinterpret_cast<const void*>(data.present_value);
|
||||
} else {
|
||||
// Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path
|
||||
float2* k_buff = reinterpret_cast<float2*>(data.k);
|
||||
float2* v_buff = reinterpret_cast<float2*>(data.v);
|
||||
const float2* k_og = reinterpret_cast<const float2*>(data.present_key);
|
||||
const float2* v_og = reinterpret_cast<const float2*>(data.present_value);
|
||||
ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, past_sequence_length + kv_sequence_length,
|
||||
present_sequence_length, is_bsnh, stream, max_threads_per_block));
|
||||
key = reinterpret_cast<const void*>(data.k);
|
||||
value = reinterpret_cast<const void*>(data.v);
|
||||
}
|
||||
} else if (num_heads == kv_num_heads) {
|
||||
// no past or present and no need to ungroup... still copy kv into present buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
|
||||
key = reinterpret_cast<const void*>(data.present_key);
|
||||
value = reinterpret_cast<const void*>(data.present_value);
|
||||
} else {
|
||||
// intermediate buffer so q and kv have same num heads... still copy kv into present buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
|
||||
float2* k_buff = reinterpret_cast<float2*>(data.k);
|
||||
float2* v_buff = reinterpret_cast<float2*>(data.v);
|
||||
const float2* k_og = reinterpret_cast<const float2*>(data.present_key);
|
||||
const float2* v_og = reinterpret_cast<const float2*>(data.present_value);
|
||||
ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, kv_sequence_length,
|
||||
kv_sequence_length, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH, stream,
|
||||
max_threads_per_block));
|
||||
key = reinterpret_cast<const void*>(data.k);
|
||||
value = reinterpret_cast<const void*>(data.v);
|
||||
}
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
|
||||
MemoryEfficientAttentionParams p;
|
||||
p.sm = device_prop.major * 10 + device_prop.minor;
|
||||
p.is_half = sizeof(T) == 2;
|
||||
p.batch_size = batch_size;
|
||||
p.num_heads = num_heads;
|
||||
p.sequence_length = sequence_length;
|
||||
p.kv_sequence_length = past_sequence_length + kv_sequence_length;
|
||||
p.max_sequence_length = (num_heads == kv_num_heads) ? present_sequence_length : past_sequence_length + kv_sequence_length;
|
||||
p.qk_head_size = head_size;
|
||||
p.v_head_size = head_size;
|
||||
p.causal = parameters.is_unidirectional;
|
||||
p.scale = scale;
|
||||
p.seqlen_k_ptr = nullptr;
|
||||
p.seqstart_q_ptr = nullptr;
|
||||
p.seqstart_k_ptr = nullptr;
|
||||
p.query = query;
|
||||
p.key = key;
|
||||
p.value = value;
|
||||
p.attn_bias = nullptr;
|
||||
p.is_attn_bias_batched = false;
|
||||
p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
|
||||
p.output = data.output;
|
||||
p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float))
|
||||
? data.fmha_buffer
|
||||
: nullptr;
|
||||
p.stream = stream;
|
||||
run_memory_efficient_attention(p);
|
||||
|
||||
return Status::OK();
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
////////// API Functions
|
||||
|
||||
template <typename T>
|
||||
Status QkvToContext(
|
||||
const cudaDeviceProp& device_prop,
|
||||
cublasHandle_t& cublas,
|
||||
Stream* ort_stream,
|
||||
contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<T>& data) {
|
||||
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
|
||||
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size)) : parameters.scale;
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
if (data.use_flash_attention) {
|
||||
return FlashAttention(device_prop, stream, parameters, data, scale);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
if (data.use_memory_efficient_attention) {
|
||||
return EfficientAttention(device_prop, stream, parameters, data, scale);
|
||||
}
|
||||
#endif
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet.");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,19 +14,28 @@ namespace cuda {
|
|||
|
||||
template <typename T>
|
||||
struct GroupQueryAttentionData {
|
||||
// Input Tensors
|
||||
const T* query = nullptr;
|
||||
const T* key = nullptr;
|
||||
const T* value = nullptr;
|
||||
const T* past_key = nullptr;
|
||||
const T* past_value = nullptr;
|
||||
// Flash buffers
|
||||
T* softmax_lse = nullptr;
|
||||
T* softmax_lse_accum = nullptr;
|
||||
T* out_accum = nullptr;
|
||||
int* seqlens_k = nullptr;
|
||||
// Memory Efficient buffers
|
||||
T* fmha_buffer = nullptr;
|
||||
T* k = nullptr;
|
||||
T* v = nullptr;
|
||||
// Output Tensors
|
||||
T* output = nullptr;
|
||||
T* present_key = nullptr;
|
||||
T* present_value = nullptr;
|
||||
// Kernel Flags
|
||||
bool use_flash_attention = false;
|
||||
bool use_memory_efficient_attention = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -507,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass(
|
|||
MemoryEfficientAttentionParams p;
|
||||
p.sm = device_prop.major * 10 + device_prop.minor;
|
||||
p.is_half = sizeof(T) == 2;
|
||||
p.is_kv_bsnh = true;
|
||||
p.batch_size = parameters.batch_size;
|
||||
p.num_heads = parameters.num_heads;
|
||||
p.sequence_length = parameters.sequence_length;
|
||||
p.kv_sequence_length = parameters.sequence_length;
|
||||
p.max_sequence_length = parameters.sequence_length;
|
||||
p.qk_head_size = parameters.head_size;
|
||||
p.v_head_size = parameters.v_head_size;
|
||||
p.causal = false;
|
||||
|
|
|
|||
|
|
@ -688,6 +688,7 @@ Status FusedAttentionCutlass(
|
|||
p.num_heads = parameters.num_heads;
|
||||
p.sequence_length = parameters.sequence_length;
|
||||
p.kv_sequence_length = parameters.sequence_length;
|
||||
p.max_sequence_length = parameters.sequence_length;
|
||||
p.qk_head_size = parameters.head_size;
|
||||
p.v_head_size = parameters.v_head_size;
|
||||
p.causal = false;
|
||||
|
|
@ -702,6 +703,7 @@ Status FusedAttentionCutlass(
|
|||
p.attn_bias = data.relative_position_bias;
|
||||
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
|
||||
p.output = data.output;
|
||||
p.is_kv_bsnh = true;
|
||||
p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float))
|
||||
? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v)))
|
||||
: nullptr;
|
||||
|
|
|
|||
|
|
@ -1041,15 +1041,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key"
|
||||
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
|
||||
"kv_sequence_length.",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
"T")
|
||||
.Output(2,
|
||||
"present_value",
|
||||
"present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value"
|
||||
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
|
||||
"kv_sequence_length.",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
"T")
|
||||
.TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.")
|
||||
.TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
|
|
|
|||
|
|
@ -147,6 +147,7 @@ class SymbolicShapeInference:
|
|||
"GatherElements": self._infer_GatherElements,
|
||||
"GatherND": self._infer_GatherND,
|
||||
"Identity": self._pass_on_shape_and_type,
|
||||
"AllReduce": self._pass_on_shape_and_type,
|
||||
"If": self._infer_If,
|
||||
"Loop": self._infer_Loop,
|
||||
"MatMul": self._infer_MatMul,
|
||||
|
|
|
|||
|
|
@ -1272,7 +1272,7 @@ def find_past_seq_len_usage(subg: GraphProto):
|
|||
return tensor_names_to_rename, nodes_to_remove
|
||||
|
||||
|
||||
def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0):
|
||||
def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1):
|
||||
past_seq_len = past_seq_len_input
|
||||
if past_seq_len not in model.get_graphs_input_names():
|
||||
# Add model input for past sequence length
|
||||
|
|
@ -1282,6 +1282,10 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads
|
|||
# Replace MultiHeadAttention with GroupQueryAttention
|
||||
for node in model.model.graph.node:
|
||||
if node.op_type == "MultiHeadAttention":
|
||||
num_heads_mha = 0
|
||||
for att in node.attribute:
|
||||
if att.name == "num_heads":
|
||||
num_heads_mha = att.i
|
||||
gqa_node = onnx.helper.make_node(
|
||||
"GroupQueryAttention",
|
||||
inputs=[
|
||||
|
|
@ -1295,8 +1299,8 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads
|
|||
outputs=node.output,
|
||||
name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
|
||||
domain="com.microsoft",
|
||||
num_heads=node.attribute[0].i,
|
||||
kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads,
|
||||
num_heads=num_heads_mha // world_size,
|
||||
kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
|
||||
is_past_bsnh=0,
|
||||
)
|
||||
model.model.graph.node.remove(node)
|
||||
|
|
|
|||
|
|
@ -130,3 +130,8 @@ class Fusion:
|
|||
for node in nodes:
|
||||
if node not in self.nodes_to_remove:
|
||||
self.nodes_to_remove.append(node)
|
||||
|
||||
def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]):
|
||||
for node in nodes:
|
||||
if node not in self.nodes_to_remove and node not in nodes_to_keep:
|
||||
self.nodes_to_remove.append(node)
|
||||
|
|
|
|||
|
|
@ -323,6 +323,7 @@ class FusionRotaryAttention(FusionAttention):
|
|||
|
||||
# qkv_nodes_1 is for LLaMA-2 Microsoft
|
||||
# qkv_nodes_2 is for LLaMA-2 Hugging Face
|
||||
# qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model
|
||||
qkv_nodes = None
|
||||
qkv_nodes_1 = self.model.match_parent_path(
|
||||
normalize_node,
|
||||
|
|
@ -334,18 +335,27 @@ class FusionRotaryAttention(FusionAttention):
|
|||
["MatMul", "Reshape", "Transpose", "MatMul"],
|
||||
[1, 0, 0, 0],
|
||||
)
|
||||
qkv_nodes_3 = self.model.match_parent_path(
|
||||
normalize_node,
|
||||
["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"],
|
||||
[1, 0, 0, 0, 0],
|
||||
)
|
||||
if qkv_nodes_1 is not None:
|
||||
_, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
|
||||
qkv_nodes = qkv_nodes_1
|
||||
elif qkv_nodes_2 is not None:
|
||||
_, reshape_qkv, _, matmul_qkv = qkv_nodes_2
|
||||
qkv_nodes = qkv_nodes_2
|
||||
elif qkv_nodes_3 is not None:
|
||||
_, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3
|
||||
qkv_nodes = qkv_nodes_3
|
||||
else:
|
||||
logger.debug("fuse_rotary_attention: failed to match qkv nodes")
|
||||
return
|
||||
|
||||
# v_nodes_1 is for LLaMA-2 Microsoft
|
||||
# v_nodes_3 is for LLaMA-2 Hugging Face
|
||||
# v_nodes_4 is for LLaMA-2 70B model
|
||||
past_v, present_v, past_seq_len = "", "", ""
|
||||
v_nodes = None
|
||||
v_nodes_1 = self.model.match_parent_path(
|
||||
|
|
@ -363,6 +373,118 @@ class FusionRotaryAttention(FusionAttention):
|
|||
["Transpose", "Reshape", "MatMul"],
|
||||
[1, 0, 0],
|
||||
)
|
||||
_, v_nodes_4, _ = self.model.match_parent_paths_all(
|
||||
matmul_qkv,
|
||||
[
|
||||
(
|
||||
["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"],
|
||||
[1, 0, 0, 0, 1, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Reshape",
|
||||
"Expand",
|
||||
"Where",
|
||||
"Equal",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Reshape",
|
||||
"Expand",
|
||||
"Where",
|
||||
"Equal",
|
||||
"Mul",
|
||||
"ConstantOfShape",
|
||||
"Shape",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Reshape",
|
||||
"Expand",
|
||||
"Where",
|
||||
"ConstantOfShape",
|
||||
"Shape",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Reshape",
|
||||
"Expand",
|
||||
"Where",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0],
|
||||
),
|
||||
(
|
||||
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
|
||||
[1, 1, 0, 0, 0, 0, 1, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Mul",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
|
||||
),
|
||||
(
|
||||
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
|
||||
[1, 1, 2, 0, 0, 0, 1, 0, 0],
|
||||
),
|
||||
(
|
||||
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
|
||||
[1, 1, 3, 0, 0, 0, 1, 0, 0],
|
||||
),
|
||||
],
|
||||
output_name_to_node=None,
|
||||
)
|
||||
if v_nodes_1 is not None:
|
||||
reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
|
||||
v_nodes = v_nodes_1
|
||||
|
|
@ -388,6 +510,11 @@ class FusionRotaryAttention(FusionAttention):
|
|||
transpose_v, reshape_v, matmul_v = v_nodes_3
|
||||
v_nodes = v_nodes_3
|
||||
present_v = transpose_v.output[0]
|
||||
elif v_nodes_4 is not None and len(v_nodes_4) == 9:
|
||||
concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:]
|
||||
v_nodes = v_nodes_4
|
||||
past_v = concat_v.input[0]
|
||||
present_v = concat_v.output[0]
|
||||
else:
|
||||
logger.debug("fuse_rotary_attention: failed to match v path")
|
||||
return
|
||||
|
|
@ -461,6 +588,7 @@ class FusionRotaryAttention(FusionAttention):
|
|||
|
||||
# k_nodes_1 is for LLaMA-2 Microsoft
|
||||
# k_nodes_2 is for LLaMA-2 Hugging Face
|
||||
# k_nodes_4 is for LLaMA-2 70B Hugging Face
|
||||
past_k, present_k = "", ""
|
||||
k_nodes = None
|
||||
k_nodes_1 = self.model.match_parent_path(
|
||||
|
|
@ -478,6 +606,174 @@ class FusionRotaryAttention(FusionAttention):
|
|||
["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
|
||||
[1, 0, 1, 0, 0, 0],
|
||||
)
|
||||
_, k_nodes_4, _ = self.model.match_parent_paths_all(
|
||||
matmul_qk,
|
||||
[
|
||||
(
|
||||
[
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"Expand",
|
||||
"Unsqueeze",
|
||||
"Concat",
|
||||
"RotaryEmbedding",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 0, 0, 0, 1, 0, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"Expand",
|
||||
"Where",
|
||||
"Equal",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"RotaryEmbedding",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"Expand",
|
||||
"Where",
|
||||
"Equal",
|
||||
"Mul",
|
||||
"ConstantOfShape",
|
||||
"Shape",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"RotaryEmbedding",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"Expand",
|
||||
"Where",
|
||||
"ConstantOfShape",
|
||||
"Shape",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"RotaryEmbedding",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"Expand",
|
||||
"Where",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"RotaryEmbedding",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"RotaryEmbedding",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Mul",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"RotaryEmbedding",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"RotaryEmbedding",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0],
|
||||
),
|
||||
(
|
||||
[
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"Concat",
|
||||
"Unsqueeze",
|
||||
"Gather",
|
||||
"Shape",
|
||||
"Concat",
|
||||
"RotaryEmbedding",
|
||||
"Transpose",
|
||||
"Reshape",
|
||||
"MatMul",
|
||||
],
|
||||
[1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0],
|
||||
),
|
||||
],
|
||||
output_name_to_node=None,
|
||||
)
|
||||
if k_nodes_1 is not None:
|
||||
reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
|
||||
k_nodes = k_nodes_1
|
||||
|
|
@ -505,6 +801,12 @@ class FusionRotaryAttention(FusionAttention):
|
|||
k_nodes = k_nodes_3
|
||||
past_k = concat_k.input[0]
|
||||
present_k = concat_k.output[0]
|
||||
elif k_nodes_4 is not None and len(k_nodes_4) == 9:
|
||||
reshape_k, matmul_k = k_nodes_4[0][-2:]
|
||||
concat_k, rotary_k = k_nodes_4[0][-5:-3]
|
||||
k_nodes = k_nodes_4
|
||||
past_k = concat_k.input[0]
|
||||
present_k = concat_k.output[0]
|
||||
else:
|
||||
logger.debug("fuse_rotary_attention: failed to match k nodes")
|
||||
return
|
||||
|
|
@ -552,7 +854,7 @@ class FusionRotaryAttention(FusionAttention):
|
|||
return
|
||||
root_output = reshape_qkv_2.output[0]
|
||||
|
||||
elif qkv_nodes == qkv_nodes_2:
|
||||
elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3):
|
||||
if not self.check_runtime_shape_paths_for_nodes(
|
||||
reshape_qkv,
|
||||
reshape_q,
|
||||
|
|
@ -573,6 +875,9 @@ class FusionRotaryAttention(FusionAttention):
|
|||
# Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
|
||||
rotary_k.output[0] = rotary_k.name + "_output_0"
|
||||
|
||||
if qkv_nodes == qkv_nodes_3:
|
||||
qkv_nodes = qkv_nodes[1:]
|
||||
|
||||
new_node = self.create_mha_node(
|
||||
matmul_q.input[0],
|
||||
root_output,
|
||||
|
|
@ -594,7 +899,14 @@ class FusionRotaryAttention(FusionAttention):
|
|||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
|
||||
self.nodes_to_remove.extend(qkv_nodes[1:])
|
||||
self.nodes_to_remove.extend(v_nodes[:-1])
|
||||
|
||||
if v_nodes != v_nodes_4:
|
||||
self.nodes_to_remove.extend(v_nodes[:-1])
|
||||
else:
|
||||
nodes_to_keep = [v_nodes[0][-1]]
|
||||
for temp_path in v_nodes:
|
||||
self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
|
||||
|
||||
self.nodes_to_remove.extend(qk_nodes)
|
||||
|
||||
if k_nodes == k_nodes_1:
|
||||
|
|
@ -608,6 +920,10 @@ class FusionRotaryAttention(FusionAttention):
|
|||
self.nodes_to_remove.append(k_nodes[1])
|
||||
self.nodes_to_remove.append(k_nodes[3])
|
||||
self.nodes_to_remove.append(k_nodes[4])
|
||||
elif k_nodes == k_nodes_4:
|
||||
nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
|
||||
for temp_path in k_nodes:
|
||||
self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
|
||||
|
||||
if q_nodes == q_nodes_1:
|
||||
self.nodes_to_remove.extend(q_nodes[:-2])
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ Please note the package versions needed for using LLaMA-2 in the `requirements.t
|
|||
- Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file.
|
||||
- `requirements-quant.txt`
|
||||
- For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor)
|
||||
- `requirements-70b-model.txt`
|
||||
- For running the LLaMA-2 70B model on multiple GPUs
|
||||
- `requirements.txt`
|
||||
- Package versions needed in each of the above files
|
||||
|
||||
|
|
@ -79,6 +81,15 @@ model.save_pretrained(name.split("/")[-1] + "-onnx")
|
|||
|
||||
Here are some additional examples for exporting LLaMA-2.
|
||||
|
||||
Export Model with Different GPU Device Ids
|
||||
```
|
||||
# From source using first GPU:
|
||||
$ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b
|
||||
|
||||
# From wheel using second GPU:
|
||||
$ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b
|
||||
```
|
||||
|
||||
Export Saved Model on Disk
|
||||
```
|
||||
# From source:
|
||||
|
|
@ -153,6 +164,19 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l
|
|||
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu
|
||||
```
|
||||
|
||||
Export LLaMA-2 70B sharded model into 4 partitions
|
||||
```
|
||||
# From source:
|
||||
# 1. Install necessary packages from requirements-70b-model.txt
|
||||
|
||||
# 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command:
|
||||
$ ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/
|
||||
|
||||
# 3. Shard and export the LLaMA-2 70B model. With FP16, you will need at least 140GB of GPU memory to load the model. Therefore, you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command:
|
||||
$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda
|
||||
|
||||
```
|
||||
|
||||
## Benchmark LLaMA-2
|
||||
|
||||
Here are some examples of how you can benchmark LLaMA-2.
|
||||
|
|
@ -220,11 +244,11 @@ python3 -m models.llama.benchmark \
|
|||
--device cuda
|
||||
```
|
||||
|
||||
6. ONNX Runtime, FP32, convert_to_onnx
|
||||
6. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \
|
||||
--benchmark-type ort-convert-to-onnx \
|
||||
--ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
|
||||
--ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp32 \
|
||||
--batch-sizes "1 2" \
|
||||
|
|
@ -232,11 +256,11 @@ python3 -m models.llama.benchmark \
|
|||
--device cpu
|
||||
```
|
||||
|
||||
7. ONNX Runtime, FP16, convert_to_onnx
|
||||
7. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \
|
||||
--benchmark-type ort-convert-to-onnx \
|
||||
--ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
|
||||
--ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp16 \
|
||||
--batch-sizes "1 2" \
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import numpy as np
|
|||
import onnx
|
||||
import psutil
|
||||
import torch
|
||||
from dist_settings import get_rank, get_size
|
||||
from llama_inputs import (
|
||||
add_io_bindings,
|
||||
get_merged_sample_with_past_kv_inputs,
|
||||
|
|
@ -133,6 +134,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
use_fp16=args.use_fp16,
|
||||
engine="ort",
|
||||
return_dict=True,
|
||||
world_size=args.world_size,
|
||||
)
|
||||
iter_inputs = get_merged_sample_with_past_kv_inputs(
|
||||
args.config,
|
||||
|
|
@ -144,6 +146,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|||
use_fp16=args.use_fp16,
|
||||
engine="ort",
|
||||
return_dict=True,
|
||||
world_size=args.world_size,
|
||||
)
|
||||
|
||||
elif args.benchmark_type == "ort-msft":
|
||||
|
|
@ -244,10 +247,10 @@ def get_model(args: argparse.Namespace):
|
|||
|
||||
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
|
||||
# Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
||||
logger.info(f"Loading model from {args.ort_model_path}")
|
||||
logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
|
||||
start_time = time.time()
|
||||
model = ort.InferenceSession(
|
||||
args.ort_model_path,
|
||||
args.ort_model_path.format(args.rank),
|
||||
sess_options,
|
||||
providers=[args.execution_provider],
|
||||
)
|
||||
|
|
@ -315,10 +318,11 @@ def time_fn(args, fn, inputs):
|
|||
latency = total_time / args.num_runs
|
||||
throughput = args.batch_size / latency
|
||||
|
||||
logger.info(f"Batch Size: {args.batch_size}")
|
||||
logger.info(f"Sequence Length: {args.sequence_length}")
|
||||
logger.info(f"Latency: {latency} s")
|
||||
logger.info(f"Throughput: {throughput} tps")
|
||||
if args.rank == 0:
|
||||
logger.info(f"Batch Size: {args.batch_size}")
|
||||
logger.info(f"Sequence Length: {args.sequence_length}")
|
||||
logger.info(f"Latency: {latency} s")
|
||||
logger.info(f"Throughput: {throughput} tps")
|
||||
return
|
||||
|
||||
|
||||
|
|
@ -358,7 +362,8 @@ def measure_fn(args, fn, inputs):
|
|||
process.cpu_percent(interval=0.1)
|
||||
|
||||
fn(inputs)
|
||||
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
|
||||
if args.rank == 0:
|
||||
logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
|
||||
|
||||
# Measure memory usage
|
||||
gc.collect()
|
||||
|
|
@ -451,7 +456,7 @@ def run_ort_inference(args, init_inputs, iter_inputs, model):
|
|||
# Add IO bindings for non-CPU execution providers
|
||||
if args.device != "cpu":
|
||||
io_binding, kv_cache_ortvalues = add_io_bindings(
|
||||
model, inputs, args.device, int(args.device_id), kv_cache_ortvalues
|
||||
model, inputs, args.device, int(args.rank), kv_cache_ortvalues
|
||||
)
|
||||
setattr(args, "io_binding", io_binding) # noqa: B010
|
||||
return io_binding, kv_cache_ortvalues
|
||||
|
|
@ -511,7 +516,7 @@ def run_inference(args, init_inputs, iter_inputs, model):
|
|||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
||||
|
||||
def get_args():
|
||||
def get_args(rank=0):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-bt",
|
||||
|
|
@ -569,7 +574,7 @@ def get_args():
|
|||
parser.add_argument(
|
||||
"-s",
|
||||
"--sequence-lengths",
|
||||
default="8 16 32 64 128 256 512",
|
||||
default="32 64 128 256 512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
|
|
@ -606,9 +611,9 @@ def get_args():
|
|||
if "ort" in args.benchmark_type:
|
||||
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
|
||||
if args.execution_provider == "CUDAExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
args.execution_provider = (args.execution_provider, {"device_id": rank})
|
||||
elif args.execution_provider == "ROCMExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
args.execution_provider = (args.execution_provider, {"device_id": rank})
|
||||
args.device = "cuda"
|
||||
|
||||
# Check that paths have been specified for any benchmarking with ORT
|
||||
|
|
@ -635,14 +640,19 @@ def get_args():
|
|||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
rank = get_rank()
|
||||
world_size = get_size()
|
||||
|
||||
args = get_args(rank)
|
||||
setup_logger(args.verbose)
|
||||
logger.info(args.__dict__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
args.rank = rank
|
||||
args.world_size = world_size
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.model_name)
|
||||
config = LlamaConfig.from_pretrained(args.model_name)
|
||||
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
|
||||
target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
|
||||
use_fp16 = args.precision == "fp16"
|
||||
|
||||
setattr(args, "tokenizer", tokenizer) # noqa: B010
|
||||
|
|
@ -656,7 +666,7 @@ def main():
|
|||
|
||||
# Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
|
||||
if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
|
||||
onnx_model = onnx.load_model(args.ort_model_path, load_external_data=False)
|
||||
onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
|
||||
gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
|
||||
|
||||
use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
|
||||
|
|
@ -666,7 +676,8 @@ def main():
|
|||
|
||||
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
|
||||
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
|
||||
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
|
||||
if args.rank == 0:
|
||||
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
|
||||
setattr(args, "batch_size", int(batch_size)) # noqa: B010
|
||||
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
|
||||
NUM_GPUS=${1:-1}
|
||||
|
||||
MPI="mpirun --allow-run-as-root
|
||||
-mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0
|
||||
--tag-output --npernode $NUM_GPUS --bind-to numa
|
||||
-x MIOPEN_FIND_MODE=1"
|
||||
|
||||
CMD="$MPI python benchmark.py ${@:2}"
|
||||
|
||||
$CMD
|
||||
|
|
@ -247,6 +247,7 @@ def main():
|
|||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
all_results = []
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
||||
|
||||
# Benchmark PyTorch without torch.compile
|
||||
if args.hf_pt_eager:
|
||||
|
|
@ -266,8 +267,6 @@ def main():
|
|||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
|
|
@ -298,8 +297,6 @@ def main():
|
|||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
|
|
@ -332,8 +329,6 @@ def main():
|
|||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
|
|
@ -366,8 +361,6 @@ def main():
|
|||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
|
|
@ -399,8 +392,6 @@ def main():
|
|||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
|
||||
NUM_GPUS=${1:-1}
|
||||
|
||||
MPI="mpirun --allow-run-as-root
|
||||
-mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0
|
||||
--tag-output --npernode $NUM_GPUS --bind-to numa
|
||||
-x MIOPEN_FIND_MODE=1"
|
||||
|
||||
CMD="$MPI python convert_to_onnx.py ${@:2}"
|
||||
|
||||
$CMD
|
||||
|
|
@ -1,16 +1,16 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from benchmark_helper import Precision, prepare_environment, setup_logger
|
||||
from convert_generation import replace_mha_with_gqa
|
||||
from dist_settings import barrier, get_rank, get_size, init_dist
|
||||
from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs
|
||||
from llama_parity import main as parity_check
|
||||
from llama_torch import setup_torch_model
|
||||
from onnx_model import OnnxModel
|
||||
from optimizer import optimize_model
|
||||
from packaging import version
|
||||
|
|
@ -18,8 +18,11 @@ from transformers import LlamaConfig, LlamaForCausalLM
|
|||
|
||||
from onnxruntime import quantization as ort_quantization
|
||||
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
|
||||
from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger
|
||||
from onnxruntime.transformers.convert_generation import replace_mha_with_gqa
|
||||
|
||||
logger = logging.getLogger("")
|
||||
init_dist()
|
||||
|
||||
|
||||
def get_model_dynamic_axes(input_names: List[str], output_names: List[str]):
|
||||
|
|
@ -129,7 +132,9 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st
|
|||
# del onnx_model
|
||||
# temp_dir.cleanup()
|
||||
#
|
||||
def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
|
||||
def run_dynamo_export(
|
||||
args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1
|
||||
):
|
||||
from torch._dynamo import config
|
||||
|
||||
config.capture_scalar_outputs = True
|
||||
|
|
@ -150,9 +155,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll
|
|||
onnx.checker.check_model(temp_path)
|
||||
onnx.shape_inference.infer_shapes_path(temp_path)
|
||||
|
||||
output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
|
||||
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
|
||||
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
||||
save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_model_fp32.onnx.data")
|
||||
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data")
|
||||
del onnx_model
|
||||
os.system(
|
||||
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
|
||||
|
|
@ -160,7 +165,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll
|
|||
|
||||
# Export decoder_with_past_model.onnx
|
||||
input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs(
|
||||
l_config, device, batch_size, sequence_length
|
||||
l_config, device, batch_size, sequence_length, world_size=world_size
|
||||
)
|
||||
temp_dir = args.output # tempfile.TemporaryDirectory()
|
||||
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
|
||||
|
|
@ -172,9 +177,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll
|
|||
onnx.checker.check_model(temp_path)
|
||||
onnx.shape_inference.infer_shapes_path(temp_path)
|
||||
|
||||
output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx")
|
||||
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx")
|
||||
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
||||
save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_with_past_model_fp32.onnx.data")
|
||||
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data")
|
||||
del onnx_model
|
||||
os.system(
|
||||
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
|
||||
|
|
@ -183,10 +188,21 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll
|
|||
logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")
|
||||
|
||||
|
||||
def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
|
||||
def _prepare_dir(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
|
||||
def run_torchscript_separate_export(
|
||||
args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1
|
||||
):
|
||||
# Dummy values for export
|
||||
batch_size, sequence_length = 2, 8
|
||||
device = torch.device("cpu")
|
||||
|
||||
# set device used to export model
|
||||
# for llama-2-70b we will use current gpus to speed up export process
|
||||
# for other models, we will use CPU to make sure we have enough memory to do export
|
||||
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
|
||||
|
||||
# Export decoder_model.onnx
|
||||
decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length)
|
||||
|
|
@ -199,8 +215,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon
|
|||
),
|
||||
]
|
||||
dynamic_axes = get_model_dynamic_axes(input_names, output_names)
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
temp_path = os.path.join(temp_dir.name, "temp.onnx")
|
||||
|
||||
# Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
|
||||
# Use temp folder per rank to avoid race condition here.
|
||||
temp_dir = f"./temp_{rank}"
|
||||
_prepare_dir(temp_dir)
|
||||
temp_path = os.path.join(temp_dir, "temp.onnx")
|
||||
torch.onnx.export(
|
||||
llama,
|
||||
args=decoder_inputs,
|
||||
|
|
@ -218,18 +238,25 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon
|
|||
onnx.checker.check_model(temp_path)
|
||||
onnx.shape_inference.infer_shapes_path(temp_path)
|
||||
|
||||
output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
|
||||
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
|
||||
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
||||
save_onnx_model(
|
||||
onnx_model,
|
||||
output_path,
|
||||
f"{args.model_name}_decoder_model_fp32.onnx.data",
|
||||
f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data",
|
||||
)
|
||||
del onnx_model
|
||||
temp_dir.cleanup()
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
# Export decoder_with_past_model.onnx
|
||||
decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length)
|
||||
decoder_with_past_inputs = get_sample_with_past_kv_inputs(
|
||||
l_config,
|
||||
device,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
use_fp16=args.precision == Precision.FLOAT16,
|
||||
world_size=world_size,
|
||||
)
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
|
|
@ -247,8 +274,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon
|
|||
),
|
||||
]
|
||||
dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names)
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
temp_path = os.path.join(temp_dir.name, "temp.onnx")
|
||||
|
||||
# Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
|
||||
# Use temp folder per rank to avoid race condition here.
|
||||
temp_dir = f"./temp_past_{rank}"
|
||||
_prepare_dir(temp_dir)
|
||||
temp_path = os.path.join(temp_dir, "temp.onnx")
|
||||
torch.onnx.export(
|
||||
llama,
|
||||
args=decoder_with_past_inputs,
|
||||
|
|
@ -266,27 +297,45 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon
|
|||
onnx.checker.check_model(temp_path)
|
||||
onnx.shape_inference.infer_shapes_path(temp_path)
|
||||
|
||||
output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx")
|
||||
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx")
|
||||
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
||||
save_onnx_model(
|
||||
onnx_model,
|
||||
output_path,
|
||||
f"{args.model_name}_decoder_with_past_model_fp32.onnx.data",
|
||||
f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data",
|
||||
)
|
||||
del onnx_model
|
||||
temp_dir.cleanup()
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!")
|
||||
logger.info(
|
||||
f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!"
|
||||
)
|
||||
|
||||
|
||||
def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
|
||||
def run_torchscript_merged_export(
|
||||
args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1
|
||||
):
|
||||
# Dummy values for export
|
||||
batch_size, sequence_length, past_sequence_length = 2, 8, 0
|
||||
device = torch.device("cpu")
|
||||
|
||||
# set device used to export model
|
||||
# for llama-2-70b we will use current gpus to speed up export process
|
||||
# for other models, we will use CPU to make sure we have enough memory to do export
|
||||
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
|
||||
|
||||
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
|
||||
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
|
||||
|
||||
# Export decoder_merged_model.onnx
|
||||
decoder_merged_inputs = get_merged_sample_with_past_kv_inputs(
|
||||
l_config, device, batch_size, sequence_length, past_sequence_length
|
||||
l_config,
|
||||
device,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
past_sequence_length,
|
||||
max_seq_len=max_sequence_length,
|
||||
use_fp16=args.precision == Precision.FLOAT16,
|
||||
world_size=world_size,
|
||||
)
|
||||
input_names = [
|
||||
"input_ids",
|
||||
|
|
@ -305,8 +354,12 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi
|
|||
),
|
||||
]
|
||||
dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names)
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
temp_path = os.path.join(temp_dir.name, "temp.onnx")
|
||||
|
||||
# Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
|
||||
# Use temp folder per rank to avoid race condition here.
|
||||
temp_dir = f"./temp_{rank}"
|
||||
_prepare_dir(temp_dir)
|
||||
temp_path = os.path.join(temp_dir, "temp.onnx")
|
||||
torch.onnx.export(
|
||||
llama,
|
||||
args=decoder_merged_inputs,
|
||||
|
|
@ -324,17 +377,17 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi
|
|||
onnx.checker.check_model(temp_path)
|
||||
onnx.shape_inference.infer_shapes_path(temp_path)
|
||||
|
||||
output_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx")
|
||||
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx")
|
||||
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
||||
save_onnx_model(
|
||||
onnx_model,
|
||||
output_path,
|
||||
f"{args.model_name}_decoder_merged_model_fp32.onnx.data",
|
||||
f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx.data",
|
||||
)
|
||||
del onnx_model
|
||||
temp_dir.cleanup()
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!")
|
||||
logger.info(f"The {args.model_name} merged ONNX model has been successfully created with the TorchScript exporter!")
|
||||
|
||||
|
||||
# Optimize the model as FP32
|
||||
|
|
@ -357,12 +410,16 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str):
|
|||
remove_existing_model(input_path)
|
||||
|
||||
|
||||
def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str]):
|
||||
decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx")
|
||||
def convert_to_float16(
|
||||
args: argparse.Namespace, config: LlamaConfig, old_paths: List[str], rank: int = 0, world_size: int = 1
|
||||
):
|
||||
decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx")
|
||||
decoder_with_past_model_fp16_path = os.path.join(
|
||||
args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx"
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx"
|
||||
)
|
||||
decoder_merged_model_fp16_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx"
|
||||
)
|
||||
decoder_merged_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp16.onnx")
|
||||
new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path]
|
||||
|
||||
logger.info("Converting to float16...")
|
||||
|
|
@ -370,7 +427,7 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths:
|
|||
if os.path.exists(fp32_path):
|
||||
model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True))
|
||||
model.convert_float_to_float16(keep_io_types=False)
|
||||
model = use_group_query_attention(config, model)
|
||||
model = use_group_query_attention(config, model, world_size)
|
||||
model.save_model_to_file(fp16_path, use_external_data_format=True)
|
||||
del model
|
||||
logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!")
|
||||
|
|
@ -380,9 +437,11 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths:
|
|||
return new_paths
|
||||
|
||||
|
||||
def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel):
|
||||
def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1):
|
||||
# Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes
|
||||
fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads)
|
||||
fp16_model_opt = replace_mha_with_gqa(
|
||||
fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size
|
||||
)
|
||||
fp16_model_opt.prune_graph()
|
||||
fp16_model_opt.update_graph(allow_remove_graph_inputs=True)
|
||||
return fp16_model_opt
|
||||
|
|
@ -406,7 +465,7 @@ def smooth_quant(
|
|||
calibration_sampling_size=[args.calibration_sampling_size],
|
||||
recipes={
|
||||
"optypes_to_exclude_output_quant": ["MatMul"],
|
||||
"smooth_quant": args.smooth_quant,
|
||||
"smooth_quant": True,
|
||||
"smooth_quant_args": {"alpha": args.smooth_quant_alpha},
|
||||
},
|
||||
op_type_dict={
|
||||
|
|
@ -526,15 +585,6 @@ def get_args():
|
|||
help="Execution provider to verify parity with",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-id",
|
||||
"--device-id",
|
||||
required=False,
|
||||
type=str,
|
||||
default="0",
|
||||
help="Device ID for GPUs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--reexport",
|
||||
|
|
@ -655,6 +705,14 @@ def get_args():
|
|||
)
|
||||
parser.set_defaults(use_dynamo_export=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./model_cache",
|
||||
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
|
@ -673,144 +731,182 @@ def main():
|
|||
remove_existing_files(args.output)
|
||||
logger.info(f"Arguments: {args}")
|
||||
|
||||
world_size = get_size()
|
||||
rank = get_rank()
|
||||
|
||||
# Load model and config
|
||||
use_auth_token = args.input == os.path.join(".")
|
||||
setattr(args, "use_auth_token", use_auth_token) # noqa: B010
|
||||
|
||||
location = args.model_name if use_auth_token else args.input
|
||||
l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token)
|
||||
llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, use_cache=True)
|
||||
original_model_name = args.model_name
|
||||
setattr(args, "original_model_name", original_model_name) # noqa: B010
|
||||
args.model_name = args.model_name.split("/")[-1]
|
||||
|
||||
# Set model paths for FP32 model
|
||||
decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
|
||||
decoder_with_past_model_fp32_path = os.path.join(
|
||||
args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx"
|
||||
)
|
||||
decoder_merged_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx")
|
||||
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
|
||||
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
|
||||
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
|
||||
|
||||
missing_separate_exports = (
|
||||
args.no_merged
|
||||
and not os.path.exists(decoder_model_fp32_path)
|
||||
and not os.path.exists(decoder_with_past_model_fp32_path)
|
||||
)
|
||||
missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path)
|
||||
location = args.original_model_name if use_auth_token else args.input
|
||||
|
||||
# Export to ONNX
|
||||
if missing_separate_exports or missing_merged_export:
|
||||
if args.use_dynamo_export and missing_separate_exports:
|
||||
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
|
||||
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
|
||||
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
|
||||
logger.warning(
|
||||
"Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
|
||||
# use cuda for Llama-2-70b to speedup export, other models use CPU by default
|
||||
l_config, llama = setup_torch_model(
|
||||
args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None
|
||||
)
|
||||
|
||||
assert l_config.num_attention_heads % world_size == 0 and l_config.num_key_value_heads % world_size == 0
|
||||
|
||||
barrier()
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
# Set model paths for FP32 model
|
||||
decoder_model_fp32_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx"
|
||||
)
|
||||
logger.warning(
|
||||
"Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
|
||||
decoder_with_past_model_fp32_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx"
|
||||
)
|
||||
run_dynamo_export(args, l_config, llama)
|
||||
elif args.no_merged:
|
||||
run_torchscript_separate_export(args, l_config, llama)
|
||||
else:
|
||||
run_torchscript_merged_export(args, l_config, llama)
|
||||
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
|
||||
|
||||
# Set model paths to store FP32 optimized model
|
||||
decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx")
|
||||
decoder_with_past_model_fp32_opt_path = os.path.join(
|
||||
args.output, f"{args.model_name}_decoder_with_past_model_fp32_opt.onnx"
|
||||
)
|
||||
decoder_merged_model_fp32_opt_path = os.path.join(
|
||||
args.output, f"{args.model_name}_decoder_merged_model_fp32_opt.onnx"
|
||||
)
|
||||
new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path]
|
||||
|
||||
# Run the optimizer script
|
||||
logger.info("Optimizing models...")
|
||||
for orig_path, opt_path in zip(old_paths, new_paths):
|
||||
if os.path.exists(orig_path):
|
||||
optimize_export(l_config, input_path=orig_path, output_path=opt_path)
|
||||
|
||||
# Re-assign default FP32 model paths as their optimized versions
|
||||
decoder_model_fp32_path = decoder_model_fp32_opt_path
|
||||
decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
|
||||
decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
|
||||
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
|
||||
|
||||
logger.info(
|
||||
f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
|
||||
)
|
||||
|
||||
# Change precision of exported models from FP32
|
||||
if args.precision == Precision.FLOAT16:
|
||||
new_paths = convert_to_float16(args, l_config, old_paths)
|
||||
|
||||
elif args.precision == Precision.INT8:
|
||||
decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx")
|
||||
decoder_with_past_model_int8_path = os.path.join(
|
||||
args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx"
|
||||
)
|
||||
decoder_merged_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int8.onnx")
|
||||
new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]
|
||||
|
||||
if args.quantization_method == "smooth_quant":
|
||||
if not args.no_merged:
|
||||
logger.error("SmoothQuant must be used on separately exported models")
|
||||
else:
|
||||
logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8")
|
||||
smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])
|
||||
|
||||
elif args.quantization_method == "quantize_dynamic":
|
||||
logger.warning(
|
||||
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
|
||||
decoder_merged_model_fp32_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx"
|
||||
)
|
||||
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
|
||||
|
||||
logger.info("Quantizing to int8...")
|
||||
for fp32_path, int8_path in zip(old_paths, new_paths):
|
||||
if os.path.exists(fp32_path):
|
||||
ort_quantization.quantize_dynamic(
|
||||
fp32_path,
|
||||
int8_path,
|
||||
op_types_to_quantize=["MatMul", "Gemm", "Gather"]
|
||||
if args.quantize_embedding_layer
|
||||
else ["MatMul", "Gemm"],
|
||||
per_channel=args.quantize_per_channel,
|
||||
reduce_range=args.quantize_reduce_range,
|
||||
use_external_data_format=True,
|
||||
extra_options={"MatMulConstBOnly": True},
|
||||
missing_separate_exports = (
|
||||
args.no_merged
|
||||
and not os.path.exists(decoder_model_fp32_path)
|
||||
and not os.path.exists(decoder_with_past_model_fp32_path)
|
||||
)
|
||||
missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path)
|
||||
|
||||
# Export to ONNX
|
||||
if missing_separate_exports or missing_merged_export:
|
||||
if args.use_dynamo_export and missing_separate_exports:
|
||||
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
|
||||
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
|
||||
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
|
||||
logger.warning(
|
||||
"Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
|
||||
)
|
||||
logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!")
|
||||
remove_existing_model(decoder_model_fp32_path)
|
||||
logger.warning(
|
||||
"Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
|
||||
)
|
||||
run_dynamo_export(args, l_config, llama)
|
||||
elif args.no_merged:
|
||||
run_torchscript_separate_export(args, l_config, llama, rank, world_size)
|
||||
else:
|
||||
run_torchscript_merged_export(args, l_config, llama, rank, world_size)
|
||||
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
|
||||
|
||||
logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
|
||||
# Set model paths to store FP32 optimized model
|
||||
decoder_model_fp32_opt_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx"
|
||||
)
|
||||
decoder_with_past_model_fp32_opt_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx"
|
||||
)
|
||||
decoder_merged_model_fp32_opt_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx"
|
||||
)
|
||||
new_paths = [
|
||||
decoder_model_fp32_opt_path,
|
||||
decoder_with_past_model_fp32_opt_path,
|
||||
decoder_merged_model_fp32_opt_path,
|
||||
]
|
||||
|
||||
else:
|
||||
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
|
||||
# Run the optimizer script
|
||||
logger.info("Optimizing models...")
|
||||
for orig_path, opt_path in zip(old_paths, new_paths):
|
||||
if os.path.exists(orig_path):
|
||||
optimize_export(l_config, input_path=orig_path, output_path=opt_path)
|
||||
|
||||
elif args.precision == Precision.INT4:
|
||||
if args.execution_provider != "cpu":
|
||||
old_paths = convert_to_float16(args, l_config, old_paths)
|
||||
# Re-assign default FP32 model paths as their optimized versions
|
||||
decoder_model_fp32_path = decoder_model_fp32_opt_path
|
||||
decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
|
||||
decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
|
||||
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
|
||||
|
||||
decoder_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int4.onnx")
|
||||
decoder_with_past_model_int4_path = os.path.join(
|
||||
args.output, f"{args.model_name}_decoder_with_past_model_int4.onnx"
|
||||
)
|
||||
decoder_merged_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int4.onnx")
|
||||
new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]
|
||||
logger.info(
|
||||
f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
|
||||
)
|
||||
|
||||
for fp_path, int4_path in zip(old_paths, new_paths):
|
||||
if os.path.exists(fp_path):
|
||||
model = onnx.load_model(fp_path, load_external_data=True)
|
||||
quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[])
|
||||
quant.process()
|
||||
quant.model.save_model_to_file(int4_path, use_external_data_format=True)
|
||||
del model
|
||||
del quant
|
||||
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
|
||||
remove_existing_model(fp_path)
|
||||
# Change precision of exported models from FP32
|
||||
if args.precision == Precision.FLOAT16:
|
||||
new_paths = convert_to_float16(args, l_config, old_paths, rank, world_size)
|
||||
|
||||
elif args.precision == Precision.INT8:
|
||||
decoder_model_int8_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx"
|
||||
)
|
||||
decoder_with_past_model_int8_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx"
|
||||
)
|
||||
decoder_merged_model_int8_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx"
|
||||
)
|
||||
new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]
|
||||
|
||||
if args.quantization_method == "smooth_quant":
|
||||
if not args.no_merged:
|
||||
logger.error("SmoothQuant must be used on separately exported models")
|
||||
else:
|
||||
logger.info(
|
||||
f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8"
|
||||
)
|
||||
smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])
|
||||
|
||||
elif args.quantization_method == "quantize_dynamic":
|
||||
logger.warning(
|
||||
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
|
||||
)
|
||||
|
||||
logger.info("Quantizing to int8...")
|
||||
for fp32_path, int8_path in zip(old_paths, new_paths):
|
||||
if os.path.exists(fp32_path):
|
||||
ort_quantization.quantize_dynamic(
|
||||
fp32_path,
|
||||
int8_path,
|
||||
op_types_to_quantize=["MatMul", "Gemm", "Gather"]
|
||||
if args.quantize_embedding_layer
|
||||
else ["MatMul", "Gemm"],
|
||||
per_channel=args.quantize_per_channel,
|
||||
reduce_range=args.quantize_reduce_range,
|
||||
use_external_data_format=True,
|
||||
extra_options={"MatMulConstBOnly": True},
|
||||
)
|
||||
logger.info(
|
||||
f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!"
|
||||
)
|
||||
remove_existing_model(decoder_model_fp32_path)
|
||||
|
||||
logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
|
||||
|
||||
else:
|
||||
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
|
||||
|
||||
elif args.precision == Precision.INT4:
|
||||
if args.execution_provider != "cpu":
|
||||
old_paths = convert_to_float16(args, l_config, old_paths, rank, world_size)
|
||||
|
||||
decoder_model_int4_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx"
|
||||
)
|
||||
decoder_with_past_model_int4_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx"
|
||||
)
|
||||
decoder_merged_model_int4_path = os.path.join(
|
||||
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx"
|
||||
)
|
||||
new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]
|
||||
|
||||
for fp_path, int4_path in zip(old_paths, new_paths):
|
||||
if os.path.exists(fp_path):
|
||||
model = onnx.load_model(fp_path, load_external_data=True)
|
||||
quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[])
|
||||
quant.process()
|
||||
quant.model.save_model_to_file(int4_path, use_external_data_format=True)
|
||||
del model
|
||||
del quant
|
||||
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
|
||||
remove_existing_model(fp_path)
|
||||
barrier()
|
||||
|
||||
logger.info("Verifying parity on all ONNX models created")
|
||||
|
||||
|
|
@ -824,7 +920,12 @@ def main():
|
|||
|
||||
# Verify parity on all saved ONNX models
|
||||
for filename in os.listdir(args.output):
|
||||
if ".data" in filename or ".onnx" not in filename:
|
||||
if (
|
||||
".data" in filename
|
||||
or ".onnx" not in filename
|
||||
or args.precision not in filename
|
||||
or f"rank_{rank}" not in filename
|
||||
):
|
||||
continue
|
||||
|
||||
parity_cmd = [
|
||||
|
|
@ -834,10 +935,10 @@ def main():
|
|||
os.path.join(args.output, filename),
|
||||
"-ep",
|
||||
args.execution_provider,
|
||||
"-id",
|
||||
args.device_id,
|
||||
"-fp",
|
||||
args.precision,
|
||||
"--cache_dir",
|
||||
args.cache_dir,
|
||||
]
|
||||
if "with_past" in filename:
|
||||
parity_cmd.append("--use_past_kv")
|
||||
|
|
@ -845,6 +946,7 @@ def main():
|
|||
parity_cmd.append("--merged")
|
||||
|
||||
try:
|
||||
logger.debug(f"check parity with cmd: {parity_cmd}")
|
||||
parity_check(parity_cmd)
|
||||
except Exception as e:
|
||||
logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
comm = None
|
||||
|
||||
|
||||
def init_dist():
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
int(os.environ["LOCAL_RANK"])
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
|
||||
elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
||||
from mpi4py import MPI
|
||||
|
||||
comm = MPI.COMM_WORLD # noqa: F841
|
||||
|
||||
int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
|
||||
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
|
||||
|
||||
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
|
||||
else:
|
||||
# don't need to do init for single process
|
||||
pass
|
||||
|
||||
|
||||
def get_rank():
|
||||
return comm.Get_rank() if comm is not None else 0
|
||||
|
||||
|
||||
def get_size():
|
||||
return comm.Get_size() if comm is not None else 1
|
||||
|
||||
|
||||
def barrier():
|
||||
if comm is not None:
|
||||
comm.Barrier()
|
||||
|
||||
|
||||
def print_out(*args):
|
||||
if get_rank() == 0:
|
||||
print(*args)
|
||||
|
|
@ -66,12 +66,13 @@ def get_sample_with_past_kv_inputs(
|
|||
use_fp16: bool = False,
|
||||
engine: str = "pt",
|
||||
return_dict: bool = False,
|
||||
world_size: int = 1,
|
||||
):
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
|
||||
# position_ids is of shape (batch_size, 1)
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=True)
|
||||
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16)
|
||||
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
|
||||
|
||||
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
||||
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
||||
|
|
@ -123,12 +124,13 @@ def get_merged_sample_with_past_kv_inputs(
|
|||
use_fp16: bool = False,
|
||||
engine: str = "pt",
|
||||
return_dict: bool = False,
|
||||
world_size: int = 1,
|
||||
):
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
|
||||
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
|
||||
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16)
|
||||
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
|
||||
|
||||
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
||||
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
||||
|
|
@ -220,8 +222,8 @@ def get_msft_sample_inputs(
|
|||
|
||||
# Create past_key_values
|
||||
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
|
||||
def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool):
|
||||
num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads
|
||||
def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
|
||||
num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
past_kv = [
|
||||
(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import List
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from dist_settings import get_rank, get_size
|
||||
from llama_inputs import (
|
||||
add_io_bindings,
|
||||
convert_inputs_for_ort,
|
||||
|
|
@ -14,9 +14,11 @@ from llama_inputs import (
|
|||
get_sample_inputs,
|
||||
get_sample_with_past_kv_inputs,
|
||||
)
|
||||
from llama_torch import setup_torch_model
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.transformers.benchmark_helper import setup_logger
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
|
@ -30,6 +32,7 @@ def get_sequence_lengths(args: argparse.Namespace):
|
|||
|
||||
def get_inputs(args: argparse.Namespace, config: LlamaConfig):
|
||||
# Dummy values for parity
|
||||
world_size = get_size()
|
||||
batch_size = 2
|
||||
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args)
|
||||
|
||||
|
|
@ -43,10 +46,17 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig):
|
|||
max_seq_len=max_sequence_length,
|
||||
use_fp16=args.use_fp16,
|
||||
return_dict=True,
|
||||
world_size=world_size,
|
||||
)
|
||||
elif args.use_past_kv:
|
||||
inputs = get_sample_with_past_kv_inputs(
|
||||
config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True
|
||||
config,
|
||||
args.device,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
use_fp16=args.use_fp16,
|
||||
return_dict=True,
|
||||
world_size=world_size,
|
||||
)
|
||||
else:
|
||||
inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
|
||||
|
|
@ -66,6 +76,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
|
|||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
logger.info(f"PyTorch took {end_time - start_time} s")
|
||||
del pt_model
|
||||
|
||||
# Run inference with ORT
|
||||
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args)
|
||||
|
|
@ -76,12 +87,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
|
|||
past_seq_len=past_sequence_length,
|
||||
max_seq_len=max_sequence_length,
|
||||
device=args.execution_provider,
|
||||
device_id=int(args.device_id),
|
||||
device_id=int(args.rank),
|
||||
)
|
||||
|
||||
ep = f"{args.execution_provider.upper()}ExecutionProvider"
|
||||
if ep == "CUDAExecutionProvider":
|
||||
ep = (ep, {"device_id": args.device_id})
|
||||
ep = (ep, {"device_id": args.rank})
|
||||
ort_model = ort.InferenceSession(
|
||||
args.onnx_model_path,
|
||||
sess_options=ort.SessionOptions(),
|
||||
|
|
@ -91,7 +102,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
|
|||
# Add IO bindings for non-CPU execution providers
|
||||
if args.execution_provider != "cpu":
|
||||
io_binding, kv_cache_ortvalues = add_io_bindings(
|
||||
ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues
|
||||
ort_model, inputs, args.execution_provider, int(args.rank), kv_cache_ortvalues
|
||||
)
|
||||
|
||||
io_binding.synchronize_inputs()
|
||||
|
|
@ -101,6 +112,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama
|
|||
end_time = time.time()
|
||||
|
||||
ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
|
||||
del ort_model
|
||||
|
||||
else:
|
||||
start_time = time.time()
|
||||
|
|
@ -155,15 +167,6 @@ def get_args(argv: List[str]):
|
|||
help="Execution provider to verify parity with",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-id",
|
||||
"--device-id",
|
||||
required=False,
|
||||
type=str,
|
||||
default="0",
|
||||
help="Device ID for GPUs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
|
|
@ -195,6 +198,14 @@ def get_args(argv: List[str]):
|
|||
help="Precision of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./model_cache",
|
||||
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
|
||||
)
|
||||
|
||||
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
|
||||
|
||||
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
||||
|
|
@ -210,21 +221,23 @@ def main(argv: List[str] = []): # noqa: B006
|
|||
args = get_args(argv)
|
||||
setup_logger(args.verbose)
|
||||
logger.info(f"Arguments: {args}")
|
||||
rank = get_rank()
|
||||
|
||||
# Load model and config
|
||||
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
|
||||
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{args.device_id}") # noqa: B010
|
||||
args.rank = rank
|
||||
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
|
||||
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
|
||||
use_auth_token = args.torch_model_directory == os.path.join(".")
|
||||
location = args.model_name if use_auth_token else args.torch_model_directory
|
||||
|
||||
config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token)
|
||||
llama = LlamaForCausalLM.from_pretrained(
|
||||
config, llama = setup_torch_model(
|
||||
args,
|
||||
location,
|
||||
use_auth_token,
|
||||
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
|
||||
use_auth_token=use_auth_token,
|
||||
use_cache=True,
|
||||
).to(args.device)
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
kv_cache_ortvalues = {}
|
||||
if not args.merged:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from dist_settings import barrier, get_rank, get_size
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, device=None):
|
||||
world_size = get_size()
|
||||
logger.info(f"world_size: {world_size}")
|
||||
rank = get_rank()
|
||||
barrier()
|
||||
|
||||
if not os.path.exists(args.cache_dir):
|
||||
os.makedirs(args.cache_dir, exist_ok=True)
|
||||
|
||||
for i in range(world_size):
|
||||
if i == rank % (world_size):
|
||||
l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir)
|
||||
l_config.use_cache = True
|
||||
llama = LlamaForCausalLM.from_pretrained(
|
||||
location,
|
||||
use_auth_token=use_auth_token,
|
||||
config=l_config,
|
||||
torch_dtype=torch_dtype,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
if world_size > 1:
|
||||
llama.parallel_model()
|
||||
if device:
|
||||
llama.to(device)
|
||||
llama.eval()
|
||||
llama.requires_grad_(False)
|
||||
barrier()
|
||||
return l_config, llama
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
-r requirements.txt
|
||||
git+https://github.com/frankdongms/transformers.git@frdong/shard_llama
|
||||
mpi4py
|
||||
psutil
|
||||
|
|
@ -337,6 +337,18 @@ class OnnxModel:
|
|||
return i, matched, return_indice
|
||||
return -1, None, None
|
||||
|
||||
def match_parent_paths_all(self, node, paths, output_name_to_node):
|
||||
match_i, matches, return_indices = [], [], []
|
||||
for i, path in enumerate(paths):
|
||||
assert isinstance(path, (List, Tuple))
|
||||
return_indice = []
|
||||
matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
|
||||
if matched:
|
||||
match_i.append(i)
|
||||
matches.append(matched)
|
||||
return_indices.append(return_indice)
|
||||
return match_i, matches, return_indices
|
||||
|
||||
def match_parent_path(
|
||||
self,
|
||||
node,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,10 @@
|
|||
# license information.
|
||||
# -------------------------------------------------------------------------
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
|
@ -22,6 +25,8 @@ from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
|||
|
||||
torch.manual_seed(0)
|
||||
|
||||
pipeline_mode = True # Reduces number of tests so pipeline doesn't time out
|
||||
|
||||
|
||||
class Formats:
|
||||
BSNH = 0
|
||||
|
|
@ -159,7 +164,7 @@ def create_multihead_attention_graph(config):
|
|||
return model.SerializeToString()
|
||||
|
||||
|
||||
def create_group_query_attention_graph_no_past(config, causal=False):
|
||||
def create_group_query_attention_graph_no_past(config, causal=False, present_kv_format=Formats.BSNH):
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"GroupQueryAttention",
|
||||
|
|
@ -168,11 +173,12 @@ def create_group_query_attention_graph_no_past(config, causal=False):
|
|||
"key",
|
||||
"value",
|
||||
],
|
||||
["output"],
|
||||
["output", "present_key", "present_value"],
|
||||
"GroupQueryAttention_0",
|
||||
num_heads=config.num_heads,
|
||||
kv_num_heads=config.kv_num_heads,
|
||||
unidirectional=1 if causal else 0,
|
||||
is_past_bsnh=1 if present_kv_format == Formats.BSNH else 0,
|
||||
domain="com.microsoft",
|
||||
),
|
||||
]
|
||||
|
|
@ -213,6 +219,26 @@ def create_group_query_attention_graph_no_past(config, causal=False):
|
|||
TensorProto.FLOAT16,
|
||||
[config.batch_size, config.sequence_length, config.num_heads * config.head_size],
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"present_key",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads,
|
||||
config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length,
|
||||
config.head_size,
|
||||
],
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"present_value",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads,
|
||||
config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length,
|
||||
config.head_size,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
|
|
@ -514,7 +540,6 @@ def generate_token_offset(cu_seqlens, max_seqlen):
|
|||
return numpy.asarray(token_offset + token_padset, dtype=numpy.int32)
|
||||
|
||||
|
||||
# TODO(aciddelgado): rename
|
||||
def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False):
|
||||
onnx_model_str = create_packed_multihead_attention_graph(config)
|
||||
qkv_unpad = torch.swapdims(qkv_unpad, 1, 2)
|
||||
|
|
@ -548,8 +573,8 @@ def mha_func(q, k, v, config):
|
|||
return output
|
||||
|
||||
|
||||
def gqa_no_past_func(q, k, v, config, causal=True):
|
||||
onnx_model_str = create_group_query_attention_graph_no_past(config, causal)
|
||||
def gqa_no_past_func(q, k, v, config, causal=True, present_kv_format=Formats.BSNH):
|
||||
onnx_model_str = create_group_query_attention_graph_no_past(config, causal, present_kv_format=present_kv_format)
|
||||
q = torch.reshape(q, (config.batch_size, config.sequence_length, -1))
|
||||
k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1))
|
||||
v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1))
|
||||
|
|
@ -560,7 +585,7 @@ def gqa_no_past_func(q, k, v, config, causal=True):
|
|||
}
|
||||
sess_options = SessionOptions()
|
||||
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"])
|
||||
ort_output = ort_session.run(None, ort_inputs)
|
||||
ort_output, _, _ = ort_session.run(None, ort_inputs)
|
||||
ort_output = numpy.array(ort_output)
|
||||
output = torch.tensor(ort_output)
|
||||
return output
|
||||
|
|
@ -689,17 +714,12 @@ def attention_ref(
|
|||
if key_padding_mask is not None:
|
||||
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
||||
if causal:
|
||||
# causal_mask = torch.triu(
|
||||
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
|
||||
# )
|
||||
causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device)
|
||||
scores.masked_fill_(causal_mask, float("-inf"))
|
||||
attention = torch.softmax(scores, dim=-1)
|
||||
if causal: # Some rows are completely masked out so we fill them with zero instead of NaN
|
||||
attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0)
|
||||
dropout_scaling = 1.0 / (1 - dropout_p)
|
||||
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
|
||||
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
||||
if dropout_mask is not None:
|
||||
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
|
||||
else:
|
||||
|
|
@ -1072,12 +1092,6 @@ def parity_check_gqa_past_no_buff(
|
|||
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
|
||||
out = out.detach().cpu().numpy()
|
||||
|
||||
# print(present_k[0, 0, config.past_sequence_length, :10])
|
||||
# print(k_cache_ref[0, 0, config.past_sequence_length, :10])
|
||||
# print(k_cache_ref.shape)
|
||||
|
||||
# print(present_k - k_cache_ref.detach().cpu().numpy())
|
||||
|
||||
# Make sure past-present buffer updating correctly
|
||||
if past_format == Formats.BSNH:
|
||||
assert numpy.allclose(
|
||||
|
|
@ -1141,84 +1155,185 @@ def parity_check_gqa_past_no_buff(
|
|||
)
|
||||
|
||||
|
||||
class TestMHA(unittest.TestCase):
|
||||
def test_packed_mha(self):
|
||||
if not torch.cuda.is_available() or platform.system() != "Linux":
|
||||
return
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 8:
|
||||
return
|
||||
print("-------- TEST PACKED MHA ---------")
|
||||
batches = [2] if pipeline_mode else [1, 5]
|
||||
seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]
|
||||
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
|
||||
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
for b in batches:
|
||||
for s in seqs:
|
||||
for n in num_h:
|
||||
for h in h_sizes:
|
||||
config = Config(b, s, s, 0, n, n, h)
|
||||
parity_check_mha(config, True)
|
||||
|
||||
def test_mha(self):
|
||||
if not torch.cuda.is_available() or platform.system() != "Linux":
|
||||
return
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 8:
|
||||
return
|
||||
print("-------- TEST MHA ---------")
|
||||
batches = [2] if pipeline_mode else [1, 5]
|
||||
seqs = (
|
||||
[(1, 128), (113, 211), (2048, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
]
|
||||
)
|
||||
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
|
||||
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n in num_h:
|
||||
for h in h_sizes:
|
||||
config = Config(b, s, s2, 0, n, n, h)
|
||||
parity_check_mha(config, False)
|
||||
|
||||
|
||||
class TestGQA(unittest.TestCase):
|
||||
def test_gqa_no_past(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
torch.manual_seed(69)
|
||||
print("-------- TEST GQA ---------")
|
||||
batches = [2] if pipeline_mode else [1, 5]
|
||||
seqs = (
|
||||
[(1, 128), (113, 211), (2048, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(2048, 2048),
|
||||
]
|
||||
)
|
||||
num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
if major < 5 or (major == 5 and minor < 3):
|
||||
return
|
||||
print("------- MEMORY EFFICIENT ATTENTION ---------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for causal in [True, False]:
|
||||
config = Config(b, s, s2, 0, n, n2, h)
|
||||
parity_check_gqa_no_past(config, causal=causal)
|
||||
if major < 8 or platform.system() != "Linux":
|
||||
return
|
||||
print("------- FLASH ATTENTION --------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for causal in [True, False]:
|
||||
config = Config(b, s, s2, 0, n, n2, h)
|
||||
parity_check_gqa_no_past(config, causal=causal)
|
||||
|
||||
def test_gqa_past(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 5 or (major == 5 and minor < 3):
|
||||
return
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
print("-------- TEST GQA PAST ---------")
|
||||
print("-------- MEMORY EFFICEINT --------")
|
||||
batches = [2] if pipeline_mode else [1, 2]
|
||||
seqs = (
|
||||
[(1, 128), (3, 1024), (64, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
(1, 339),
|
||||
(3, 1024),
|
||||
(64, 800),
|
||||
(64, 256),
|
||||
(3, 799),
|
||||
(64, 2048),
|
||||
(16, 20000),
|
||||
(1, 128 * 512),
|
||||
(16, 128 * 512),
|
||||
(128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for causal in [True]:
|
||||
for past_kv_format in [Formats.BNSH, Formats.BSNH]:
|
||||
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
|
||||
config = Config(b, s, s2, sp, n, n2, h)
|
||||
parity_check_gqa_past(
|
||||
config,
|
||||
causal=causal,
|
||||
past_format=past_kv_format,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
causal=causal,
|
||||
past_format=past_kv_format,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
if major < 8 or platform.system() != "Linux":
|
||||
return
|
||||
print("------- FLASH ATTENTION -------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for causal in [True]:
|
||||
for past_kv_format in [Formats.BNSH, Formats.BSNH]:
|
||||
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
|
||||
config = Config(b, s, s2, sp, n, n2, h)
|
||||
parity_check_gqa_past(
|
||||
config,
|
||||
causal=causal,
|
||||
past_format=past_kv_format,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
causal=causal,
|
||||
past_format=past_kv_format,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("-------- TEST PACKED MHA ---------")
|
||||
for b in [5]:
|
||||
for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]:
|
||||
for n in [6]:
|
||||
for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
|
||||
config = Config(b, s, s, 0, n, n, h)
|
||||
parity_check_mha(config, True)
|
||||
print("-------- TEST MHA ---------")
|
||||
for b in [5]:
|
||||
for s, s2 in [
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
]:
|
||||
for n in [6]:
|
||||
for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
|
||||
config = Config(b, s, s2, 0, n, n, h)
|
||||
parity_check_mha(config, False)
|
||||
print("-------- TEST GQA ---------")
|
||||
for b in [5]:
|
||||
for s, s2 in [
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
]:
|
||||
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
|
||||
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
|
||||
for causal in [True, False]:
|
||||
config = Config(b, s, s2, 0, n, n2, h)
|
||||
parity_check_gqa_no_past(config, causal=causal)
|
||||
print("-------- TEST GQA PAST ---------")
|
||||
random.seed(69)
|
||||
for b in [2]:
|
||||
for s, s2 in [
|
||||
(1, 128),
|
||||
(1, 339),
|
||||
(3, 1024),
|
||||
(64, 800),
|
||||
(64, 256),
|
||||
(3, 799),
|
||||
(64, 2048),
|
||||
(16, 20000),
|
||||
(1, 128 * 512),
|
||||
(16, 128 * 512),
|
||||
(128, 128),
|
||||
]:
|
||||
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
|
||||
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
|
||||
for causal in [True]:
|
||||
for past_kv_format in [Formats.BNSH, Formats.BSNH]:
|
||||
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
|
||||
config = Config(b, s, s2, sp, n, n2, h)
|
||||
parity_check_gqa_past(
|
||||
config,
|
||||
causal=causal,
|
||||
past_format=past_kv_format,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
causal=causal,
|
||||
past_format=past_kv_format,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class TestRotaryAttentionFusion(unittest.TestCase):
|
|||
helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]),
|
||||
helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size),
|
||||
]
|
||||
if model_type in {"past", "merged", "llama2_msft"}:
|
||||
if model_type in {"past", "merged", "llama2_msft", "70b_distributed_merged"}:
|
||||
inputs.extend(
|
||||
[
|
||||
helper.make_tensor_value_info(
|
||||
|
|
@ -164,14 +164,14 @@ class TestRotaryAttentionFusion(unittest.TestCase):
|
|||
if is_fused or model_type == "llama2_msft":
|
||||
# q_out/k_out
|
||||
return f"{node_type}_out"
|
||||
if model_type in {"no_past", "past", "merged"}:
|
||||
if model_type in {"no_past", "past", "merged", "70b_distributed_merged"}:
|
||||
if node_type == "k":
|
||||
return "k_before_rope"
|
||||
return "q_before_rope"
|
||||
return ""
|
||||
|
||||
def get_first_rope_output(node_type: str):
|
||||
if is_fused or model_type in {"llama2_msft", "past", "merged"}:
|
||||
if is_fused or model_type in {"llama2_msft", "past", "merged", "70b_distributed_merged"}:
|
||||
if node_type == "q":
|
||||
return "q_rope"
|
||||
return "k_rope"
|
||||
|
|
@ -295,23 +295,225 @@ class TestRotaryAttentionFusion(unittest.TestCase):
|
|||
)
|
||||
k_nodes = [reshape_k_node, transpose_k_1_node]
|
||||
|
||||
if model_type in {"past", "merged"}:
|
||||
if model_type == "70b_distributed_merged":
|
||||
concat_k_node = helper.make_node(
|
||||
"Concat",
|
||||
inputs=["past_key", "k_rope"],
|
||||
outputs=["present_key"],
|
||||
axis=2,
|
||||
)
|
||||
k_nodes.append(concat_k_node)
|
||||
shape_k1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k1_out"], name="Shape_k1")
|
||||
shape_k2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k2_out"], name="Shape_k2")
|
||||
shape_k3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k3_out"], name="Shape_k3")
|
||||
shape_k4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k4_out"], name="Shape_k4")
|
||||
|
||||
transpose_k_2_node = helper.make_node(
|
||||
"Transpose",
|
||||
inputs=["present_key"],
|
||||
outputs=["k"],
|
||||
name="Transpose_k_2",
|
||||
perm=[0, 1, 3, 2],
|
||||
)
|
||||
return k_nodes + [transpose_k_2_node] # noqa: RUF005
|
||||
gather_k_1 = helper.make_node(
|
||||
"Gather",
|
||||
inputs=["shape_k1_out", "one"],
|
||||
outputs=["gather_k1_out"],
|
||||
name="Gather_k_1",
|
||||
axis=0,
|
||||
)
|
||||
gather_k_2 = helper.make_node(
|
||||
"Gather",
|
||||
inputs=["shape_k2_out", "one"],
|
||||
outputs=["gather_k2_out"],
|
||||
name="Gather_k_2",
|
||||
axis=0,
|
||||
)
|
||||
gather_k_3 = helper.make_node(
|
||||
"Gather",
|
||||
inputs=["shape_k3_out", "one"],
|
||||
outputs=["gather_k3_out"],
|
||||
name="Gather_k_3",
|
||||
axis=0,
|
||||
)
|
||||
gather_k_4 = helper.make_node(
|
||||
"Gather",
|
||||
inputs=["shape_k4_out", "one"],
|
||||
outputs=["gather_k4_out"],
|
||||
name="Gather_k_4",
|
||||
axis=0,
|
||||
)
|
||||
|
||||
unsqueeze_k_1 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["present_value", "zero"],
|
||||
outputs=["unsqueeze_k1_out"],
|
||||
name="Unsqueeze_k1",
|
||||
)
|
||||
unsqueeze_k_2 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_k1_out", "zero"],
|
||||
outputs=["unsqueeze_k2_out"],
|
||||
name="Unsqueeze_k2",
|
||||
)
|
||||
unsqueeze_k_3 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_k2_out", "zero"],
|
||||
outputs=["unsqueeze_k3_out"],
|
||||
name="Unsqueeze_k3",
|
||||
)
|
||||
unsqueeze_k_4 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_k3_out", "zero"],
|
||||
outputs=["unsqueeze_k4_out"],
|
||||
name="Unsqueeze_k4",
|
||||
)
|
||||
unsqueeze_k_5 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_k4_out", "zero"],
|
||||
outputs=["unsqueeze_k5_out"],
|
||||
name="Unsqueeze_k5",
|
||||
)
|
||||
|
||||
concat_k_2 = helper.make_node(
|
||||
"Concat",
|
||||
inputs=["unsqueeze_k2_out", "unsqueeze_k3_out", "One", "unsqueeze_k4_out", "unsqueeze_k5_out"],
|
||||
outputs=["concat_k2_ouot"],
|
||||
name="Concat_k2",
|
||||
axis=0,
|
||||
)
|
||||
reshape_k_2 = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=["concat_k2_ouot", "One"],
|
||||
outputs=["reshape_k2_out"],
|
||||
name="Reshape_k_2",
|
||||
)
|
||||
shape_k5 = helper.make_node("Shape", inputs=["reshape_k2_out"], outputs=["shape_k5_out"], name="Shape_k5")
|
||||
constant_of_shape_k_1 = helper.make_node(
|
||||
"ConstantOfShape",
|
||||
inputs=["shape_k5_out"],
|
||||
outputs=["constant_of_shape_k1_out"],
|
||||
name="ConstantOfShape_k1",
|
||||
)
|
||||
mul_k_1 = helper.make_node(
|
||||
"Mul",
|
||||
inputs=["constant_of_shape_k1_out", "One"],
|
||||
outputs=["mul_k1_out"],
|
||||
name="mul_k1",
|
||||
)
|
||||
equal_k_1 = helper.make_node(
|
||||
"Equal",
|
||||
inputs=["reshape_k2_out", "mul_k1_out"],
|
||||
outputs=["equal_k_1_out"],
|
||||
name="equal_k1",
|
||||
)
|
||||
where_k_1 = helper.make_node(
|
||||
"Where",
|
||||
inputs=["equal_k_1_out", "constant_of_shape_k1_out", "reshape_k2_out"],
|
||||
outputs=["where_k_1_out"],
|
||||
name="where_k1",
|
||||
)
|
||||
unsqueeze_k_6 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_k1_out", "zero"],
|
||||
outputs=["unsqueeze_k6_out"],
|
||||
name="Unsqueeze_k6",
|
||||
)
|
||||
mul_k_2 = helper.make_node(
|
||||
"Mul",
|
||||
inputs=["gather_k2_out", "One"],
|
||||
outputs=["mul_k2_out"],
|
||||
name="mul_k2",
|
||||
)
|
||||
unsqueeze_k_7 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["mul_k2_out", "zero"],
|
||||
outputs=["unsqueeze_k7_out"],
|
||||
name="Unsqueeze_k7",
|
||||
)
|
||||
unsqueeze_k_8 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_k3_out", "zero"],
|
||||
outputs=["unsqueeze_k8_out"],
|
||||
name="Unsqueeze_k8",
|
||||
)
|
||||
unsqueeze_k_9 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_k4_out", "zero"],
|
||||
outputs=["unsqueeze_k9_out"],
|
||||
name="Unsqueeze_k9",
|
||||
)
|
||||
concat_k_3 = helper.make_node(
|
||||
"Concat",
|
||||
inputs=["unsqueeze_k6_out", "unsqueeze_k7_out", "unsqueeze_k8_out", "unsqueeze_k9_out"],
|
||||
outputs=["concat_k3_out"],
|
||||
name="Concat_k3",
|
||||
axis=0,
|
||||
)
|
||||
expand_k_1 = helper.make_node(
|
||||
"Expand",
|
||||
inputs=["unsqueeze_k1_out", "where_k_1_out"],
|
||||
outputs=["expand_k1_out"],
|
||||
name="expand_k1",
|
||||
)
|
||||
reshape_k_3 = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=["expand_k1_out", "concat_k3_out"],
|
||||
outputs=["reshape_k3_out"],
|
||||
name="Reshape_k_3",
|
||||
)
|
||||
transpose_k_2_node = helper.make_node(
|
||||
"Transpose",
|
||||
inputs=["reshape_k3_out"],
|
||||
outputs=["k"],
|
||||
name="Transpose_k_2",
|
||||
perm=[0, 1, 3, 2],
|
||||
)
|
||||
|
||||
k_nodes_for_70b_model = [
|
||||
concat_k_node,
|
||||
shape_k1,
|
||||
shape_k2,
|
||||
shape_k3,
|
||||
shape_k4,
|
||||
gather_k_1,
|
||||
gather_k_2,
|
||||
gather_k_3,
|
||||
gather_k_4,
|
||||
unsqueeze_k_1,
|
||||
unsqueeze_k_2,
|
||||
unsqueeze_k_3,
|
||||
unsqueeze_k_4,
|
||||
unsqueeze_k_5,
|
||||
concat_k_2,
|
||||
reshape_k_2,
|
||||
shape_k5,
|
||||
constant_of_shape_k_1,
|
||||
mul_k_1,
|
||||
equal_k_1,
|
||||
where_k_1,
|
||||
unsqueeze_k_6,
|
||||
mul_k_2,
|
||||
unsqueeze_k_7,
|
||||
unsqueeze_k_8,
|
||||
unsqueeze_k_9,
|
||||
concat_k_3,
|
||||
expand_k_1,
|
||||
reshape_k_3,
|
||||
transpose_k_2_node,
|
||||
]
|
||||
k_nodes.extend(k_nodes_for_70b_model)
|
||||
return k_nodes
|
||||
else:
|
||||
if model_type in {"past", "merged"}:
|
||||
concat_k_node = helper.make_node(
|
||||
"Concat",
|
||||
inputs=["past_key", "k_rope"],
|
||||
outputs=["present_key"],
|
||||
axis=2,
|
||||
)
|
||||
k_nodes.append(concat_k_node)
|
||||
|
||||
transpose_k_2_node = helper.make_node(
|
||||
"Transpose",
|
||||
inputs=["present_key"],
|
||||
outputs=["k"],
|
||||
name="Transpose_k_2",
|
||||
perm=[0, 1, 3, 2],
|
||||
)
|
||||
return k_nodes + [transpose_k_2_node] # noqa: RUF005
|
||||
|
||||
def create_k_path(self, model_type: str):
|
||||
if model_type == "llama2_msft":
|
||||
|
|
@ -505,7 +707,7 @@ class TestRotaryAttentionFusion(unittest.TestCase):
|
|||
if model_type == "no_past":
|
||||
return v_nodes
|
||||
|
||||
if model_type in {"past", "merged"}:
|
||||
if model_type in {"past", "merged", "70b_distributed_merged"}:
|
||||
concat_v_node = helper.make_node(
|
||||
"Concat",
|
||||
inputs=["past_value", "transpose_v_1_out"],
|
||||
|
|
@ -513,7 +715,194 @@ class TestRotaryAttentionFusion(unittest.TestCase):
|
|||
name="Concat_v",
|
||||
axis=2,
|
||||
)
|
||||
return v_nodes + [concat_v_node] # noqa: RUF005
|
||||
|
||||
if model_type != "70b_distributed_merged":
|
||||
return v_nodes + [concat_v_node] # noqa: RUF005
|
||||
|
||||
shape_v1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_1_out"], name="Shape_v1")
|
||||
shape_v2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_2_out"], name="Shape_v2")
|
||||
shape_v3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_3_out"], name="Shape_v3")
|
||||
shape_v4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_4_out"], name="Shape_v4")
|
||||
gather_v_1 = helper.make_node(
|
||||
"Gather",
|
||||
inputs=["shape_1_out", "one"],
|
||||
outputs=["gather_1_out"],
|
||||
name="Gather_v1",
|
||||
axis=0,
|
||||
)
|
||||
gather_v_2 = helper.make_node(
|
||||
"Gather",
|
||||
inputs=["shape_2_out", "one"],
|
||||
outputs=["gather_2_out"],
|
||||
name="Gather_v2",
|
||||
axis=0,
|
||||
)
|
||||
gather_v_3 = helper.make_node(
|
||||
"Gather",
|
||||
inputs=["shape_3_out", "one"],
|
||||
outputs=["gather_3_out"],
|
||||
name="Gather_v3",
|
||||
axis=0,
|
||||
)
|
||||
gather_v_4 = helper.make_node(
|
||||
"Gather",
|
||||
inputs=["shape_4_out", "one"],
|
||||
outputs=["gather_4_out"],
|
||||
name="Gather_v4",
|
||||
axis=0,
|
||||
)
|
||||
unsqueeze_v_1 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["present_value", "zero"],
|
||||
outputs=["unsqueeze_v1_out"],
|
||||
name="Unsqueeze_v1",
|
||||
)
|
||||
unsqueeze_v_2 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_1_out", "zero"],
|
||||
outputs=["unsqueeze_v2_out"],
|
||||
name="Unsqueeze_v2",
|
||||
)
|
||||
unsqueeze_v_3 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_2_out", "zero"],
|
||||
outputs=["unsqueeze_v3_out"],
|
||||
name="Unsqueeze_v3",
|
||||
)
|
||||
unsqueeze_v_4 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_3_out", "zero"],
|
||||
outputs=["unsqueeze_v4_out"],
|
||||
name="Unsqueeze_v4",
|
||||
)
|
||||
unsqueeze_v_5 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_4_out", "zero"],
|
||||
outputs=["unsqueeze_v5_out"],
|
||||
name="Unsqueeze_v5",
|
||||
)
|
||||
concat_v_2 = helper.make_node(
|
||||
"Concat",
|
||||
inputs=["unsqueeze_v2_out", "unsqueeze_v3_out", "One", "unsqueeze_v4_out", "unsqueeze_v5_out"],
|
||||
outputs=["concat_v2_ouot"],
|
||||
name="Concat_v2",
|
||||
axis=0,
|
||||
)
|
||||
reshape_v_2 = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=["concat_v2_ouot", "One"],
|
||||
outputs=["reshape_v2_out"],
|
||||
name="Reshape_v2",
|
||||
)
|
||||
shape_v5 = helper.make_node("Shape", inputs=["reshape_v2_out"], outputs=["shape_5_out"], name="Shape_v5")
|
||||
constant_of_shape_v_1 = helper.make_node(
|
||||
"ConstantOfShape",
|
||||
inputs=["shape_5_out"],
|
||||
outputs=["constant_of_shape_v1_out"],
|
||||
name="ConstantOfShape_v1",
|
||||
)
|
||||
mul_v_1 = helper.make_node(
|
||||
"Mul",
|
||||
inputs=["constant_of_shape_v1_out", "One"],
|
||||
outputs=["mul_v1_out"],
|
||||
name="mul_v1",
|
||||
)
|
||||
equal_v_1 = helper.make_node(
|
||||
"Equal",
|
||||
inputs=["reshape_v2_out", "mul_v1_out"],
|
||||
outputs=["equal_v_1_out"],
|
||||
name="equal_v1",
|
||||
)
|
||||
where_v_1 = helper.make_node(
|
||||
"Where",
|
||||
inputs=["equal_v_1_out", "constant_of_shape_v1_out", "reshape_v2_out"],
|
||||
outputs=["where_v_1_out"],
|
||||
name="where_v1",
|
||||
)
|
||||
unsqueeze_v_6 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_1_out", "zero"],
|
||||
outputs=["unsqueeze_v6_out"],
|
||||
name="Unsqueeze_v6",
|
||||
)
|
||||
mul_v_2 = helper.make_node(
|
||||
"Mul",
|
||||
inputs=["gather_2_out", "One"],
|
||||
outputs=["mul_v2_out"],
|
||||
name="mul_v2",
|
||||
)
|
||||
unsqueeze_v_7 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["mul_v2_out", "zero"],
|
||||
outputs=["unsqueeze_v7_out"],
|
||||
name="Unsqueeze_v7",
|
||||
)
|
||||
unsqueeze_v_8 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_3_out", "zero"],
|
||||
outputs=["unsqueeze_v8_out"],
|
||||
name="Unsqueeze_v8",
|
||||
)
|
||||
unsqueeze_v_9 = helper.make_node(
|
||||
"Unsqueeze",
|
||||
inputs=["gather_4_out", "zero"],
|
||||
outputs=["unsqueeze_v9_out"],
|
||||
name="Unsqueeze_v9",
|
||||
)
|
||||
concat_v_3 = helper.make_node(
|
||||
"Concat",
|
||||
inputs=["unsqueeze_v6_out", "unsqueeze_v7_out", "unsqueeze_v8_out", "unsqueeze_v9_out"],
|
||||
outputs=["concat_v3_out"],
|
||||
name="Concat_v3",
|
||||
axis=0,
|
||||
)
|
||||
expand_v_1 = helper.make_node(
|
||||
"Expand",
|
||||
inputs=["unsqueeze_v1_out", "where_v_1_out"],
|
||||
outputs=["expand_v1_out"],
|
||||
name="expand_v1",
|
||||
)
|
||||
reshape_v_3 = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=["expand_v1_out", "concat_v3_out"],
|
||||
outputs=["reshape_v3_out"],
|
||||
name="Reshape_v3",
|
||||
)
|
||||
|
||||
v_nodes_for_70b_model = [
|
||||
concat_v_node,
|
||||
shape_v1,
|
||||
shape_v2,
|
||||
shape_v3,
|
||||
shape_v4,
|
||||
gather_v_1,
|
||||
gather_v_2,
|
||||
gather_v_3,
|
||||
gather_v_4,
|
||||
unsqueeze_v_1,
|
||||
unsqueeze_v_2,
|
||||
unsqueeze_v_3,
|
||||
unsqueeze_v_4,
|
||||
unsqueeze_v_5,
|
||||
concat_v_2,
|
||||
reshape_v_2,
|
||||
shape_v5,
|
||||
constant_of_shape_v_1,
|
||||
mul_v_1,
|
||||
equal_v_1,
|
||||
where_v_1,
|
||||
unsqueeze_v_6,
|
||||
mul_v_2,
|
||||
unsqueeze_v_7,
|
||||
unsqueeze_v_8,
|
||||
unsqueeze_v_9,
|
||||
concat_v_3,
|
||||
expand_v_1,
|
||||
reshape_v_3,
|
||||
]
|
||||
v_nodes.extend(v_nodes_for_70b_model)
|
||||
|
||||
return v_nodes
|
||||
|
||||
# Create extra nodes for `position_ids`
|
||||
unsqueeze_v_node = helper.make_node(
|
||||
|
|
@ -672,7 +1061,28 @@ class TestRotaryAttentionFusion(unittest.TestCase):
|
|||
|
||||
return extra_nodes
|
||||
|
||||
def create_end_nodes(self):
|
||||
def create_end_nodes(self, model_type):
|
||||
if model_type == "70b_distributed_merged":
|
||||
matmul_o_node = helper.make_node(
|
||||
"MatMul",
|
||||
inputs=["attn_output", "o_weight"],
|
||||
outputs=["output_proj"],
|
||||
name="MatMul_o_proj",
|
||||
)
|
||||
all_reduce = helper.make_node(
|
||||
"AllReduce",
|
||||
inputs=["output_proj"],
|
||||
outputs=["allreduce_proj"],
|
||||
name="allreduce_proj",
|
||||
)
|
||||
end_node = helper.make_node(
|
||||
"Add",
|
||||
inputs=["zero", "allreduce_proj"],
|
||||
outputs=["output_0"],
|
||||
name="Add_normalize_node",
|
||||
)
|
||||
return [matmul_o_node, all_reduce, end_node]
|
||||
|
||||
matmul_o_node = helper.make_node(
|
||||
"MatMul",
|
||||
inputs=["attn_output", "o_weight"],
|
||||
|
|
@ -711,7 +1121,7 @@ class TestRotaryAttentionFusion(unittest.TestCase):
|
|||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
end_nodes = self.create_end_nodes()
|
||||
end_nodes = self.create_end_nodes(model_type)
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes,
|
||||
|
|
@ -740,7 +1150,7 @@ class TestRotaryAttentionFusion(unittest.TestCase):
|
|||
reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes))
|
||||
extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes)
|
||||
|
||||
end_nodes = self.create_end_nodes()
|
||||
end_nodes = self.create_end_nodes(model_type)
|
||||
|
||||
first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes
|
||||
second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes
|
||||
|
|
@ -790,6 +1200,11 @@ class TestRotaryAttentionFusion(unittest.TestCase):
|
|||
interleaved = False
|
||||
self.check_models(model_type, interleaved)
|
||||
|
||||
def test_hf_70b_distributed_decoder_merged_model(self):
|
||||
model_type = "70b_distributed_merged"
|
||||
interleaved = False
|
||||
self.check_models(model_type, interleaved)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue