optimize CPU implementation of Attention (#2496)

This commit is contained in:
Yulong Wang 2019-12-01 14:43:38 -08:00 committed by GitHub
parent 0f57e0a49e
commit 89824b35e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -107,159 +107,184 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
const Tensor* mask_index = context->Input<Tensor>(3);
const auto dims = input->Shape().GetDims();
int batch_size = static_cast<int>(dims[0]);
int sequence_length = static_cast<int>(dims[1]);
int hidden_size = static_cast<int>(dims[2]);
int head_size = hidden_size / num_heads_;
const int batch_size = static_cast<int>(dims[0]);
const int sequence_length = static_cast<int>(dims[1]);
const int hidden_size = static_cast<int>(dims[2]);
const int head_size = hidden_size / num_heads_;
TensorShape output_shape(dims);
Tensor* output = context->Output(0, output_shape);
const size_t element_size = sizeof(T);
// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int k = hidden_size;
constexpr size_t element_size = sizeof(T);
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
// STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(NH)
// STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH)
auto gemm_data = allocator->Alloc(batch_size * sequence_length * 3 * hidden_size * element_size);
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
auto Q = reinterpret_cast<T*>(gemm_data);
auto K = Q + batch_size * sequence_length * hidden_size;
auto V = K + batch_size * sequence_length * hidden_size;
auto gemm_data_mat = EigenMatrixMapRowMajor<T>(reinterpret_cast<T*>(gemm_data), m, n);
gemm_data_mat.rowwise() = ConstEigenVectorMap<T>(bias->template Data<T>(), n).transpose();
T* QKV[3] = {Q, K, V};
math::Gemm<T>(
CblasNoTrans,
CblasNoTrans,
m,
n,
k,
1.0f,
input->template Data<T>(),
weights->template Data<T>(),
1.0f,
reinterpret_cast<T*>(gemm_data),
tp);
{
const int loop_len = 3 * batch_size * num_heads_;
const auto input_data = input->template Data<T>();
const auto weights_data = weights->template Data<T>();
const auto bias_data = bias->template Data<T>();
// STEP.2: gemm_data_transposed(3, B, N, S, H) = transpose gemm_data(B, S, 3, N, H)
auto gemm_data_transposed = allocator->Alloc(batch_size * sequence_length * 3 * hidden_size * element_size);
BufferUniquePtr gemm_transposed_buffer(gemm_data_transposed, BufferDeleter(allocator));
concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), loop_len, [&](int32_t i) {
const int batch_index = (i / 3) / num_heads_;
const int head_index = (i / 3) % num_heads_;
const int qkv_index = i % 3;
Tensor gemm_data_tensor{input->DataType(), TensorShape{batch_size, sequence_length, 3, num_heads_, head_size}, gemm_data, allocator->Info()};
Tensor gemm_data_transposed_tensor{input->DataType(), TensorShape{3, batch_size, num_heads_, sequence_length, head_size}, gemm_data_transposed, allocator->Info()};
int input_offset = batch_index * sequence_length * hidden_size;
int weights_offset = qkv_index * hidden_size + head_index * head_size;
T* qkv_dest = QKV[qkv_index];
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
static const std::vector<size_t> transpose_permutations{2, 0, 3, 1, 4};
ORT_RETURN_IF_ERROR(TransposeBase::DoTranspose(transpose_permutations, gemm_data_tensor, gemm_data_transposed_tensor));
// broadcast 3NH -> (3.B.N.S.H)
const T* broadcast_data_src = bias_data + weights_offset;
T* broadcast_data_dest = QKV[qkv_index] + qkv_offset;
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
memcpy(broadcast_data_dest, broadcast_data_src, head_size * sizeof(T));
broadcast_data_dest += head_size;
}
T* Q = reinterpret_cast<T*>(gemm_data_transposed);
T* K = Q + (batch_size * hidden_size * sequence_length);
T* V = K + (batch_size * hidden_size * sequence_length);
// original transposed iteration
// A: input (BxSxNxH) (B.)S x NH S x NH
// B: weights (NxHx3xNxH) NH x (3.N.)H NH x H
// C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H
// STEP.3: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1, 1, 1)
math::GemmEx<float, concurrency::ThreadPool>(CblasNoTrans, // TransA = no
CblasNoTrans, // TransB = no
sequence_length, // M = S
head_size, // N = H
hidden_size, // K = NH
1.0f, // alpha
input_data + input_offset, // A
hidden_size, // lda = NH
weights_data + weights_offset, // B
3 * hidden_size, // ldb = 3NH
1.0f, // beta
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr // use single-thread
);
});
}
// STEP.2: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1, 1, 1)
auto scratch_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * sequence_length * element_size);
BufferUniquePtr scratch_buffer(scratch_data, BufferDeleter(allocator));
{
memset(scratch_data, 0, batch_size * num_heads_ * sequence_length * sequence_length * element_size);
auto mask_data = mask_index->template Data<int>();
int size_each_batch = num_heads_ * sequence_length;
T* p_current_data = reinterpret_cast<T*>(scratch_data);
auto scratch_broadcast_data = allocator->Alloc(batch_size * sequence_length * element_size);
BufferUniquePtr scratch_broadcast_buffer(scratch_broadcast_data, BufferDeleter(allocator));
memset(scratch_broadcast_data, 0, batch_size * sequence_length * element_size);
T* p_scratch_broadcast_current_data = reinterpret_cast<T*>(scratch_broadcast_data);
for (int b_i = 0; b_i < batch_size; b_i++) {
int mask = mask_data[b_i];
for (int n_i = 0; n_i < size_each_batch; n_i++) {
for (int m_i = mask; m_i < sequence_length; m_i++) {
p_current_data[m_i] = static_cast<T>(-10000.0);
// TODO: mask_index can be used in softmax to save some calculation.
int mask = mask_index->template Data<int32_t>()[b_i];
for (int m_i = mask; m_i < sequence_length; m_i++) {
p_scratch_broadcast_current_data[m_i] = static_cast<T>(-10000.0);
}
p_scratch_broadcast_current_data += sequence_length;
}
const int loop_len = batch_size * num_heads_;
const float alpha = 1.0f / sqrt(static_cast<float>(head_size));
concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), loop_len, [&](int32_t i) {
const int batch_index = i / num_heads_;
// broadcast masks (B) -> (B.N.)S.S
const T* broadcast_data_src = reinterpret_cast<T*>(scratch_broadcast_data) + batch_index * sequence_length;
T* broadcast_data_dest = reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i;
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T));
broadcast_data_dest += sequence_length;
}
// gemm
// original transposed iteration
// A: Q (BxNxSxH) (B.N.)S x H S x H
// B: K' (BxNxSxH) (B.N.)H x S H x S
// C: scratch_data (BxNxSxS) (B.N.)S x S S x S
math::Gemm<T, concurrency::ThreadPool>(
CblasNoTrans,
CblasTrans,
sequence_length,
sequence_length,
head_size,
alpha,
Q + sequence_length * head_size * i,
K + sequence_length * head_size * i,
1.0,
reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i,
nullptr);
});
}
// STEP.3: P(B, N, S, S) = Softmax(scratch)
{
const int N = batch_size * num_heads_ * sequence_length;
const int D = sequence_length;
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), N, [&](int j) {
float* x = reinterpret_cast<T*>(scratch_data) + j * D;
float* y = x;
for (int i = 0; i < D; i++)
y[i] = expf(x[i]);
double sum = 0.0;
for (int i = 0; i < D; i++) {
sum += x[i];
}
if (sum == 0) {
for (int i = 0; i < D; i++) {
y[i] = 1.0f / (float)D;
}
} else {
for (int i = 0; i < D; i++) {
y[i] = x[i] / (float)sum;
}
p_current_data += sequence_length;
}
}
});
}
{
int offset_Q = 0;
int offset_Q_increment = sequence_length * head_size;
int offset_scratch = 0;
int offset_scratch_increment = sequence_length * sequence_length;
for (int b_i = 0; b_i < batch_size; b_i++) {
for (int n_i = 0; n_i < num_heads_; n_i++) {
math::Gemm<T>(
CblasNoTrans,
CblasTrans,
sequence_length,
sequence_length,
head_size,
1.0f / sqrt(static_cast<float>(head_size)),
Q + offset_Q,
K + offset_Q,
1.0,
reinterpret_cast<T*>(scratch_data) + offset_scratch,
tp);
offset_Q += offset_Q_increment;
offset_scratch += offset_scratch_increment;
}
}
}
// STEP.4: P(B, N, S, S) = Softmax(scratch)
auto p_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * sequence_length * element_size);
BufferUniquePtr p_buffer(p_data, BufferDeleter(allocator));
{
int N = batch_size * num_heads_ * sequence_length;
int D = sequence_length;
Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> X_tensor(
reinterpret_cast<T*>(scratch_data), N, D);
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> Y_tensor(
reinterpret_cast<T*>(p_data), N, D);
#ifndef USE_OPENMP
if (tp == nullptr)
#endif
ComputeSoftMax(Eigen::DefaultDevice(), X_tensor, Y_tensor, false);
#ifndef USE_OPENMP
else
ComputeSoftMax(Eigen::ThreadPoolDevice(&tp->GetHandler(), tp->NumThreads()), X_tensor, Y_tensor, false);
#endif
}
// STEP.5: out_tmp(B, N, S, H) = P(B, N, S, S) x V(B, N, S, H)
// STEP.4: out_tmp(B, N, S, H) = P(B, N, S, S) x V(B, N, S, H)
auto out_tmp_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * head_size * element_size);
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
{
int offset_p = 0;
int offset_p_increment = sequence_length * sequence_length;
int offset_V = 0;
int offset_V_increment = sequence_length * head_size;
for (int b_i = 0; b_i < batch_size; b_i++) {
for (int n_i = 0; n_i < num_heads_; n_i++) {
math::MatMul<T>(
sequence_length,
head_size,
sequence_length,
reinterpret_cast<T*>(p_data) + offset_p,
V + offset_V,
reinterpret_cast<T*>(out_tmp_data) + offset_V,
tp);
offset_p += offset_p_increment;
offset_V += offset_V_increment;
}
concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), batch_size * num_heads_, [&](int i) {
T* current_tmp_data = reinterpret_cast<T*>(out_tmp_data) + sequence_length * head_size * i;
math::MatMul<T>(
sequence_length,
head_size,
sequence_length,
reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i,
V + sequence_length * head_size * i,
current_tmp_data,
nullptr);
// transpose: out(B, S, N, H) = transpose out_tmp(B, N, S, H)
const int batch_index = i / num_heads_;
const int head_index = i % num_heads_;
T* src = current_tmp_data;
T* dest = output->template MutableData<T>() + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
for (int j = 0; j < sequence_length; j++) {
memcpy(dest, src, head_size * sizeof(T));
src += head_size;
dest += hidden_size;
}
}
// STEP.6: out(B, S, N, H) = transpose out_tmp(B, N, S, H)
Tensor out_tmp_tensor{input->DataType(), TensorShape{batch_size, num_heads_, sequence_length, head_size}, out_tmp_data, allocator->Info()};
Tensor output_tensor{input->DataType(), TensorShape{batch_size, sequence_length, num_heads_, head_size}, output->template MutableData<T>(), allocator->Info()};
static const std::vector<size_t> transpose_out_permutations{0, 2, 1, 3};
ORT_RETURN_IF_ERROR(TransposeBase::DoTranspose(transpose_out_permutations, out_tmp_tensor, output_tensor));
});
return Status::OK();
}