Change mask_index input of Attention op to be optional (#3459)

Change Mask Index to optional
This commit is contained in:
Tianlei Wu 2020-04-12 22:55:37 -07:00 committed by GitHub
parent 7f6e407e09
commit 54bbbb78ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 133 additions and 46 deletions

View file

@ -35,7 +35,7 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Input 1 - weights : (hidden_size, 3 * hidden_size)
// Input 2 - bias : (3 * hidden_size)
// Input 3 - mask_index : (batch_size)
// Input 3 - mask_index : (batch_size) if presented
// Output : (batch_size, sequence_length, hidden_size)
const Tensor* input = context->Input<Tensor>(0);
@ -77,13 +77,15 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
}
const Tensor* mask_index = context->Input<Tensor>(3);
const auto mask_dims = mask_index->Shape().GetDims();
if (mask_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 is expected to have 1 dimension, got ",
mask_dims.size());
}
if (static_cast<int>(mask_dims[0]) != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 3 and 0 shall have same length at dimension 0");
if (mask_index != nullptr) { // mask_index is optional
const auto mask_dims = mask_index->Shape().GetDims();
if (mask_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 is expected to have 1 dimension, got ",
mask_dims.size());
}
if (static_cast<int>(mask_dims[0]) != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 3 and 0 shall have same length at dimension 0");
}
}
return Status::OK();
@ -179,22 +181,36 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
// STEP.2: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1,
// 1, 1)
auto scratch_data =
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * sequence_length * element_size);
size_t scratch_data_bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * sequence_length * element_size;
auto scratch_data = allocator->Alloc(scratch_data_bytes);
BufferUniquePtr scratch_buffer(scratch_data, BufferDeleter(allocator));
{
auto scratch_broadcast_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * element_size);
BufferUniquePtr scratch_broadcast_buffer(scratch_broadcast_data, BufferDeleter(allocator));
memset(scratch_broadcast_data, 0, batch_size * sequence_length * element_size);
T* p_scratch_broadcast_current_data = reinterpret_cast<T*>(scratch_broadcast_data);
for (int b_i = 0; b_i < batch_size; b_i++) {
// TODO: mask_index can be used in softmax to save some calculation.
int mask = mask_index->template Data<int32_t>()[b_i];
for (int m_i = mask; m_i < sequence_length; m_i++) {
p_scratch_broadcast_current_data[m_i] = static_cast<T>(-10000.0);
size_t mask_data_bytes = 0;
if (mask_index != nullptr) {
mask_data_bytes = SafeInt<size_t>(batch_size) * sequence_length * element_size;
}
void* mask_data = nullptr;
if (mask_data_bytes > 0) {
mask_data = allocator->Alloc(mask_data_bytes);
memset(mask_data, 0, mask_data_bytes);
}
BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));
if (mask_index != nullptr) {
T* p_mask = reinterpret_cast<T*>(mask_data);
for (int b_i = 0; b_i < batch_size; b_i++) {
// TODO: mask_index can be used in softmax to save some calculation.
// Convert mask_index to mask (-10000 means out of range, which will be 0 after softmax): B => BxS
int valid_length = mask_index->template Data<int32_t>()[b_i];
for (int m_i = valid_length; m_i < sequence_length; m_i++) {
p_mask[m_i] = static_cast<T>(-10000.0);
}
p_mask += sequence_length;
}
p_scratch_broadcast_current_data += sequence_length;
} else {
memset(scratch_data, 0, scratch_data_bytes);
}
const int loop_len = batch_size * num_heads_;
@ -206,12 +222,15 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const std::ptrdiff_t batch_index = i / num_heads_;
// broadcast masks (B) -> (B.N.)S.S
const T* broadcast_data_src = reinterpret_cast<T*>(scratch_broadcast_data) + batch_index * sequence_length;
T* broadcast_data_dest = reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i;
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T));
broadcast_data_dest += sequence_length;
if (mask_index != nullptr) {
const T* broadcast_data_src = reinterpret_cast<T*>(mask_data) + batch_index * sequence_length;
T* broadcast_data_dest = reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i;
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T));
broadcast_data_dest += sequence_length;
}
}
// gemm

View file

@ -40,7 +40,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Input 1 - weights : (hidden_size, 3 * hidden_size)
// Input 2 - bias : (3 * hidden_size)
// Input 3 - mask_index : (batch_size)
// Input 3 - mask_index : (batch_size) if presented
// Output : (batch_size, sequence_length, hidden_size)
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights = context->Input<Tensor>(1);
@ -88,7 +88,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
auto temp_buffer = GetScratchBuffer<void>(workSpaceSize);
if (!LaunchAttentionKernel(
reinterpret_cast<const CudaT*>(gemm_buffer.get()),
mask_index->template Data<int>(),
nullptr == mask_index ? nullptr : mask_index->template Data<int>(),
output->template MutableData<T>(),
batch_size,
sequence_length,

View file

@ -152,6 +152,38 @@ __device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T*
}
}
template <typename T, unsigned TPB>
__global__ void SoftmaxKernelSmall(const int sequence_length, const T* input, T* output) {
SoftmaxSmall<T, TPB>(sequence_length, sequence_length, input, output);
}
template <typename T, unsigned TPB>
__global__ void SoftmaxKernel(const int sequence_length, const T* input, T* output) {
Softmax<T, TPB>(sequence_length, sequence_length, input, output);
}
template <typename T>
bool ComputeSoftmax(
cudaStream_t stream, const int sequence_length, const int batch_size, const int num_heads,
const T* input, T* output) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (sequence_length <= 32) {
const int blockSize = 32;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(sequence_length, input, output);
} else if (sequence_length <= 128) {
const int blockSize = 128;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(sequence_length, input, output);
} else if (sequence_length == 384) {
const int blockSize = 384;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(sequence_length, input, output);
} else {
const int blockSize = 256;
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(sequence_length, input, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernelSmall(const int sequence_length, const int* mask_index, const T* input, T* output) {
__shared__ int num_valid;
@ -390,8 +422,14 @@ bool QkvToContext(
}
// apply softmax and store result P to scratch2: BxNxSxS
if (!ComputeMaskedSoftmax<T>(stream, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2)) {
return false;
if (nullptr != mask_index) {
if (!ComputeMaskedSoftmax<T>(stream, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2)) {
return false;
}
} else {
if (!ComputeSoftmax<T>(stream, sequence_length, batch_size, num_heads, scratch1, scratch2)) {
return false;
}
}
// compute P*V (as V*P), and store in scratch3: BxNxSxH

View file

@ -7,20 +7,20 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
size_t GetAttentionWorkspaceSize(size_t element_size, int batchsize, int num_heads, int head_size, int sequence_length);
size_t GetAttentionWorkspaceSize(size_t element_size, int batchsize, int num_heads, int head_size, int sequence_length);
bool LaunchAttentionKernel(
const void* input, // Input tensor
const int* mask_index, // Nask index where each element is length of a sequence
void* output, // Output tensor
int batch_size, // Batch size (B)
int sequence_length, // Sequence length (S)
int num_heads, // Number of attention heads (N)
int head_size, // Hidden layer size per head (H)
void* workspace, // Temporary buffer
cublasHandle_t& cublas, // Cublas handle
const size_t element_size // Element size of input tensor
);
bool LaunchAttentionKernel(
const void* input, // Input tensor
const int* mask_index, // Mask index (length of each sequence). NULL means no mask.
void* output, // Output tensor
int batch_size, // Batch size (B)
int sequence_length, // Sequence length (S)
int num_heads, // Number of attention heads (N)
int head_size, // Hidden layer size per head (H)
void* workspace, // Temporary buffer
cublasHandle_t& cublas, // Cublas handle
const size_t element_size // Element size of input tensor
);
} // namespace cuda
} // namespace contrib

View file

@ -299,7 +299,7 @@ void RegisterBertSchemas() {
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T")
.Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T")
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
.Input(3, "mask_index", "Attention mask index with shape (batch_size)", "M")
.Input(3, "mask_index", "Attention mask index with shape (batch_size).", "M", OpSchema::Optional)
.Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types")
@ -2325,7 +2325,6 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
RegisterBertSchemas();
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -39,16 +39,18 @@ static void RunAttentionTest(
tester.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
tester.AddInput<MLFloat16>("weight", weights_dims, ToFloat16(weights_data));
tester.AddInput<MLFloat16>("bias", bias_dims, ToFloat16(bias_data));
tester.AddInput<int32_t>("mask_index", mask_index_dims, mask_index_data);
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
} else {
tester.AddInput<float>("input", input_dims, input_data);
tester.AddInput<float>("weight", weights_dims, weights_data);
tester.AddInput<float>("bias", bias_dims, bias_data);
tester.AddInput<int32_t>("mask_index", mask_index_dims, mask_index_data);
tester.AddOutput<float>("output", output_dims, output_data);
}
if (mask_index_data.size() > 0) { // mask index is optional.
tester.AddInput<int32_t>("mask_index", mask_index_dims, mask_index_data);
}
tester.Run();
}
}
@ -204,5 +206,34 @@ TEST(AttentionTest, AttentionMaskExceedSequence) {
batch_size, sequence_length, hidden_size, number_of_heads);
}
TEST(AttentionTest, AttentionNoMaskIndex) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
// No mask_index
std::vector<int32_t> mask_index_data = {};
std::vector<float> output_data = {
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads);
}
} // namespace test
} // namespace onnxruntime