From 89824b35e950190da1ebeb9e3c8a0bc9ace0c86f Mon Sep 17 00:00:00 2001 From: Yulong Wang Date: Sun, 1 Dec 2019 14:43:38 -0800 Subject: [PATCH] optimize CPU implementation of Attention (#2496) --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 271 ++++++++++-------- 1 file changed, 148 insertions(+), 123 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 58ccbfe956..2027dd1da0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -107,159 +107,184 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const auto dims = input->Shape().GetDims(); - int batch_size = static_cast(dims[0]); - int sequence_length = static_cast(dims[1]); - int hidden_size = static_cast(dims[2]); - int head_size = hidden_size / num_heads_; + const int batch_size = static_cast(dims[0]); + const int sequence_length = static_cast(dims[1]); + const int hidden_size = static_cast(dims[2]); + const int head_size = hidden_size / num_heads_; TensorShape output_shape(dims); Tensor* output = context->Output(0, output_shape); - const size_t element_size = sizeof(T); - - // Use GEMM for fully connection. - int m = batch_size * sequence_length; - int n = 3 * hidden_size; - int k = hidden_size; + constexpr size_t element_size = sizeof(T); AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - - // STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(NH) + // STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH) auto gemm_data = allocator->Alloc(batch_size * sequence_length * 3 * hidden_size * element_size); BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator)); + auto Q = reinterpret_cast(gemm_data); + auto K = Q + batch_size * sequence_length * hidden_size; + auto V = K + batch_size * sequence_length * hidden_size; - auto gemm_data_mat = EigenMatrixMapRowMajor(reinterpret_cast(gemm_data), m, n); - gemm_data_mat.rowwise() = ConstEigenVectorMap(bias->template Data(), n).transpose(); + T* QKV[3] = {Q, K, V}; - math::Gemm( - CblasNoTrans, - CblasNoTrans, - m, - n, - k, - 1.0f, - input->template Data(), - weights->template Data(), - 1.0f, - reinterpret_cast(gemm_data), - tp); + { + const int loop_len = 3 * batch_size * num_heads_; + const auto input_data = input->template Data(); + const auto weights_data = weights->template Data(); + const auto bias_data = bias->template Data(); - // STEP.2: gemm_data_transposed(3, B, N, S, H) = transpose gemm_data(B, S, 3, N, H) - auto gemm_data_transposed = allocator->Alloc(batch_size * sequence_length * 3 * hidden_size * element_size); - BufferUniquePtr gemm_transposed_buffer(gemm_data_transposed, BufferDeleter(allocator)); + concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), loop_len, [&](int32_t i) { + const int batch_index = (i / 3) / num_heads_; + const int head_index = (i / 3) % num_heads_; + const int qkv_index = i % 3; - Tensor gemm_data_tensor{input->DataType(), TensorShape{batch_size, sequence_length, 3, num_heads_, head_size}, gemm_data, allocator->Info()}; - Tensor gemm_data_transposed_tensor{input->DataType(), TensorShape{3, batch_size, num_heads_, sequence_length, head_size}, gemm_data_transposed, allocator->Info()}; + int input_offset = batch_index * sequence_length * hidden_size; + int weights_offset = qkv_index * hidden_size + head_index * head_size; + T* qkv_dest = QKV[qkv_index]; + int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size); - static const std::vector transpose_permutations{2, 0, 3, 1, 4}; - ORT_RETURN_IF_ERROR(TransposeBase::DoTranspose(transpose_permutations, gemm_data_tensor, gemm_data_transposed_tensor)); + // broadcast 3NH -> (3.B.N.S.H) + const T* broadcast_data_src = bias_data + weights_offset; + T* broadcast_data_dest = QKV[qkv_index] + qkv_offset; + for (int seq_index = 0; seq_index < sequence_length; seq_index++) { + memcpy(broadcast_data_dest, broadcast_data_src, head_size * sizeof(T)); + broadcast_data_dest += head_size; + } - T* Q = reinterpret_cast(gemm_data_transposed); - T* K = Q + (batch_size * hidden_size * sequence_length); - T* V = K + (batch_size * hidden_size * sequence_length); + // original transposed iteration + // A: input (BxSxNxH) (B.)S x NH S x NH + // B: weights (NxHx3xNxH) NH x (3.N.)H NH x H + // C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H - // STEP.3: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1, 1, 1) + math::GemmEx(CblasNoTrans, // TransA = no + CblasNoTrans, // TransB = no + sequence_length, // M = S + head_size, // N = H + hidden_size, // K = NH + 1.0f, // alpha + input_data + input_offset, // A + hidden_size, // lda = NH + weights_data + weights_offset, // B + 3 * hidden_size, // ldb = 3NH + 1.0f, // beta + qkv_dest + qkv_offset, // C + head_size, // ldc + nullptr // use single-thread + ); + }); + } + + // STEP.2: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1, 1, 1) auto scratch_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * sequence_length * element_size); BufferUniquePtr scratch_buffer(scratch_data, BufferDeleter(allocator)); { - memset(scratch_data, 0, batch_size * num_heads_ * sequence_length * sequence_length * element_size); - auto mask_data = mask_index->template Data(); - int size_each_batch = num_heads_ * sequence_length; - T* p_current_data = reinterpret_cast(scratch_data); + auto scratch_broadcast_data = allocator->Alloc(batch_size * sequence_length * element_size); + BufferUniquePtr scratch_broadcast_buffer(scratch_broadcast_data, BufferDeleter(allocator)); + memset(scratch_broadcast_data, 0, batch_size * sequence_length * element_size); + T* p_scratch_broadcast_current_data = reinterpret_cast(scratch_broadcast_data); for (int b_i = 0; b_i < batch_size; b_i++) { - int mask = mask_data[b_i]; - for (int n_i = 0; n_i < size_each_batch; n_i++) { - for (int m_i = mask; m_i < sequence_length; m_i++) { - p_current_data[m_i] = static_cast(-10000.0); + // TODO: mask_index can be used in softmax to save some calculation. + int mask = mask_index->template Data()[b_i]; + for (int m_i = mask; m_i < sequence_length; m_i++) { + p_scratch_broadcast_current_data[m_i] = static_cast(-10000.0); + } + p_scratch_broadcast_current_data += sequence_length; + } + + const int loop_len = batch_size * num_heads_; + const float alpha = 1.0f / sqrt(static_cast(head_size)); + + concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), loop_len, [&](int32_t i) { + const int batch_index = i / num_heads_; + // broadcast masks (B) -> (B.N.)S.S + const T* broadcast_data_src = reinterpret_cast(scratch_broadcast_data) + batch_index * sequence_length; + T* broadcast_data_dest = reinterpret_cast(scratch_data) + sequence_length * sequence_length * i; + for (int seq_index = 0; seq_index < sequence_length; seq_index++) { + memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T)); + broadcast_data_dest += sequence_length; + } + + // gemm + + // original transposed iteration + // A: Q (BxNxSxH) (B.N.)S x H S x H + // B: K' (BxNxSxH) (B.N.)H x S H x S + // C: scratch_data (BxNxSxS) (B.N.)S x S S x S + + math::Gemm( + CblasNoTrans, + CblasTrans, + sequence_length, + sequence_length, + head_size, + alpha, + Q + sequence_length * head_size * i, + K + sequence_length * head_size * i, + 1.0, + reinterpret_cast(scratch_data) + sequence_length * sequence_length * i, + nullptr); + }); + } + + // STEP.3: P(B, N, S, S) = Softmax(scratch) + { + const int N = batch_size * num_heads_ * sequence_length; + const int D = sequence_length; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), N, [&](int j) { + float* x = reinterpret_cast(scratch_data) + j * D; + float* y = x; + + for (int i = 0; i < D; i++) + y[i] = expf(x[i]); + + double sum = 0.0; + + for (int i = 0; i < D; i++) { + sum += x[i]; + } + + if (sum == 0) { + for (int i = 0; i < D; i++) { + y[i] = 1.0f / (float)D; + } + } else { + for (int i = 0; i < D; i++) { + y[i] = x[i] / (float)sum; } - p_current_data += sequence_length; } - } + }); } - { - int offset_Q = 0; - int offset_Q_increment = sequence_length * head_size; - int offset_scratch = 0; - int offset_scratch_increment = sequence_length * sequence_length; - - for (int b_i = 0; b_i < batch_size; b_i++) { - for (int n_i = 0; n_i < num_heads_; n_i++) { - math::Gemm( - CblasNoTrans, - CblasTrans, - sequence_length, - sequence_length, - head_size, - 1.0f / sqrt(static_cast(head_size)), - Q + offset_Q, - K + offset_Q, - 1.0, - reinterpret_cast(scratch_data) + offset_scratch, - tp); - offset_Q += offset_Q_increment; - offset_scratch += offset_scratch_increment; - } - } - } - - // STEP.4: P(B, N, S, S) = Softmax(scratch) - auto p_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * sequence_length * element_size); - BufferUniquePtr p_buffer(p_data, BufferDeleter(allocator)); - - { - int N = batch_size * num_heads_ * sequence_length; - int D = sequence_length; - - Eigen::TensorMap, Eigen::Aligned> X_tensor( - reinterpret_cast(scratch_data), N, D); - Eigen::TensorMap, Eigen::Aligned> Y_tensor( - reinterpret_cast(p_data), N, D); -#ifndef USE_OPENMP - if (tp == nullptr) -#endif - ComputeSoftMax(Eigen::DefaultDevice(), X_tensor, Y_tensor, false); -#ifndef USE_OPENMP - else - ComputeSoftMax(Eigen::ThreadPoolDevice(&tp->GetHandler(), tp->NumThreads()), X_tensor, Y_tensor, false); -#endif - } - - // STEP.5: out_tmp(B, N, S, H) = P(B, N, S, S) x V(B, N, S, H) + // STEP.4: out_tmp(B, N, S, H) = P(B, N, S, S) x V(B, N, S, H) auto out_tmp_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * head_size * element_size); BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator)); - { - int offset_p = 0; - int offset_p_increment = sequence_length * sequence_length; - int offset_V = 0; - int offset_V_increment = sequence_length * head_size; - for (int b_i = 0; b_i < batch_size; b_i++) { - for (int n_i = 0; n_i < num_heads_; n_i++) { - math::MatMul( - sequence_length, - head_size, - sequence_length, - reinterpret_cast(p_data) + offset_p, - V + offset_V, - reinterpret_cast(out_tmp_data) + offset_V, - tp); - offset_p += offset_p_increment; - offset_V += offset_V_increment; - } + concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), batch_size * num_heads_, [&](int i) { + T* current_tmp_data = reinterpret_cast(out_tmp_data) + sequence_length * head_size * i; + math::MatMul( + sequence_length, + head_size, + sequence_length, + reinterpret_cast(scratch_data) + sequence_length * sequence_length * i, + V + sequence_length * head_size * i, + current_tmp_data, + nullptr); + + // transpose: out(B, S, N, H) = transpose out_tmp(B, N, S, H) + const int batch_index = i / num_heads_; + const int head_index = i % num_heads_; + T* src = current_tmp_data; + T* dest = output->template MutableData() + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + for (int j = 0; j < sequence_length; j++) { + memcpy(dest, src, head_size * sizeof(T)); + src += head_size; + dest += hidden_size; } - } - - // STEP.6: out(B, S, N, H) = transpose out_tmp(B, N, S, H) - Tensor out_tmp_tensor{input->DataType(), TensorShape{batch_size, num_heads_, sequence_length, head_size}, out_tmp_data, allocator->Info()}; - Tensor output_tensor{input->DataType(), TensorShape{batch_size, sequence_length, num_heads_, head_size}, output->template MutableData(), allocator->Info()}; - - static const std::vector transpose_out_permutations{0, 2, 1, 3}; - ORT_RETURN_IF_ERROR(TransposeBase::DoTranspose(transpose_out_permutations, out_tmp_tensor, output_tensor)); + }); return Status::OK(); }