mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
optimize CPU implementation of Attention (#2496)
This commit is contained in:
parent
0f57e0a49e
commit
89824b35e9
1 changed files with 148 additions and 123 deletions
|
|
@ -107,159 +107,184 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
|
||||
const auto dims = input->Shape().GetDims();
|
||||
int batch_size = static_cast<int>(dims[0]);
|
||||
int sequence_length = static_cast<int>(dims[1]);
|
||||
int hidden_size = static_cast<int>(dims[2]);
|
||||
int head_size = hidden_size / num_heads_;
|
||||
const int batch_size = static_cast<int>(dims[0]);
|
||||
const int sequence_length = static_cast<int>(dims[1]);
|
||||
const int hidden_size = static_cast<int>(dims[2]);
|
||||
const int head_size = hidden_size / num_heads_;
|
||||
|
||||
TensorShape output_shape(dims);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
const size_t element_size = sizeof(T);
|
||||
|
||||
// Use GEMM for fully connection.
|
||||
int m = batch_size * sequence_length;
|
||||
int n = 3 * hidden_size;
|
||||
int k = hidden_size;
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
|
||||
// STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(NH)
|
||||
// STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH)
|
||||
auto gemm_data = allocator->Alloc(batch_size * sequence_length * 3 * hidden_size * element_size);
|
||||
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
|
||||
auto Q = reinterpret_cast<T*>(gemm_data);
|
||||
auto K = Q + batch_size * sequence_length * hidden_size;
|
||||
auto V = K + batch_size * sequence_length * hidden_size;
|
||||
|
||||
auto gemm_data_mat = EigenMatrixMapRowMajor<T>(reinterpret_cast<T*>(gemm_data), m, n);
|
||||
gemm_data_mat.rowwise() = ConstEigenVectorMap<T>(bias->template Data<T>(), n).transpose();
|
||||
T* QKV[3] = {Q, K, V};
|
||||
|
||||
math::Gemm<T>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0f,
|
||||
input->template Data<T>(),
|
||||
weights->template Data<T>(),
|
||||
1.0f,
|
||||
reinterpret_cast<T*>(gemm_data),
|
||||
tp);
|
||||
{
|
||||
const int loop_len = 3 * batch_size * num_heads_;
|
||||
const auto input_data = input->template Data<T>();
|
||||
const auto weights_data = weights->template Data<T>();
|
||||
const auto bias_data = bias->template Data<T>();
|
||||
|
||||
// STEP.2: gemm_data_transposed(3, B, N, S, H) = transpose gemm_data(B, S, 3, N, H)
|
||||
auto gemm_data_transposed = allocator->Alloc(batch_size * sequence_length * 3 * hidden_size * element_size);
|
||||
BufferUniquePtr gemm_transposed_buffer(gemm_data_transposed, BufferDeleter(allocator));
|
||||
concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), loop_len, [&](int32_t i) {
|
||||
const int batch_index = (i / 3) / num_heads_;
|
||||
const int head_index = (i / 3) % num_heads_;
|
||||
const int qkv_index = i % 3;
|
||||
|
||||
Tensor gemm_data_tensor{input->DataType(), TensorShape{batch_size, sequence_length, 3, num_heads_, head_size}, gemm_data, allocator->Info()};
|
||||
Tensor gemm_data_transposed_tensor{input->DataType(), TensorShape{3, batch_size, num_heads_, sequence_length, head_size}, gemm_data_transposed, allocator->Info()};
|
||||
int input_offset = batch_index * sequence_length * hidden_size;
|
||||
int weights_offset = qkv_index * hidden_size + head_index * head_size;
|
||||
T* qkv_dest = QKV[qkv_index];
|
||||
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
|
||||
|
||||
static const std::vector<size_t> transpose_permutations{2, 0, 3, 1, 4};
|
||||
ORT_RETURN_IF_ERROR(TransposeBase::DoTranspose(transpose_permutations, gemm_data_tensor, gemm_data_transposed_tensor));
|
||||
// broadcast 3NH -> (3.B.N.S.H)
|
||||
const T* broadcast_data_src = bias_data + weights_offset;
|
||||
T* broadcast_data_dest = QKV[qkv_index] + qkv_offset;
|
||||
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
|
||||
memcpy(broadcast_data_dest, broadcast_data_src, head_size * sizeof(T));
|
||||
broadcast_data_dest += head_size;
|
||||
}
|
||||
|
||||
T* Q = reinterpret_cast<T*>(gemm_data_transposed);
|
||||
T* K = Q + (batch_size * hidden_size * sequence_length);
|
||||
T* V = K + (batch_size * hidden_size * sequence_length);
|
||||
// original transposed iteration
|
||||
// A: input (BxSxNxH) (B.)S x NH S x NH
|
||||
// B: weights (NxHx3xNxH) NH x (3.N.)H NH x H
|
||||
// C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H
|
||||
|
||||
// STEP.3: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1, 1, 1)
|
||||
math::GemmEx<float, concurrency::ThreadPool>(CblasNoTrans, // TransA = no
|
||||
CblasNoTrans, // TransB = no
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
hidden_size, // K = NH
|
||||
1.0f, // alpha
|
||||
input_data + input_offset, // A
|
||||
hidden_size, // lda = NH
|
||||
weights_data + weights_offset, // B
|
||||
3 * hidden_size, // ldb = 3NH
|
||||
1.0f, // beta
|
||||
qkv_dest + qkv_offset, // C
|
||||
head_size, // ldc
|
||||
nullptr // use single-thread
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// STEP.2: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1, 1, 1)
|
||||
auto scratch_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * sequence_length * element_size);
|
||||
BufferUniquePtr scratch_buffer(scratch_data, BufferDeleter(allocator));
|
||||
|
||||
{
|
||||
memset(scratch_data, 0, batch_size * num_heads_ * sequence_length * sequence_length * element_size);
|
||||
auto mask_data = mask_index->template Data<int>();
|
||||
int size_each_batch = num_heads_ * sequence_length;
|
||||
T* p_current_data = reinterpret_cast<T*>(scratch_data);
|
||||
auto scratch_broadcast_data = allocator->Alloc(batch_size * sequence_length * element_size);
|
||||
BufferUniquePtr scratch_broadcast_buffer(scratch_broadcast_data, BufferDeleter(allocator));
|
||||
memset(scratch_broadcast_data, 0, batch_size * sequence_length * element_size);
|
||||
T* p_scratch_broadcast_current_data = reinterpret_cast<T*>(scratch_broadcast_data);
|
||||
for (int b_i = 0; b_i < batch_size; b_i++) {
|
||||
int mask = mask_data[b_i];
|
||||
for (int n_i = 0; n_i < size_each_batch; n_i++) {
|
||||
for (int m_i = mask; m_i < sequence_length; m_i++) {
|
||||
p_current_data[m_i] = static_cast<T>(-10000.0);
|
||||
// TODO: mask_index can be used in softmax to save some calculation.
|
||||
int mask = mask_index->template Data<int32_t>()[b_i];
|
||||
for (int m_i = mask; m_i < sequence_length; m_i++) {
|
||||
p_scratch_broadcast_current_data[m_i] = static_cast<T>(-10000.0);
|
||||
}
|
||||
p_scratch_broadcast_current_data += sequence_length;
|
||||
}
|
||||
|
||||
const int loop_len = batch_size * num_heads_;
|
||||
const float alpha = 1.0f / sqrt(static_cast<float>(head_size));
|
||||
|
||||
concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), loop_len, [&](int32_t i) {
|
||||
const int batch_index = i / num_heads_;
|
||||
// broadcast masks (B) -> (B.N.)S.S
|
||||
const T* broadcast_data_src = reinterpret_cast<T*>(scratch_broadcast_data) + batch_index * sequence_length;
|
||||
T* broadcast_data_dest = reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i;
|
||||
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
|
||||
memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T));
|
||||
broadcast_data_dest += sequence_length;
|
||||
}
|
||||
|
||||
// gemm
|
||||
|
||||
// original transposed iteration
|
||||
// A: Q (BxNxSxH) (B.N.)S x H S x H
|
||||
// B: K' (BxNxSxH) (B.N.)H x S H x S
|
||||
// C: scratch_data (BxNxSxS) (B.N.)S x S S x S
|
||||
|
||||
math::Gemm<T, concurrency::ThreadPool>(
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
sequence_length,
|
||||
sequence_length,
|
||||
head_size,
|
||||
alpha,
|
||||
Q + sequence_length * head_size * i,
|
||||
K + sequence_length * head_size * i,
|
||||
1.0,
|
||||
reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i,
|
||||
nullptr);
|
||||
});
|
||||
}
|
||||
|
||||
// STEP.3: P(B, N, S, S) = Softmax(scratch)
|
||||
{
|
||||
const int N = batch_size * num_heads_ * sequence_length;
|
||||
const int D = sequence_length;
|
||||
|
||||
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), N, [&](int j) {
|
||||
float* x = reinterpret_cast<T*>(scratch_data) + j * D;
|
||||
float* y = x;
|
||||
|
||||
for (int i = 0; i < D; i++)
|
||||
y[i] = expf(x[i]);
|
||||
|
||||
double sum = 0.0;
|
||||
|
||||
for (int i = 0; i < D; i++) {
|
||||
sum += x[i];
|
||||
}
|
||||
|
||||
if (sum == 0) {
|
||||
for (int i = 0; i < D; i++) {
|
||||
y[i] = 1.0f / (float)D;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < D; i++) {
|
||||
y[i] = x[i] / (float)sum;
|
||||
}
|
||||
p_current_data += sequence_length;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
int offset_Q = 0;
|
||||
int offset_Q_increment = sequence_length * head_size;
|
||||
int offset_scratch = 0;
|
||||
int offset_scratch_increment = sequence_length * sequence_length;
|
||||
|
||||
for (int b_i = 0; b_i < batch_size; b_i++) {
|
||||
for (int n_i = 0; n_i < num_heads_; n_i++) {
|
||||
math::Gemm<T>(
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
sequence_length,
|
||||
sequence_length,
|
||||
head_size,
|
||||
1.0f / sqrt(static_cast<float>(head_size)),
|
||||
Q + offset_Q,
|
||||
K + offset_Q,
|
||||
1.0,
|
||||
reinterpret_cast<T*>(scratch_data) + offset_scratch,
|
||||
tp);
|
||||
offset_Q += offset_Q_increment;
|
||||
offset_scratch += offset_scratch_increment;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// STEP.4: P(B, N, S, S) = Softmax(scratch)
|
||||
auto p_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * sequence_length * element_size);
|
||||
BufferUniquePtr p_buffer(p_data, BufferDeleter(allocator));
|
||||
|
||||
{
|
||||
int N = batch_size * num_heads_ * sequence_length;
|
||||
int D = sequence_length;
|
||||
|
||||
Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> X_tensor(
|
||||
reinterpret_cast<T*>(scratch_data), N, D);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> Y_tensor(
|
||||
reinterpret_cast<T*>(p_data), N, D);
|
||||
#ifndef USE_OPENMP
|
||||
if (tp == nullptr)
|
||||
#endif
|
||||
ComputeSoftMax(Eigen::DefaultDevice(), X_tensor, Y_tensor, false);
|
||||
#ifndef USE_OPENMP
|
||||
else
|
||||
ComputeSoftMax(Eigen::ThreadPoolDevice(&tp->GetHandler(), tp->NumThreads()), X_tensor, Y_tensor, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
// STEP.5: out_tmp(B, N, S, H) = P(B, N, S, S) x V(B, N, S, H)
|
||||
// STEP.4: out_tmp(B, N, S, H) = P(B, N, S, S) x V(B, N, S, H)
|
||||
auto out_tmp_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * head_size * element_size);
|
||||
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
|
||||
{
|
||||
int offset_p = 0;
|
||||
int offset_p_increment = sequence_length * sequence_length;
|
||||
int offset_V = 0;
|
||||
int offset_V_increment = sequence_length * head_size;
|
||||
|
||||
for (int b_i = 0; b_i < batch_size; b_i++) {
|
||||
for (int n_i = 0; n_i < num_heads_; n_i++) {
|
||||
math::MatMul<T>(
|
||||
sequence_length,
|
||||
head_size,
|
||||
sequence_length,
|
||||
reinterpret_cast<T*>(p_data) + offset_p,
|
||||
V + offset_V,
|
||||
reinterpret_cast<T*>(out_tmp_data) + offset_V,
|
||||
tp);
|
||||
offset_p += offset_p_increment;
|
||||
offset_V += offset_V_increment;
|
||||
}
|
||||
concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), batch_size * num_heads_, [&](int i) {
|
||||
T* current_tmp_data = reinterpret_cast<T*>(out_tmp_data) + sequence_length * head_size * i;
|
||||
math::MatMul<T>(
|
||||
sequence_length,
|
||||
head_size,
|
||||
sequence_length,
|
||||
reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i,
|
||||
V + sequence_length * head_size * i,
|
||||
current_tmp_data,
|
||||
nullptr);
|
||||
|
||||
// transpose: out(B, S, N, H) = transpose out_tmp(B, N, S, H)
|
||||
const int batch_index = i / num_heads_;
|
||||
const int head_index = i % num_heads_;
|
||||
T* src = current_tmp_data;
|
||||
T* dest = output->template MutableData<T>() + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
|
||||
for (int j = 0; j < sequence_length; j++) {
|
||||
memcpy(dest, src, head_size * sizeof(T));
|
||||
src += head_size;
|
||||
dest += hidden_size;
|
||||
}
|
||||
}
|
||||
|
||||
// STEP.6: out(B, S, N, H) = transpose out_tmp(B, N, S, H)
|
||||
Tensor out_tmp_tensor{input->DataType(), TensorShape{batch_size, num_heads_, sequence_length, head_size}, out_tmp_data, allocator->Info()};
|
||||
Tensor output_tensor{input->DataType(), TensorShape{batch_size, sequence_length, num_heads_, head_size}, output->template MutableData<T>(), allocator->Info()};
|
||||
|
||||
static const std::vector<size_t> transpose_out_permutations{0, 2, 1, 3};
|
||||
ORT_RETURN_IF_ERROR(TransposeBase::DoTranspose(transpose_out_permutations, out_tmp_tensor, output_tensor));
|
||||
});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue